Commit 6880bf15 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Model] Add VLLM_USE_NN to use nn layout

parent fafe3ca7
...@@ -12,10 +12,10 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention ...@@ -12,10 +12,10 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| 结构 | 模型 | FP16/BF16 | AWQ | GPTQ | 支持版本 | 是否优化 | | 结构 | 模型 | FP16/BF16 | AWQ | GPTQ | 支持版本 | 是否优化 |
| :------: | :------: | :------: | :------: |:------: | :------: |:------: | | :------: | :------: | :------: | :------: |:------: | :------: |:------: |
| LlamaForCausalLM | Llama 3.2, Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,DeepSeek-R1-Distill-Llama | Yes | Yes | Yes | v0.5.0,Llama 3.2>=v0.6.2 | Yes | | LlamaForCausalLM | Llama 3.2, Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,DeepSeek-R1-Distill-Llama | Yes | Yes | Yes | v0.5.0,Llama 3.2>=v0.6.2 | Yes |
| Llama4ForConditionalGeneration | Llama 4 | No/Yes | - | - | v0.8.5.post1 | No | | Llama4ForConditionalGeneration | Llama 4 | No/Yes | - | - | v0.8.5.post1 | No |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes | v0.5.0,Qwen-VL>=v0.6.2 | Yes | | QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes | v0.5.0,Qwen-VL>=v0.6.2 | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5,DeepSeek-R1-Distill-Qwen,gte_Qwen2-1.5B-instruct | Yes | Yes | Yes | v0.5.0,gte>=v0.7.2 | Yes | | Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5,DeepSeek-R1-Distill-Qwen,gte_Qwen2-1.5B-instruct | Yes | Yes | Yes | v0.5.0,gte>=v0.7.2 | Yes |
| Qwen3ForCausalLM | QWen3 | Yes | - | - | v0.8.4 | Yes | | Qwen3ForCausalLM | QWen3,Qwen3-Embedding,Qwen3-Reranker | Yes | - | - | v0.8.4 | Yes |
| Qwen3MoeForCausalLM | QWen3MoE | Yes | - | - | v0.8.4 | Yes | | Qwen3MoeForCausalLM | QWen3MoE | Yes | - | - | v0.8.4 | Yes |
| ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes | v0.5.0 | Yes | | ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes | v0.5.0 | Yes |
| Glm4ForCausalLM | GLM-4-0414 | No/Yes | - | - | v0.8.5.post1 | Yes | | Glm4ForCausalLM | GLM-4-0414 | No/Yes | - | - | v0.8.5.post1 | Yes |
......
...@@ -1192,7 +1192,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1192,7 +1192,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
del eye del eye
# standardize to (output, input) # standardize to (output, input)
return dequant_weights.T return dequant_weights.T
return layer.weight return layer.weight if not envs.VLLM_USE_NN else layer.weight.T
# we currently do not have quantized bmm's which are needed for # we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
......
...@@ -125,6 +125,7 @@ if TYPE_CHECKING: ...@@ -125,6 +125,7 @@ if TYPE_CHECKING:
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_FLASH_ATTN_BACKEND: bool = False VLLM_FLASH_ATTN_BACKEND: bool = False
VLLM_USE_NN: bool = False
VLLM_ENABLE_TBO: bool = False VLLM_ENABLE_TBO: bool = False
VLLM_TBO_REQ_DELAY_MS:int = 0 VLLM_TBO_REQ_DELAY_MS:int = 0
...@@ -814,6 +815,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -814,6 +815,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_FLASH_ATTN_BACKEND": "VLLM_FLASH_ATTN_BACKEND":
lambda: (os.environ.get("VLLM_FLASH_ATTN_BACKEND", "False").lower() in lambda: (os.environ.get("VLLM_FLASH_ATTN_BACKEND", "False").lower() in
("true", "1")), ("true", "1")),
# If set, vLLM will transpose weight to use nn layout
"VLLM_USE_NN":
lambda: (os.environ.get("VLLM_USE_NN", "False").lower() in
("true", "1")),
# Enable two batch overlap. # Enable two batch overlap.
"VLLM_ENABLE_TBO": "VLLM_ENABLE_TBO":
......
...@@ -108,7 +108,10 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id): ...@@ -108,7 +108,10 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
assert loaded_weight.shape[0] == 1 assert loaded_weight.shape[0] == 1
loaded_weight = loaded_weight[0] loaded_weight = loaded_weight[0]
return param[shard_id], loaded_weight if envs.VLLM_USE_NN:
return param[shard_id], loaded_weight.t()
else:
return param[shard_id], loaded_weight
# TODO(Isotr0py): We might need a more flexible structure to handle # TODO(Isotr0py): We might need a more flexible structure to handle
...@@ -194,10 +197,16 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -194,10 +197,16 @@ class UnquantizedLinearMethod(LinearMethodBase):
output_partition_sizes: list[int], input_size: int, output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes), if envs.VLLM_USE_NN:
input_size_per_partition, weight = Parameter(torch.empty(input_size_per_partition,
sum(output_partition_sizes),
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
else:
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs) set_weight_attrs(weight, extra_weight_attrs)
...@@ -219,7 +228,10 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -219,7 +228,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
else: else:
return torch.matmul(x, layer.weight) return torch.matmul(x, layer.weight)
else: else:
return dispatch_unquantized_gemm()(x, layer.weight, bias) if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
return dispatch_unquantized_gemm()(x, layer.weight.t(), bias)
else:
return dispatch_unquantized_gemm()(x, layer.weight, bias)
class LinearBase(torch.nn.Module): class LinearBase(torch.nn.Module):
...@@ -339,6 +351,10 @@ class ReplicatedLinear(LinearBase): ...@@ -339,6 +351,10 @@ class ReplicatedLinear(LinearBase):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param.size() == loaded_weight.size(), ( assert param.size() == loaded_weight.size(), (
f"Tried to load weights of size {loaded_weight.size()}" f"Tried to load weights of size {loaded_weight.size()}"
f"to a parameter of size {param.size()}") f"to a parameter of size {param.size()}")
...@@ -456,6 +472,7 @@ class ColumnParallelLinear(LinearBase): ...@@ -456,6 +472,7 @@ class ColumnParallelLinear(LinearBase):
# bitsandbytes loads the weights of the specific portion # bitsandbytes loads the weights of the specific portion
# no need to narrow # no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
# Special case for GGUF # Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False) is_gguf_weight = getattr(param, "is_gguf_weight", False)
...@@ -474,7 +491,10 @@ class ColumnParallelLinear(LinearBase): ...@@ -474,7 +491,10 @@ class ColumnParallelLinear(LinearBase):
param_data = param.data param_data = param.data
if output_dim is not None and not is_sharded_weight: if output_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[output_dim] if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
shard_size = param_data.shape[output_dim]
else:
shard_size = param_data.shape[int(not(output_dim))]
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
...@@ -484,6 +504,9 @@ class ColumnParallelLinear(LinearBase): ...@@ -484,6 +504,9 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
...@@ -615,6 +638,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -615,6 +638,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
is_metadata = getattr(param, "is_metadata", False) is_metadata = getattr(param, "is_metadata", False)
# Special case for per-tensor scale to load scalar into fused array. # Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already fused on disk (mlp). # Loaded weight is already fused on disk (mlp).
...@@ -694,9 +718,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -694,9 +718,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size = loaded_weight.shape[output_dim] shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \ shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id loaded_shard_id
if not envs.VLLM_USE_NN or is_quantization:
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
else:
param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
if not is_sharded_weight: if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
...@@ -721,6 +748,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -721,6 +748,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"MergedColumnParallelLinear, assume the weight is " "MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.") "the same for all partitions.")
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
...@@ -1013,6 +1043,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1013,6 +1043,7 @@ class QKVParallelLinear(ColumnParallelLinear):
# Special case for per-tensor scales in fused case. # Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False) needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv). # Loaded weight is already fused on disk (qkv).
...@@ -1120,8 +1151,13 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1120,8 +1151,13 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id) param, orig_qkv_offsets, loaded_shard_id)
param_data = param_data.narrow(output_dim, shard_offset, if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
shard_size) param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(int(not(output_dim)), shard_offset,
shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":
shard_id = tp_rank shard_id = tp_rank
else: else:
...@@ -1151,6 +1187,9 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -1151,6 +1187,9 @@ class QKVParallelLinear(ColumnParallelLinear):
"QKVParallelLinear, assume the weight is the same " "QKVParallelLinear, assume the weight is the same "
"for all partitions.") "for all partitions.")
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
...@@ -1262,10 +1301,15 @@ class RowParallelLinear(LinearBase): ...@@ -1262,10 +1301,15 @@ class RowParallelLinear(LinearBase):
if input_dim: if input_dim:
weight_shape[input_dim] = weight_shape[input_dim] // tp_size weight_shape[input_dim] = weight_shape[input_dim] // tp_size
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
param_data = param.data param_data = param.data
if input_dim is not None and not is_sharded_weight: if input_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[input_dim] if not envs.VLLM_USE_NN or is_quantization:
shard_size = param_data.shape[input_dim]
else:
shard_size = param_data.shape[int(not(input_dim))]
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx, loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size) shard_size)
...@@ -1275,6 +1319,9 @@ class RowParallelLinear(LinearBase): ...@@ -1275,6 +1319,9 @@ class RowParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0: if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1) loaded_weight = loaded_weight.reshape(1)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
...@@ -1543,4 +1590,4 @@ class QKVCrossParallelLinear(LinearBase): ...@@ -1543,4 +1590,4 @@ class QKVCrossParallelLinear(LinearBase):
s += f", bias={self.bias is not None}" s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}" s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += ", gather_output=False" s += ", gather_output=False"
return s return s
\ No newline at end of file
...@@ -35,10 +35,16 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): ...@@ -35,10 +35,16 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
output_size: int, params_dtype: torch.dtype, output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs): **extra_weight_attrs):
"""Create weights for embedding layer.""" """Create weights for embedding layer."""
# if envs.VLLM_USE_NN:
# weight = Parameter(torch.empty(input_size_per_partition,
# sum(output_partition_sizes),
# dtype=params_dtype),
# requires_grad=False)
# else:
weight = Parameter(torch.empty(sum(output_partition_sizes), weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition, input_size_per_partition,
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs) set_weight_attrs(weight, extra_weight_attrs)
...@@ -56,7 +62,10 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase): ...@@ -56,7 +62,10 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
else: else:
return torch.matmul(x, layer.weight) return torch.matmul(x, layer.weight)
else: else:
return dispatch_unquantized_gemm()(x, layer.weight, bias) if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
return dispatch_unquantized_gemm()(x, layer.weight.t(), bias)
else:
return dispatch_unquantized_gemm()(x, layer.weight, bias)
def embedding(self, layer: torch.nn.Module, def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor: input_: torch.Tensor) -> torch.Tensor:
...@@ -404,6 +413,9 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -404,6 +413,9 @@ class VocabParallelEmbedding(torch.nn.Module):
# Copy the data. Select chunk corresponding to current shard. # Copy the data. Select chunk corresponding to current shard.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# if envs.VLLM_USE_NN and self.quant_method is not None:
# loaded_weight = loaded_weight.t()
if current_platform.is_hpu(): if current_platform.is_hpu():
# FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here, # FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
# so we're using a workaround. Remove this when fixed in # so we're using a workaround. Remove this when fixed in
...@@ -502,4 +514,4 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -502,4 +514,4 @@ class ParallelLMHead(VocabParallelEmbedding):
def forward(self, input_): def forward(self, input_):
del input_ del input_
raise RuntimeError("LMHead's weights should be used in the sampler.") raise RuntimeError("LMHead's weights should be used in the sampler.")
\ No newline at end of file
...@@ -18,6 +18,7 @@ from vllm.model_executor.models import ModelRegistry ...@@ -18,6 +18,7 @@ from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_classification_model, from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model, as_embedding_model,
as_reward_model) as_reward_model)
import vllm.envs as envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -94,19 +95,20 @@ def get_model_architecture( ...@@ -94,19 +95,20 @@ def get_model_architecture(
'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM', 'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM',
'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel'] 'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if not envs.VLLM_USE_NN:
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []: if os.getenv('LLAMA_NN') != '0':
os.environ['LLAMA_NN'] = '0' if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
else: os.environ['LLAMA_NN'] = '0'
os.environ['LLAMA_NN'] = '1' else:
if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0': os.environ['LLAMA_NN'] = '1'
os.environ['LM_NN'] = '0' if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0':
else: os.environ['LM_NN'] = '0'
os.environ['LM_NN'] = '1' else:
if os.getenv('GEMM_PAD') != '1': os.environ['LM_NN'] = '1'
os.environ['GEMM_PAD'] = '0' if os.getenv('GEMM_PAD') != '1':
if os.getenv('FA_PAD') != '1': os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0' if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
# awq相关配置 # awq相关配置
try: try:
if os.getenv('AWQ_MOE_SZ') == None: if os.getenv('AWQ_MOE_SZ') == None:
...@@ -205,4 +207,4 @@ def configure_quant_config(quant_config: QuantizationConfig, ...@@ -205,4 +207,4 @@ def configure_quant_config(quant_config: QuantizationConfig,
logger.warning( logger.warning(
"The model class %s has not defined `packed_modules_mapping`, " "The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored " "this may lead to incorrect mapping of quantized or ignored "
"modules", model_class.__name__) "modules", model_class.__name__)
\ No newline at end of file
...@@ -54,6 +54,7 @@ from .utils import (AutoWeightsLoader, extract_layer_index, ...@@ -54,6 +54,7 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
import vllm.envs as envs
class DeepseekMLP(nn.Module): class DeepseekMLP(nn.Module):
...@@ -152,6 +153,15 @@ class DeepseekMoE(nn.Module): ...@@ -152,6 +153,15 @@ class DeepseekMoE(nn.Module):
param.data = data param.data = data
self.w2 = self.w2.view(len(w2), *w2s[0].shape) self.w2 = self.w2.view(len(w2), *w2s[0].shape)
if envs.VLLM_USE_NN:
self.w1 = self.w1.permute(0,2,1).contiguous()
for expert, w in zip(self.experts, self.w1):
expert.gate_up_proj.weight.data = w.permute(1,0)
self.w2 = self.w2.permute(0, 2, 1).contiguous()
for expert, w in zip(self.experts, self.w2):
expert.down_proj.weight.data = w.permute(1, 0)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
......
...@@ -193,6 +193,7 @@ import torch ...@@ -193,6 +193,7 @@ import torch
import os import os
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl) MLAAttentionImpl)
...@@ -739,7 +740,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -739,7 +740,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
del eye del eye
# standardize to (output, input) # standardize to (output, input)
return dequant_weights.T return dequant_weights.T
return layer.weight return layer.weight if not envs.VLLM_USE_NN else layer.weight.T
# we currently do not have quantized bmm's which are needed for # we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment