Commit 0640f227 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.0' into v0.6.0-dev

parents 82f1ffdf 32e7db25
......@@ -9,7 +9,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase
......
......@@ -160,6 +160,9 @@ def _bgmv_expand(
return
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
_bgmv_expand,
mutates_args=["output_tensor"])
try:
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
_bgmv_expand,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_expand = _bgmv_expand
......@@ -173,6 +173,9 @@ def _bgmv_expand_slice(
return
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
_bgmv_expand_slice,
mutates_args=["output_tensor"])
try:
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
_bgmv_expand_slice,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_expand_slice = _bgmv_expand_slice
......@@ -142,6 +142,9 @@ def _bgmv_shrink(
return
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
_bgmv_shrink,
mutates_args=["output_tensor"])
try:
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
_bgmv_shrink,
mutates_args=["output_tensor"])
except AttributeError:
bgmv_shrink = _bgmv_shrink
......@@ -192,6 +192,9 @@ def _sgmv_expand(
return
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand,
mutates_args=["output_tensor"])
try:
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
_sgmv_expand,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_expand = _sgmv_expand
......@@ -205,6 +205,9 @@ def _sgmv_expand_slice(
return
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
_sgmv_expand_slice,
mutates_args=["output_tensor"])
try:
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
_sgmv_expand_slice,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_expand_slice = _sgmv_expand_slice
......@@ -189,6 +189,9 @@ def _sgmv_shrink(
return
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
_sgmv_shrink,
mutates_args=["output_tensor"])
try:
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
_sgmv_shrink,
mutates_args=["output_tensor"])
except AttributeError:
sgmv_shrink = _sgmv_shrink
......@@ -10,10 +10,8 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
import torch
from vllm.triton_utils import HAS_TRITON
from vllm.utils import is_xpu
# FIXME: xpu path doesn't support torch.library.custom_op
if HAS_TRITON and not is_xpu():
if HAS_TRITON:
from vllm.lora.ops.bgmv_expand import bgmv_expand
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
......
......@@ -5,9 +5,6 @@ from vllm.entrypoints.openai.protocol import (
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
......@@ -18,6 +15,9 @@ async def get_guided_decoding_logits_processor(
request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
......@@ -37,6 +37,9 @@ def get_local_guided_decoding_logits_processor(
# request = _adapt_request_for_tool_use(request)
if guided_decoding_backend == 'outlines':
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
if guided_decoding_backend == 'lm-format-enforcer':
......@@ -56,8 +59,9 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest,
if type(request) is CompletionRequest:
return request
# user has chosen to not use any tool
if request.tool_choice == "none":
# user has chosen to not use any tool,
# OR is allowing the model to choose a tool.
if request.tool_choice == "none" or request.tool_choice == "auto":
return request
# user has chosen to use a named tool
......
......@@ -14,9 +14,6 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor,
get_outlines_guided_decoding_logits_processor)
from vllm.sampling_params import LogitsProcessor
......@@ -43,12 +40,23 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor(
character_level_parser = RegexParser(request.guided_regex)
elif request.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_outlines_guided_decoding_logits_processor)
return await get_outlines_guided_decoding_logits_processor(
request, tokenizer)
elif (request.response_format is not None
and request.response_format.type == "json_object"):
character_level_parser = JsonSchemaParser(
None) # None means any json object
elif (request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
schema = _normalize_json_schema_object(
request.response_format.json_schema.json_schema)
character_level_parser = JsonSchemaParser(schema)
else:
return None
......@@ -80,6 +88,10 @@ def get_local_lm_format_enforcer_guided_decoding_logits_processor(
character_level_parser = RegexParser(guided_options.guided_regex)
elif guided_options.guided_grammar:
# CFG grammar not supported by LMFE, revert to outlines
# NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193
from vllm.model_executor.guided_decoding.outlines_decoding import (
get_local_outlines_guided_decoding_logits_processor)
return get_local_outlines_guided_decoding_logits_processor(
guided_options, tokenizer)
elif guided_options.guided_json_object:
......
......@@ -8,8 +8,9 @@ from typing import Tuple, Union
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest)
from vllm.entrypoints.openai.protocol import (
ChatCompletionNamedToolChoiceParam, ChatCompletionRequest,
CompletionRequest)
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest)
from vllm.model_executor.guided_decoding.outlines_logits_processors import (
......@@ -101,16 +102,30 @@ def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest,
GuidedDecodingRequest]
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:
# if the request is a chat completion request, AND the tool choice is a
# named tool choice, do guided decoding
# using that tool as the JSON schema
if isinstance(request, ChatCompletionRequest) and isinstance(
request.tool_choice, ChatCompletionNamedToolChoiceParam):
# Guided generation for tools/functions parameters
if request.tool_choice.type == "function":
for tool in request.tools:
if (tool.type == "function" and tool.function.name
== request.tool_choice.function.name):
json = json_dumps(tool.function.parameters, sort_keys=True)
return json, GuidedDecodingMode.JSON
return None, None
if request.guided_json:
json = request.guided_json
if isinstance(json, dict):
elif request.guided_json:
if isinstance(request.guided_json, dict):
# turn dict into hashable string
json = json_dumps(json)
elif isinstance(json, BaseModel):
json = json_dumps(request.guided_json)
elif isinstance(request.guided_json, BaseModel):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json = str(json.__signature__)
json = str(request.guided_json.__signature__)
else:
json = request.guided_json
return json, GuidedDecodingMode.JSON
elif request.guided_regex:
return request.guided_regex, GuidedDecodingMode.REGEX
......@@ -127,6 +142,13 @@ def _get_guide_and_mode(
and request.response_format is not None
and request.response_format.type == "json_object"):
return JSON_GRAMMAR, GuidedDecodingMode.GRAMMAR
elif (not isinstance(request, GuidedDecodingRequest)
and request.response_format is not None
and request.response_format.type == "json_schema"
and request.response_format.json_schema is not None
and request.response_format.json_schema.json_schema is not None):
json = json_dumps(request.response_format.json_schema.json_schema)
return json, GuidedDecodingMode.JSON
else:
return None, None
......
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.triton_utils import HAS_TRITON
__all__ = [
"FusedMoE",
"FusedMoEMethodBase",
]
__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"]
if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
fused_experts, fused_marlin_moe, fused_moe, fused_topk,
get_config_file_name, grouped_topk)
__all__ += [
"fused_marlin_moe",
"fused_moe",
"fused_topk",
"fused_experts",
......
{
"3328": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2560": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2816": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3584": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3840": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1280": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"2304": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
}
}
\ No newline at end of file
{
"3840": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3584": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2816": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1280": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"768": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3328": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"2560": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"2304": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
}
}
\ No newline at end of file
{
"2048": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"1792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"3328": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2560": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 4
},
"768": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"2816": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2304": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2
},
"1280": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3840": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3584": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
\ No newline at end of file
......@@ -2,7 +2,7 @@
import functools
import json
import os
from typing import Any, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple
import torch
import triton
......@@ -323,21 +323,16 @@ def get_moe_configs(E: int, N: int,
return None
def get_default_config(
M: int,
E: int,
N: int,
K: int,
topk: int,
dtype: Optional[str],
) -> Dict[str, int]:
def get_default_config(M: int, E: int, N: int, K: int, topk: int,
dtype: Optional[str],
is_marlin: bool) -> Dict[str, int]:
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}
if M <= E:
if M <= E or (is_marlin and M <= 32):
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
......@@ -347,14 +342,14 @@ def get_default_config(
return config
def try_get_optimal_moe_config(
w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
override_config: Optional[Dict[str, Any]] = None,
):
def try_get_optimal_moe_config(w1_shape: Tuple[int, ...],
w2_shape: Tuple[int, ...],
top_k: int,
dtype: Optional[str],
M: int,
override_config: Optional[Dict[str,
Any]] = None,
is_marlin: bool = False):
if override_config:
config = override_config
else:
......@@ -368,7 +363,8 @@ def try_get_optimal_moe_config(
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
is_marlin)
return config
......@@ -441,6 +437,113 @@ def grouped_topk(hidden_states: torch.Tensor,
return topk_weights, topk_ids
def fused_marlin_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
g_idx1: torch.Tensor,
g_idx2: torch.Tensor,
rand_perm1: torch.Tensor,
rand_perm2: torch.Tensor,
topk: int,
custom_routing_function: Optional[Callable] = None,
renormalize: bool = True,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[
1] == w1.shape[1] * 16, "Hidden size mismatch w1"
assert hidden_states.shape[
1] == w2.shape[2] // 2, "Hidden size mismatch w2"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
#TODO fp8 is not implemented yet
assert not use_fp8
M, K = hidden_states.shape
E = w1.shape[0]
N = w2.shape[1] * 16
if custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)
get_config_func = functools.partial(try_get_optimal_moe_config,
w1.shape,
w2.shape,
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
is_marlin=True)
config = get_config_func(M)
block_size_m = config['BLOCK_SIZE_M']
sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E)
max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda",
requires_grad=False)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe(
hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale,
g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk,
block_size_m, True, False)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe(
intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids,
w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk,
block_size_m, False, True)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
def get_config_dtype_str(dtype: torch.dtype,
use_int8_w8a16: Optional[bool] = False,
use_fp8_w8a8: Optional[bool] = False):
......@@ -597,6 +700,7 @@ def fused_moe(
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
w1_scale: Optional[torch.Tensor] = None,
......@@ -644,9 +748,12 @@ def fused_moe(
topk_weights, topk_ids = grouped_topk(hidden_states, gating_output,
topk, renormalize,
num_expert_group, topk_group)
else:
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)
return fused_experts(hidden_states,
w1,
......
from abc import abstractmethod
from typing import List, Optional, Tuple
from enum import Enum
from typing import Callable, List, Optional, Tuple
import torch
......@@ -15,6 +16,12 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class FusedMoeWeightScaleSupported(Enum):
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
class FusedMoEMethodBase(QuantizeMethodBase):
@abstractmethod
......@@ -55,15 +62,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None) -> torch.Tensor:
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
......@@ -72,17 +82,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group)
def forward_cuda(self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None) -> torch.Tensor:
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
def forward_cuda(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts)
......@@ -94,7 +108,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group)
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
......@@ -107,20 +122,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")
def forward_tpu(self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None) -> torch.Tensor:
def forward_tpu(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
return fused_moe(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
......@@ -165,6 +184,7 @@ class FusedMoE(torch.nn.Module):
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
):
super().__init__()
......@@ -183,6 +203,7 @@ class FusedMoE(torch.nn.Module):
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
......@@ -199,55 +220,182 @@ class FusedMoE(torch.nn.Module):
params_dtype=params_dtype,
weight_loader=self.weight_loader)
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
expert_id: int):
param_data = param.data
# for per tensor weight quantization
if shard_id in ("w1", "w3"):
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
idx = 0 if shard_id == "w1" else 1
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
elif shard_id == "w2":
param_data[expert_id] = loaded_weight
def _load_model_weight_or_group_weight_scale(self, shard_dim: int,
expert_data: torch.Tensor,
shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int):
# Load grouped weight scales for group quantization
# or model weights
if shard_id == "w2":
self._load_w2(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
def _load_per_channel_weight_scale(self, expert_data: torch.Tensor,
shard_dim: int, shard_id: str,
loaded_weight: torch.tensor,
tp_rank: int):
# for per channel weight quantization
if shard_id == "w2":
expert_data.copy_(loaded_weight)
elif shard_id in ("w1", "w3"):
self._load_w13(shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
def _load_w2(self, expert_data: torch.Tensor, shard_dim: int,
shard_id: str, loaded_weight: torch.tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
def _load_single_value(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, expert_id: int):
param_data = param.data
# Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight
def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
shard_id: str, expert_id: int) -> None:
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
f"got {shard_id}.")
# Special case for fp8 scales.
if getattr(param, "is_fp8_scale", False):
self._load_fp8_scale(param.data, loaded_weight, weight_name,
shard_id, expert_id)
return
WEIGHT_SCALE_SUPPORTED = [
e.value for e in FusedMoeWeightScaleSupported
]
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0}
expert_data = param.data[expert_id]
tp_rank = get_tensor_model_parallel_rank()
# If transposed, weight is saved as [input_dim, output_dim]
# Otherwise, weight is saved as [output_dim, input_dim]
# Default is not transposed/input dim is dim 1
input_dim = getattr(param, "input_dim", 1)
output_dim = getattr(param, "output_dim", 0)
# is_transposed: whether or not the parameter is transposed on disk
# If transposed, the loaded weight will be transposed and the dim
# to shard the loaded weight will be flipped.
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
loaded_weight = loaded_weight.t().contiguous()
shard_dim = ~shard_dim
# Case weight_scales
if "weight_scale" in weight_name:
# load the weight scaling based on the quantization scheme
# supported weight scales can be found in
# FusedMoeWeightScaleSupported
# TODO @dsikka: once hardened, refactor to use vLLM Parameters
# specific to each case
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
self._load_per_channel_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif quant_method == FusedMoeWeightScaleSupported.GROUP.value:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
self._load_per_tensor_weight_scale(shard_id=shard_id,
param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
else:
raise ValueError(
f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}")
return
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
if shard_id == "w2":
shard_dim = input_dim
shard_size = expert_data.shape[shard_dim]
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
elif shard_id in ("w1", "w3"):
shard_dim = output_dim
shard_size = expert_data.shape[output_dim] // 2
offset = shard_size * tp_rank
loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size)
if "weight_shape" in weight_name:
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
expert_data.copy_(loaded_weight)
# w3, up_proj: Load into second logical weight of w13.
elif shard_id == "w3":
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
# w2, down_proj: Load into only logical weight of w2.
elif shard_id == "w2":
expert_data.copy_(loaded_weight)
else:
raise ValueError(
f"Expected shard_id w1,w2 or w3 but got {shard_id}")
# Case input scale
if "input_scale" in weight_name:
# Note: input_scale loading is only supported for fp8
if param.data[expert_id] != 1 and (param.data[expert_id] -
loaded_weight).abs() > 1e-5:
raise ValueError(
"input_scales of w1 and w3 of a layer "
f"must be equal. But got {param.data[expert_id]} "
f"vs. {loaded_weight}")
self._load_single_value(param=param,
loaded_weight=loaded_weight,
expert_id=expert_id)
return
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank)
return
@staticmethod
def select_experts(hidden_states: torch.Tensor,
......@@ -256,7 +404,8 @@ class FusedMoE(torch.nn.Module):
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None):
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk, grouped_topk)
......@@ -271,11 +420,17 @@ class FusedMoE(torch.nn.Module):
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group)
else:
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
return topk_weights, topk_ids
......@@ -292,7 +447,8 @@ class FusedMoE(torch.nn.Module):
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group)
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function)
if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
......@@ -342,4 +498,4 @@ class FusedMoE(torch.nn.Module):
param_data[expert_id][idx] = loaded_weight
# If we are in the row parallel case (down_proj)
else:
param_data[expert_id] = loaded_weight
param_data[expert_id] = loaded_weight
\ No newline at end of file
......@@ -14,8 +14,10 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.parameter import (BasevLLMParameter,
PackedColumnParameter,
PackedvLLMParameter,
PerTensorScaleParameter)
PerTensorScaleParameter,
RowvLLMParameter)
from vllm.model_executor.utils import set_weight_attrs
import os
......@@ -26,7 +28,8 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod"
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod"
]
......@@ -38,9 +41,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def adjust_bitsandbytes_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
def adjust_bitsandbytes_4bit_shard(param: Parameter,
qkv_offsets: Dict[str, Tuple[int, int]],
loaded_shard_id: str) -> Tuple[int, int]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
total, _ = qkv_offsets["total"]
......@@ -227,8 +230,7 @@ class ReplicatedLinear(LinearBase):
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader,
prefix=prefix)
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
......@@ -326,8 +328,7 @@ class ColumnParallelLinear(LinearBase):
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
......@@ -525,8 +526,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
......@@ -593,8 +595,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(param, PackedvLLMParameter
) and param.packed_dim == param.output_dim:
if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
)) and param.packed_dim == param.output_dim:
shard_size, shard_offset = \
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)
......@@ -613,9 +615,10 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=0)
return
elif type(param) is BasevLLMParameter:
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
......@@ -743,8 +746,8 @@ class QKVParallelLinear(ColumnParallelLinear):
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if isinstance(param, PackedvLLMParameter
) and param.packed_dim == param.output_dim:
if isinstance(param, (PackedColumnParameter, PackedvLLMParameter
)) and param.packed_dim == param.output_dim:
shard_size, shard_offset = \
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)
......@@ -760,12 +763,12 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: Optional[str] = None):
if loaded_shard_id is None: # special case for certain models
if isinstance(param, PerTensorScaleParameter):
param.load_merged_column_weight(loaded_weight=loaded_weight,
shard_id=0)
param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0)
return
elif type(param) is BasevLLMParameter:
param.load_merged_column_weight(loaded_weight=loaded_weight)
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
param.load_qkv_weight(loaded_weight=loaded_weight)
return
# TODO: @dsikka - move to parameter.py
self._load_fused_module_from_checkpoint(param, loaded_weight)
return
......@@ -878,8 +881,9 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
if use_bitsandbytes:
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes_4bit:
orig_qkv_offsets = {
"q": (0, self.num_heads * self.head_size),
"k": (self.num_heads * self.head_size,
......@@ -891,7 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
0)
}
shard_size, shard_offset = adjust_bitsandbytes_shard(
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id)
if is_gguf_weight:
......@@ -995,8 +999,7 @@ class RowParallelLinear(LinearBase):
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader),
prefix=prefix)
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment