Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
...@@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ...@@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase from vllm.worker.worker_base import WorkerBase
......
...@@ -160,6 +160,9 @@ def _bgmv_expand( ...@@ -160,6 +160,9 @@ def _bgmv_expand(
return return
bgmv_expand = torch.library.custom_op("lora::bgmv_expand", try:
_bgmv_expand, bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
mutates_args=["output_tensor"]) _bgmv_expand,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_expand = _bgmv_expand
...@@ -173,6 +173,9 @@ def _bgmv_expand_slice( ...@@ -173,6 +173,9 @@ def _bgmv_expand_slice(
return return
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", try:
_bgmv_expand_slice, bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
mutates_args=["output_tensor"]) _bgmv_expand_slice,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_expand_slice = _bgmv_expand_slice
...@@ -142,6 +142,9 @@ def _bgmv_shrink( ...@@ -142,6 +142,9 @@ def _bgmv_shrink(
return return
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", try:
_bgmv_shrink, bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
mutates_args=["output_tensor"]) _bgmv_shrink,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_shrink = _bgmv_shrink
...@@ -192,6 +192,9 @@ def _sgmv_expand( ...@@ -192,6 +192,9 @@ def _sgmv_expand(
return return
sgmv_expand = torch.library.custom_op("lora::sgmv_expand", try:
_sgmv_expand, sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
mutates_args=["output_tensor"]) _sgmv_expand,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_expand = _sgmv_expand
...@@ -205,6 +205,9 @@ def _sgmv_expand_slice( ...@@ -205,6 +205,9 @@ def _sgmv_expand_slice(
return return
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", try:
_sgmv_expand_slice, sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
mutates_args=["output_tensor"]) _sgmv_expand_slice,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_expand_slice = _sgmv_expand_slice
...@@ -189,6 +189,9 @@ def _sgmv_shrink( ...@@ -189,6 +189,9 @@ def _sgmv_shrink(
return return
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", try:
_sgmv_shrink, sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
mutates_args=["output_tensor"]) _sgmv_shrink,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_shrink = _sgmv_shrink
...@@ -10,10 +10,8 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union ...@@ -10,10 +10,8 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
import torch import torch
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm.utils import is_xpu
# FIXME: xpu path doesn't support torch.library.custom_op if HAS_TRITON:
if HAS_TRITON and not is_xpu():
from vllm.lora.ops.bgmv_expand import bgmv_expand from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.bgmv_shrink import bgmv_shrink
......
...@@ -5,9 +5,6 @@ from vllm.entrypoints.openai.protocol import ( ...@@ -5,9 +5,6 @@ from vllm.entrypoints.openai.protocol import (
CompletionRequest) CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import ( from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest) GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor from vllm.sampling_params import LogitsProcessor
...@@ -18,6 +15,9 @@ async def get_guided_decoding_logits_processor( ...@@ -18,6 +15,9 @@ async def get_guided_decoding_logits_processor(
request = _adapt_request_for_tool_use(request) request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines': if guided_decoding_backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor( return await get_outlines_guided_decoding_logits_processor(
request, tokenizer) request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer': if guided_decoding_backend == 'lm-format-enforcer':
...@@ -37,6 +37,9 @@ def get_local_guided_decoding_logits_processor( ...@@ -37,6 +37,9 @@ def get_local_guided_decoding_logits_processor(
# request = _adapt_request_for_tool_use(request) # request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines': if guided_decoding_backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor( return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer) guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer': if guided_decoding_backend == 'lm-format-enforcer':
...@@ -56,8 +59,9 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest, ...@@ -56,8 +59,9 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest,
if type(request) is CompletionRequest: if type(request) is CompletionRequest:
return request return request
# user has chosen to not use any tool # user has chosen to not use any tool,
if request.tool_choice == "none": # OR is allowing the model to choose a tool.
if request.tool_choice == "none" or request.tool_choice == "auto":
return request return request
# user has chosen to use a named tool # user has chosen to use a named tool
......
...@@ -14,9 +14,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ...@@ -14,9 +14,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest) CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import ( from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest) GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor from vllm.sampling_params import LogitsProcessor
...@@ -43,12 +40,23 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( ...@@ -43,12 +40,23 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
character_level_parser = RegexParser(request.guided_regex) character_level_parser = RegexParser(request.guided_regex)
elif request.guided_grammar: elif request.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines # CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor( return await get_outlines_guided_decoding_logits_processor(
request, tokenizer) request, tokenizer)
elif (request.response_format is not None elif (request.response_format is not None
and request.response_format.type == "json_object"): and request.response_format.type == "json_object"):
character_level_parser = JsonSchemaParser( character_level_parser = JsonSchemaParser(
None) # None means any json object None) # None means any json object
elif (request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
schema = _normalize_json_schema_object(
request.response_format.json_schema.json_schema)
character_level_parser = JsonSchemaParser(schema)
else: else:
return None return None
...@@ -80,6 +88,10 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor( ...@@ -80,6 +88,10 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
character_level_parser = RegexParser(guided_options.guided_regex) character_level_parser = RegexParser(guided_options.guided_regex)
elif guided_options.guided_grammar: elif guided_options.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines # CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor( return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer) guided_options, tokenizer)
elif guided_options.guided_json_object: elif guided_options.guided_json_object:
......
...@@ -8,8 +8,9 @@ from typing import Tuple, Union ...@@ -8,8 +8,9 @@ from typing import Tuple, Union
from pydantic import BaseModel from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, from vllm.entrypoints.openai.protocol import (
CompletionRequest) ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import ( from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest) GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
...@@ -101,16 +102,30 @@ def _get_guide_and_mode( ...@@ -101,16 +102,30 @@ def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest, request: Union[CompletionRequest, ChatCompletionRequest,
GuidedDecodingRequest] GuidedDecodingRequest]
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
# if the request is a chat completion request, AND the tool choice is a
# named tool choice, do guided decoding
# using that tool as the JSON schema
if isinstance(request, ChatCompletionRequest) and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam):
# Guided generation for tools/functions parameters
if request.tool_choice.type == "function":
for tool in request.tools:
if (tool.type == "function" and tool.function.name
== request.tool_choice.function.name):
json = json_dumps(tool.function.parameters, sort_keys=True)
return json, GuidedDecodingMode.JSON
return None, None
if request.guided_json: elif request.guided_json:
json = request.guided_json if isinstance(request.guided_json, dict):
if isinstance(json, dict):
# turn dict into hashable string # turn dict into hashable string
json = json_dumps(json) json = json_dumps(request.guided_json)
elif isinstance(json, BaseModel): elif isinstance(request.guided_json, BaseModel):
# use pydantic signature so that different model classes # use pydantic signature so that different model classes
# with the same fields will get hashed the same # with the same fields will get hashed the same
json = str(json.__signature__) json = str(request.guided_json.__signature__)
else:
json = request.guided_json
return json, GuidedDecodingMode.JSON return json, GuidedDecodingMode.JSON
elif request.guided_regex: elif request.guided_regex:
return request.guided_regex, GuidedDecodingMode.REGEX return request.guided_regex, GuidedDecodingMode.REGEX
...@@ -127,6 +142,13 @@ def _get_guide_and_mode( ...@@ -127,6 +142,13 @@ def _get_guide_and_mode(
and request.response_format is not None and request.response_format is not None
and request.response_format.type == "json_object"): and request.response_format.type == "json_object"):
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
elif (not isinstance(request, GuidedDecodingRequest)
and request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
json = json_dumps(request.response_format.json_schema.json_schema)
return json, GuidedDecodingMode.JSON
else: else:
return None, None return None, None
......
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, from vllm.model_executor.layers.fused_moe.layer import (
FusedMoEMethodBase) FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
__all__ = [ __all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"]
"FusedMoE",
"FusedMoEMethodBase",
]
if HAS_TRITON: if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name, fused_experts, fused_marlin_moe, fused_moe, fused_topk,
grouped_topk) get_config_file_name, grouped_topk)
__all__ += [ __all__ += [
"fused_marlin_moe",
"fused_moe", "fused_moe",
"fused_topk", "fused_topk",
"fused_experts", "fused_experts",
......
{
"3328": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2560": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2816": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3584": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3840": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1280": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"2304": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
}
}
\ No newline at end of file
{
"3840": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3584": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2816": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1280": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3328": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2560": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"2304": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
}
}
\ No newline at end of file
{
"2048": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3328": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2560": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"768": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2816": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2304": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2
},
"1280": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3840": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3584": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
\ No newline at end of file
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import functools import functools
import json import json
import os import os
from typing import Any, Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple
import torch import torch
import triton import triton
...@@ -323,21 +323,16 @@ def get_moe_configs(E: int, N: int, ...@@ -323,21 +323,16 @@ def get_moe_configs(E: int, N: int,
return None return None
def get_default_config( def get_default_config(M: int, E: int, N: int, K: int, topk: int,
M: int, dtype: Optional[str],
E: int, is_marlin: bool) -> Dict[str, int]:
N: int,
K: int,
topk: int,
dtype: Optional[str],
) -> Dict[str, int]:
config = { config = {
'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32, 'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8 'GROUP_SIZE_M': 8
} }
if M <= E: if M <= E or (is_marlin and M <= 32):
config = { config = {
'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_N': 32,
...@@ -347,14 +342,14 @@ def get_default_config( ...@@ -347,14 +342,14 @@ def get_default_config(
return config return config
def try_get_optimal_moe_config( def try_get_optimal_moe_config(w1_shape: Tuple[int, ...],
w1_shape: Tuple[int, ...], w2_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...], top_k: int,
top_k: int, dtype: Optional[str],
dtype: Optional[str], M: int,
M: int, override_config: Optional[Dict[str,
override_config: Optional[Dict[str, Any]] = None, Any]] = None,
): is_marlin: bool = False):
if override_config: if override_config:
config = override_config config = override_config
else: else:
...@@ -368,7 +363,8 @@ def try_get_optimal_moe_config( ...@@ -368,7 +363,8 @@ def try_get_optimal_moe_config(
config = configs[min(configs.keys(), key=lambda x: abs(x - M))] config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else: else:
# Else use the default config # Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype) config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
is_marlin)
return config return config
...@@ -441,6 +437,113 @@ def grouped_topk(hidden_states: torch.Tensor, ...@@ -441,6 +437,113 @@ def grouped_topk(hidden_states: torch.Tensor,
return topk_weights, topk_ids return topk_weights, topk_ids
def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
rand_perm1: torch.Tensor,
rand_perm2: torch.Tensor,
topk: int,
custom_routing_function: Optional[Callable] = None,
renormalize: bool = True,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
#TODO fp8 is not implemented yet
assert not use_fp8
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
if custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)
get_config_func = functools.partial(try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
is_marlin=True)
config = get_config_func(M)
block_size_m = config['BLOCK_SIZE_M']
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale,
g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk,
block_size_m, True, False)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids,
w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
block_size_m, False, True)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
def get_config_dtype_str(dtype: torch.dtype, def get_config_dtype_str(dtype: torch.dtype,
use_int8_w8a16: Optional[bool] = False, use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False): use_fp8_w8a8: Optional[bool] = False):
...@@ -597,6 +700,7 @@ def fused_moe( ...@@ -597,6 +700,7 @@ def fused_moe(
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
...@@ -644,9 +748,12 @@ def fused_moe( ...@@ -644,9 +748,12 @@ def fused_moe(
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize, topk, renormalize,
num_expert_group, topk_group) num_expert_group, topk_group)
else: elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize) renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)
return fused_experts(hidden_states, return fused_experts(hidden_states,
w1, w1,
......
from abc import abstractmethod from abc import abstractmethod
from typing import List, Optional, Tuple from enum import Enum
from typing import Callable, List, Optional, Tuple
import torch import torch
...@@ -15,6 +16,12 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -15,6 +16,12 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__) logger = init_logger(__name__)
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
class FusedMoEMethodBase(QuantizeMethodBase): class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod @abstractmethod
...@@ -55,15 +62,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -55,15 +62,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
router_logits: torch.Tensor, x: torch.Tensor,
top_k: int, router_logits: torch.Tensor,
renormalize: bool, top_k: int,
use_grouped_topk: bool, renormalize: bool,
topk_group: Optional[int] = None, use_grouped_topk: bool,
num_expert_group: Optional[int] = None) -> torch.Tensor: topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
return self.forward(x=x, return self.forward(x=x,
layer=layer, layer=layer,
...@@ -72,17 +82,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -72,17 +82,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize=renormalize, renormalize=renormalize,
use_grouped_topk=use_grouped_topk, use_grouped_topk=use_grouped_topk,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group) num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
def forward_cuda(self,
layer: torch.nn.Module, def forward_cuda(
x: torch.Tensor, self,
use_grouped_topk: bool, layer: torch.nn.Module,
top_k: int, x: torch.Tensor,
router_logits: torch.Tensor, use_grouped_topk: bool,
renormalize: bool, top_k: int,
topk_group: Optional[int] = None, router_logits: torch.Tensor,
num_expert_group: Optional[int] = None) -> torch.Tensor: renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts) fused_experts)
...@@ -94,7 +108,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -94,7 +108,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
top_k=top_k, top_k=top_k,
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group) num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(hidden_states=x, return fused_experts(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
...@@ -107,20 +122,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -107,20 +122,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
raise NotImplementedError( raise NotImplementedError(
"The CPU backend currently does not support MoE.") "The CPU backend currently does not support MoE.")
def forward_tpu(self, def forward_tpu(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
use_grouped_topk: bool, x: torch.Tensor,
top_k: int, use_grouped_topk: bool,
router_logits: torch.Tensor, top_k: int,
renormalize: bool, router_logits: torch.Tensor,
topk_group: Optional[int] = None, renormalize: bool,
num_expert_group: Optional[int] = None) -> torch.Tensor: topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk assert not use_grouped_topk
assert num_expert_group is None assert num_expert_group is None
assert topk_group is None assert topk_group is None
assert custom_routing_function is None
return fused_moe(hidden_states=x, return fused_moe(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
...@@ -165,6 +184,7 @@ class FusedMoE(torch.nn.Module): ...@@ -165,6 +184,7 @@ class FusedMoE(torch.nn.Module):
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
custom_routing_function: Optional[Callable] = None,
): ):
super().__init__() super().__init__()
...@@ -183,6 +203,7 @@ class FusedMoE(torch.nn.Module): ...@@ -183,6 +203,7 @@ class FusedMoE(torch.nn.Module):
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group self.num_expert_group = num_expert_group
self.topk_group = topk_group self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = ( self.quant_method: Optional[QuantizeMethodBase] = (
...@@ -199,55 +220,182 @@ class FusedMoE(torch.nn.Module): ...@@ -199,55 +220,182 @@ class FusedMoE(torch.nn.Module):
params_dtype=params_dtype, params_dtype=params_dtype,
weight_loader=self.weight_loader) weight_loader=self.weight_loader)
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
expert_id: int):
param_data = param.data
# for per tensor weight quantization
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
elif shard_id == "w2":
param_data[expert_id] = loaded_weight
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int):
# Load grouped weight scales for group quantization
# or model weights
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
shard_dim: int, shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int):
# for per channel weight quantization
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
def _load_single_value(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, expert_id: int):
param_data = param.data
# Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight
def weight_loader(self, param: torch.nn.Parameter, def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str, loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None: shard_id: str, expert_id: int) -> None:
if shard_id not in ("w1", "w2", "w3"): if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but " raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.") f"got {shard_id}.")
# Special case for fp8 scales. WEIGHT_SCALE_SUPPORTED = [
if getattr(param, "is_fp8_scale", False): e.value for e in FusedMoeWeightScaleSupported
self._load_fp8_scale(param.data, loaded_weight, weight_name, ]
shard_id, expert_id) # Fetch the dim to shard the parameter/loaded weight
return # based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id] expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
# If transposed, weight is saved as [input_dim, output_dim] # is_transposed: whether or not the parameter is transposed on disk
# Otherwise, weight is saved as [output_dim, input_dim] # If transposed, the loaded weight will be transposed and the dim
# Default is not transposed/input dim is dim 1 # to shard the loaded weight will be flipped.
input_dim = getattr(param, "input_dim", 1) is_transposed = getattr(param, "is_transposed", False)
output_dim = getattr(param, "output_dim", 0) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
loaded_weight = loaded_weight.t().contiguous()
shard_dim = ~shard_dim
# Case weight_scales
if "weight_scale" in weight_name:
# load the weight scaling based on the quantization scheme
# supported weight scales can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
else:
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
return
# Index the loaded weight for tp sharding. if "weight_shape" in weight_name:
# down_proj: "RowParallel" so tp sharding on input_dim self._load_single_value(param=param,
if shard_id == "w2": loaded_weight=loaded_weight,
shard_dim = input_dim expert_id=expert_id)
shard_size = expert_data.shape[shard_dim] return
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
elif shard_id in ("w1", "w3"):
shard_dim = output_dim
shard_size = expert_data.shape[output_dim] // 2
offset = shard_size * tp_rank
loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)
# Narrow parameter and load. # Case input scale
# w1, gate_proj: Load into first logical weight of w13. if "input_scale" in weight_name:
if shard_id == "w1": # Note: input_scale loading is only supported for fp8
expert_data = expert_data.narrow(shard_dim, 0, shard_size) if param.data[expert_id] != 1 and (param.data[expert_id] -
expert_data.copy_(loaded_weight) loaded_weight).abs() > 1e-5:
# w3, up_proj: Load into second logical weight of w13. raise ValueError(
elif shard_id == "w3": "input_scales of w1 and w3 of a layer "
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) f"must be equal. But got {param.data[expert_id]} "
expert_data.copy_(loaded_weight) f"vs. {loaded_weight}")
# w2, down_proj: Load into only logical weight of w2.
elif shard_id == "w2": self._load_single_value(param=param,
expert_data.copy_(loaded_weight) loaded_weight=loaded_weight,
else: expert_id=expert_id)
raise ValueError( return
f"Expected shard_id w1,w2 or w3 but got {shard_id}")
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
return
@staticmethod @staticmethod
def select_experts(hidden_states: torch.Tensor, def select_experts(hidden_states: torch.Tensor,
...@@ -256,7 +404,8 @@ class FusedMoE(torch.nn.Module): ...@@ -256,7 +404,8 @@ class FusedMoE(torch.nn.Module):
use_grouped_topk: bool, use_grouped_topk: bool,
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None): num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk) fused_topk, grouped_topk)
...@@ -271,11 +420,17 @@ class FusedMoE(torch.nn.Module): ...@@ -271,11 +420,17 @@ class FusedMoE(torch.nn.Module):
renormalize=renormalize, renormalize=renormalize,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group) topk_group=topk_group)
else: elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
renormalize=renormalize) renormalize=renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
return topk_weights, topk_ids return topk_weights, topk_ids
...@@ -292,7 +447,8 @@ class FusedMoE(torch.nn.Module): ...@@ -292,7 +447,8 @@ class FusedMoE(torch.nn.Module):
renormalize=self.renormalize, renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group, topk_group=self.topk_group,
num_expert_group=self.num_expert_group) num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
...@@ -342,4 +498,4 @@ class FusedMoE(torch.nn.Module): ...@@ -342,4 +498,4 @@ class FusedMoE(torch.nn.Module):
param_data[expert_id][idx] = loaded_weight param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj) # If we are in the row parallel case (down_proj)
else: else:
param_data[expert_id] = loaded_weight param_data[expert_id] = loaded_weight
\ No newline at end of file
...@@ -14,8 +14,10 @@ from vllm.logger import init_logger ...@@ -14,8 +14,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
PackedColumnParameter,
PackedvLLMParameter, PackedvLLMParameter,
PerTensorScaleParameter) PerTensorScaleParameter,
RowvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
import os import os
...@@ -26,7 +28,8 @@ logger = init_logger(__name__) ...@@ -26,7 +28,8 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [ WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod" "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod"
] ]
...@@ -38,9 +41,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset): ...@@ -38,9 +41,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def adjust_bitsandbytes_shard(param: Parameter, def adjust_bitsandbytes_4bit_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]], qkv_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]: loaded_shard_id: str) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total, _ = qkv_offsets["total"] total, _ = qkv_offsets["total"]
...@@ -227,8 +230,7 @@ class ReplicatedLinear(LinearBase): ...@@ -227,8 +230,7 @@ class ReplicatedLinear(LinearBase):
self.input_size, self.input_size,
self.output_size, self.output_size,
self.params_dtype, self.params_dtype,
weight_loader=self.weight_loader, weight_loader=self.weight_loader)
prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
...@@ -326,8 +328,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -326,8 +328,7 @@ class ColumnParallelLinear(LinearBase):
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=( weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__ self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
...@@ -525,8 +526,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -525,8 +526,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
if use_bitsandbytes: False)
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim] shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \ shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id loaded_shard_id
...@@ -593,8 +595,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -593,8 +595,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Quantization. # Special case for Quantization.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
# for the packing. # for the packing.
if isinstance(param, PackedvLLMParameter if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
) and param.packed_dim == param.output_dim: )) and param.packed_dim == param.output_dim:
shard_size, shard_offset = \ shard_size, shard_offset = \
param.adjust_shard_indexes_for_packing( param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset) shard_size=shard_size, shard_offset=shard_offset)
...@@ -613,9 +615,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -613,9 +615,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param.load_merged_column_weight(loaded_weight=loaded_weight, param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=0) shard_id=0)
return return
elif type(param) is BasevLLMParameter: elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight) param.load_merged_column_weight(loaded_weight=loaded_weight)
return return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight) self._load_fused_module_from_checkpoint(param, loaded_weight)
return return
...@@ -743,8 +746,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -743,8 +746,8 @@ class QKVParallelLinear(ColumnParallelLinear):
# Special case for Quantization. # Special case for Quantization.
# If quantized, we need to adjust the offset and size to account # If quantized, we need to adjust the offset and size to account
# for the packing. # for the packing.
if isinstance(param, PackedvLLMParameter if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
) and param.packed_dim == param.output_dim: )) and param.packed_dim == param.output_dim:
shard_size, shard_offset = \ shard_size, shard_offset = \
param.adjust_shard_indexes_for_packing( param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset) shard_size=shard_size, shard_offset=shard_offset)
...@@ -760,12 +763,12 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -760,12 +763,12 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: Optional[str] = None): loaded_shard_id: Optional[str] = None):
if loaded_shard_id is None: # special case for certain models if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter): if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight, param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
shard_id=0)
return return
elif type(param) is BasevLLMParameter: elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight) param.load_qkv_weight(loaded_weight=loaded_weight)
return return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight) self._load_fused_module_from_checkpoint(param, loaded_weight)
return return
...@@ -878,8 +881,9 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -878,8 +881,9 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard( shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset) param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
if use_bitsandbytes: False)
if use_bitsandbytes_4bit:
orig_qkv_offsets = { orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size), "q": (0, self.num_heads * self.head_size),
"k": (self.num_heads * self.head_size, "k": (self.num_heads * self.head_size,
...@@ -891,7 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -891,7 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
((self.num_heads + 2 * self.num_kv_heads) * self.head_size, ((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
0) 0)
} }
shard_size, shard_offset = adjust_bitsandbytes_shard( shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id) param, orig_qkv_offsets, loaded_shard_id)
if is_gguf_weight: if is_gguf_weight:
...@@ -995,8 +999,7 @@ class RowParallelLinear(LinearBase): ...@@ -995,8 +999,7 @@ class RowParallelLinear(LinearBase):
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=( weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__ self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
prefix=prefix)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results") "results can lead to incorrect results")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment