Commit 6880bf15 authored by zhuwenwen's avatar zhuwenwen
Browse files

[Model] Add VLLM_USE_NN to use nn layout

parent fafe3ca7
......@@ -15,7 +15,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| Llama4ForConditionalGeneration | Llama 4 | No/Yes | - | - | v0.8.5.post1 | No |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes | v0.5.0,Qwen-VL>=v0.6.2 | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5,DeepSeek-R1-Distill-Qwen,gte_Qwen2-1.5B-instruct | Yes | Yes | Yes | v0.5.0,gte>=v0.7.2 | Yes |
| Qwen3ForCausalLM | QWen3 | Yes | - | - | v0.8.4 | Yes |
| Qwen3ForCausalLM | QWen3,Qwen3-Embedding,Qwen3-Reranker | Yes | - | - | v0.8.4 | Yes |
| Qwen3MoeForCausalLM | QWen3MoE | Yes | - | - | v0.8.4 | Yes |
| ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes | v0.5.0 | Yes |
| Glm4ForCausalLM | GLM-4-0414 | No/Yes | - | - | v0.8.5.post1 | Yes |
......
......@@ -1192,7 +1192,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
return layer.weight if not envs.VLLM_USE_NN else layer.weight.T
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
......
......@@ -125,6 +125,7 @@ if TYPE_CHECKING:
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_FLASH_ATTN_BACKEND: bool = False
VLLM_USE_NN: bool = False
VLLM_ENABLE_TBO: bool = False
VLLM_TBO_REQ_DELAY_MS:int = 0
......@@ -815,6 +816,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_FLASH_ATTN_BACKEND", "False").lower() in
("true", "1")),
# If set, vLLM will transpose weight to use nn layout
"VLLM_USE_NN":
lambda: (os.environ.get("VLLM_USE_NN", "False").lower() in
("true", "1")),
# Enable two batch overlap.
"VLLM_ENABLE_TBO":
lambda: bool(int(os.getenv("VLLM_ENABLE_TBO", "0"))),
......
......@@ -108,6 +108,9 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
assert loaded_weight.shape[0] == 1
loaded_weight = loaded_weight[0]
if envs.VLLM_USE_NN:
return param[shard_id], loaded_weight.t()
else:
return param[shard_id], loaded_weight
......@@ -194,6 +197,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if envs.VLLM_USE_NN:
weight = Parameter(torch.empty(input_size_per_partition,
sum(output_partition_sizes),
dtype=params_dtype),
requires_grad=False)
else:
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
......@@ -218,6 +227,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
return torch.matmul(x, layer.weight) + bias
else:
return torch.matmul(x, layer.weight)
else:
if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
return dispatch_unquantized_gemm()(x, layer.weight.t(), bias)
else:
return dispatch_unquantized_gemm()(x, layer.weight, bias)
......@@ -339,6 +351,10 @@ class ReplicatedLinear(LinearBase):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param.size() == loaded_weight.size(), (
f"Tried to load weights of size {loaded_weight.size()}"
f"to a parameter of size {param.size()}")
......@@ -456,6 +472,7 @@ class ColumnParallelLinear(LinearBase):
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
......@@ -474,7 +491,10 @@ class ColumnParallelLinear(LinearBase):
param_data = param.data
if output_dim is not None and not is_sharded_weight:
if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
shard_size = param_data.shape[output_dim]
else:
shard_size = param_data.shape[int(not(output_dim))]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
......@@ -484,6 +504,9 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......@@ -615,6 +638,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
is_metadata = getattr(param, "is_metadata", False)
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (mlp).
......@@ -695,8 +719,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if not envs.VLLM_USE_NN or is_quantization:
param_data = param_data.narrow(output_dim, shard_offset, shard_size)
else:
param_data = param_data.narrow(int(not(output_dim)), shard_offset, shard_size)
start_idx = tp_rank * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
......@@ -721,6 +748,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......@@ -1013,6 +1043,7 @@ class QKVParallelLinear(ColumnParallelLinear):
# Special case for per-tensor scales in fused case.
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
if loaded_shard_id is None:
# Loaded weight is already fused on disk (qkv).
......@@ -1120,8 +1151,13 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
param, orig_qkv_offsets, loaded_shard_id)
if not envs.VLLM_USE_NN or len(param_data.shape)==1 or is_quantization:
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
else:
param_data = param_data.narrow(int(not(output_dim)), shard_offset,
shard_size)
if loaded_shard_id == "q":
shard_id = tp_rank
else:
......@@ -1151,6 +1187,9 @@ class QKVParallelLinear(ColumnParallelLinear):
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......@@ -1263,9 +1302,14 @@ class RowParallelLinear(LinearBase):
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)
is_quantization = not isinstance(self.quant_method, UnquantizedLinearMethod)
param_data = param.data
if input_dim is not None and not is_sharded_weight:
if not envs.VLLM_USE_NN or is_quantization:
shard_size = param_data.shape[input_dim]
else:
shard_size = param_data.shape[int(not(input_dim))]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
......@@ -1275,6 +1319,9 @@ class RowParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if envs.VLLM_USE_NN and not is_quantization:
loaded_weight = loaded_weight.t()
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
......
......@@ -35,6 +35,12 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for embedding layer."""
# if envs.VLLM_USE_NN:
# weight = Parameter(torch.empty(input_size_per_partition,
# sum(output_partition_sizes),
# dtype=params_dtype),
# requires_grad=False)
# else:
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
......@@ -55,6 +61,9 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
return torch.matmul(x, layer.weight) + bias
else:
return torch.matmul(x, layer.weight)
else:
if envs.VLLM_USE_NN and x.shape[-1] == layer.weight.shape[0]:
return dispatch_unquantized_gemm()(x, layer.weight.t(), bias)
else:
return dispatch_unquantized_gemm()(x, layer.weight, bias)
......@@ -404,6 +413,9 @@ class VocabParallelEmbedding(torch.nn.Module):
# Copy the data. Select chunk corresponding to current shard.
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# if envs.VLLM_USE_NN and self.quant_method is not None:
# loaded_weight = loaded_weight.t()
if current_platform.is_hpu():
# FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
# so we're using a workaround. Remove this when fixed in
......
......@@ -18,6 +18,7 @@ from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
import vllm.envs as envs
logger = init_logger(__name__)
......@@ -94,6 +95,7 @@ def get_model_architecture(
'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM',
'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures):
if not envs.VLLM_USE_NN:
if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
os.environ['LLAMA_NN'] = '0'
......
......@@ -54,6 +54,7 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
import vllm.envs as envs
class DeepseekMLP(nn.Module):
......@@ -153,6 +154,15 @@ class DeepseekMoE(nn.Module):
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
if envs.VLLM_USE_NN:
self.w1 = self.w1.permute(0,2,1).contiguous()
for expert, w in zip(self.experts, self.w1):
expert.gate_up_proj.weight.data = w.permute(1,0)
self.w2 = self.w2.permute(0, 2, 1).contiguous()
for expert, w in zip(self.experts, self.w2):
expert.down_proj.weight.data = w.permute(1, 0)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
......
......@@ -193,6 +193,7 @@ import torch
import os
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
......@@ -739,7 +740,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
return layer.weight if not envs.VLLM_USE_NN else layer.weight.T
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment