Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6880bf15
Commit
6880bf15
authored
Jun 06, 2025
by
zhuwenwen
Browse files
[Model] Add VLLM_USE_NN to use nn layout
parent
fafe3ca7
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
114 additions
and
36 deletions
+114
-36
README.md
README.md
+4
-4
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+1
-1
vllm/envs.py
vllm/envs.py
+6
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+58
-11
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+17
-5
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+16
-14
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+10
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+2
-1
No files found.
README.md
View file @
6880bf15
...
...
@@ -12,10 +12,10 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| 结构 | 模型 | FP16/BF16 | AWQ | GPTQ | 支持版本 | 是否优化 |
| :------: | :------: | :------: | :------: |:------: | :------: |:------: |
| LlamaForCausalLM | Llama 3.2, Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,DeepSeek-R1-Distill-Llama | Yes | Yes | Yes | v0.5.0,Llama 3.2>=v0.6.2 | Yes |
| Llama4ForConditionalGeneration | Llama 4 | No/Yes | - | - | v0.8.5.post1 | No |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes | v0.5.0,Qwen-VL>=v0.6.2 | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5,DeepSeek-R1-Distill-Qwen,gte_Qwen2-1.5B-instruct | Yes | Yes | Yes | v0.5.0,gte>=v0.7.2 | Yes |
| Qwen3ForCausalLM | QWen3
| Yes | - | - | v0.8.4 | Yes |
| Llama4ForConditionalGeneration | Llama 4
| No/Yes | - | - | v0.8.5.post1 | No |
| QWenLMHeadModel | QWen,Qwen-VL
| Yes | Yes | Yes | v0.5.0,Qwen-VL>=v0.6.2 | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5,DeepSeek-R1-Distill-Qwen,gte_Qwen2-1.5B-instruct
| Yes | Yes | Yes | v0.5.0,gte>=v0.7.2 | Yes |
| Qwen3ForCausalLM | QWen3
,Qwen3-Embedding,Qwen3-Reranker
| Yes | - | - | v0.8.4 | Yes |
| Qwen3MoeForCausalLM | QWen3MoE | Yes | - | - | v0.8.4 | Yes |
| ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes | v0.5.0 | Yes |
| Glm4ForCausalLM | GLM-4-0414 | No/Yes | - | - | v0.8.5.post1 | Yes |
...
...
vllm/attention/backends/mla/common.py
View file @
6880bf15
...
...
@@ -1192,7 +1192,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
del
eye
# standardize to (output, input)
return
dequant_weights
.
T
return
layer
.
weight
return
layer
.
weight
if
not
envs
.
VLLM_USE_NN
else
layer
.
weight
.
T
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
...
...
vllm/envs.py
View file @
6880bf15
...
...
@@ -125,6 +125,7 @@ if TYPE_CHECKING:
VLLM_ENFORCE_EAGER_BS_THRESHOLD
:
Optional
[
int
]
=
None
VLLM_HAS_CONTEXT_DEFAULT
:
bool
=
False
VLLM_FLASH_ATTN_BACKEND
:
bool
=
False
VLLM_USE_NN
:
bool
=
False
VLLM_ENABLE_TBO
:
bool
=
False
VLLM_TBO_REQ_DELAY_MS
:
int
=
0
...
...
@@ -814,6 +815,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_FLASH_ATTN_BACKEND"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_FLASH_ATTN_BACKEND"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# If set, vLLM will transpose weight to use nn layout
"VLLM_USE_NN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_NN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# Enable two batch overlap.
"VLLM_ENABLE_TBO"
:
...
...
vllm/model_executor/layers/linear.py
View file @
6880bf15
...
...
@@ -108,7 +108,10 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
assert
loaded_weight
.
shape
[
0
]
==
1
loaded_weight
=
loaded_weight
[
0
]
return
param
[
shard_id
],
loaded_weight
if
envs
.
VLLM_USE_NN
:
return
param
[
shard_id
],
loaded_weight
.
t
()
else
:
return
param
[
shard_id
],
loaded_weight
# TODO(Isotr0py): We might need a more flexible structure to handle
...
...
@@ -194,10 +197,16 @@ class UnquantizedLinearMethod(LinearMethodBase):
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
if
envs
.
VLLM_USE_NN
:
weight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
,
sum
(
output_partition_sizes
),
dtype
=
params_dtype
),
requires_grad
=
False
)
else
:
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
...
...
@@ -219,7 +228,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
else
:
return
torch
.
matmul
(
x
,
layer
.
weight
)
else
:
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
,
bias
)
if
envs
.
VLLM_USE_NN
and
x
.
shape
[
-
1
]
==
layer
.
weight
.
shape
[
0
]:
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
.
t
(),
bias
)
else
:
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
,
bias
)
class
LinearBase
(
torch
.
nn
.
Module
):
...
...
@@ -339,6 +351,10 @@ class ReplicatedLinear(LinearBase):
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param
.
size
()
==
loaded_weight
.
size
(),
(
f
"Tried to load weights of size
{
loaded_weight
.
size
()
}
"
f
"to a parameter of size
{
param
.
size
()
}
"
)
...
...
@@ -456,6 +472,7 @@ class ColumnParallelLinear(LinearBase):
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight
=
is_sharded_weight
or
use_bitsandbytes_4bit
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
# Special case for GGUF
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
...
...
@@ -474,7 +491,10 @@ class ColumnParallelLinear(LinearBase):
param_data
=
param
.
data
if
output_dim
is
not
None
and
not
is_sharded_weight
:
shard_size
=
param_data
.
shape
[
output_dim
]
if
not
envs
.
VLLM_USE_NN
or
len
(
param_data
.
shape
)
==
1
or
is_quantization
:
shard_size
=
param_data
.
shape
[
output_dim
]
else
:
shard_size
=
param_data
.
shape
[
int
(
not
(
output_dim
))]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
...
...
@@ -484,6 +504,9 @@ class ColumnParallelLinear(LinearBase):
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -615,6 +638,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
if
loaded_shard_id
is
None
:
# Loaded weight is already fused on disk (mlp).
...
...
@@ -694,9 +718,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size
=
loaded_weight
.
shape
[
output_dim
]
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
\
loaded_shard_id
if
not
envs
.
VLLM_USE_NN
or
is_quantization
:
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
else
:
param_data
=
param_data
.
narrow
(
int
(
not
(
output_dim
)),
shard_offset
,
shard_size
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
if
not
is_sharded_weight
:
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
...
...
@@ -721,6 +748,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions."
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -1013,6 +1043,7 @@ class QKVParallelLinear(ColumnParallelLinear):
# Special case for per-tensor scales in fused case.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
if
loaded_shard_id
is
None
:
# Loaded weight is already fused on disk (qkv).
...
...
@@ -1120,8 +1151,13 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size
,
shard_offset
=
adjust_bitsandbytes_4bit_shard
(
param
,
orig_qkv_offsets
,
loaded_shard_id
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
if
not
envs
.
VLLM_USE_NN
or
len
(
param_data
.
shape
)
==
1
or
is_quantization
:
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
else
:
param_data
=
param_data
.
narrow
(
int
(
not
(
output_dim
)),
shard_offset
,
shard_size
)
if
loaded_shard_id
==
"q"
:
shard_id
=
tp_rank
else
:
...
...
@@ -1151,6 +1187,9 @@ class QKVParallelLinear(ColumnParallelLinear):
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -1262,10 +1301,15 @@ class RowParallelLinear(LinearBase):
if
input_dim
:
weight_shape
[
input_dim
]
=
weight_shape
[
input_dim
]
//
tp_size
param
.
materialize
(
tuple
(
weight_shape
),
dtype
=
loaded_weight
.
dtype
)
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
param_data
=
param
.
data
if
input_dim
is
not
None
and
not
is_sharded_weight
:
shard_size
=
param_data
.
shape
[
input_dim
]
if
not
envs
.
VLLM_USE_NN
or
is_quantization
:
shard_size
=
param_data
.
shape
[
input_dim
]
else
:
shard_size
=
param_data
.
shape
[
int
(
not
(
input_dim
))]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
...
...
@@ -1275,6 +1319,9 @@ class RowParallelLinear(LinearBase):
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -1543,4 +1590,4 @@ class QKVCrossParallelLinear(LinearBase):
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
s
+=
", gather_output=False"
return
s
return
s
\ No newline at end of file
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
6880bf15
...
...
@@ -35,10 +35,16 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
"""Create weights for embedding layer."""
# if envs.VLLM_USE_NN:
# weight = Parameter(torch.empty(input_size_per_partition,
# sum(output_partition_sizes),
# dtype=params_dtype),
# requires_grad=False)
# else:
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
input_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
})
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
...
...
@@ -56,7 +62,10 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
else
:
return
torch
.
matmul
(
x
,
layer
.
weight
)
else
:
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
,
bias
)
if
envs
.
VLLM_USE_NN
and
x
.
shape
[
-
1
]
==
layer
.
weight
.
shape
[
0
]:
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
.
t
(),
bias
)
else
:
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
,
bias
)
def
embedding
(
self
,
layer
:
torch
.
nn
.
Module
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -404,6 +413,9 @@ class VocabParallelEmbedding(torch.nn.Module):
# Copy the data. Select chunk corresponding to current shard.
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# if envs.VLLM_USE_NN and self.quant_method is not None:
# loaded_weight = loaded_weight.t()
if
current_platform
.
is_hpu
():
# FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
# so we're using a workaround. Remove this when fixed in
...
...
@@ -502,4 +514,4 @@ class ParallelLMHead(VocabParallelEmbedding):
def
forward
(
self
,
input_
):
del
input_
raise
RuntimeError
(
"LMHead's weights should be used in the sampler."
)
raise
RuntimeError
(
"LMHead's weights should be used in the sampler."
)
\ No newline at end of file
vllm/model_executor/model_loader/utils.py
View file @
6880bf15
...
...
@@ -18,6 +18,7 @@ from vllm.model_executor.models import ModelRegistry
from
vllm.model_executor.models.adapters
import
(
as_classification_model
,
as_embedding_model
,
as_reward_model
)
import
vllm.envs
as
envs
logger
=
init_logger
(
__name__
)
...
...
@@ -94,19 +95,20 @@ def get_model_architecture(
'ChatGLMModel'
,
'Glm4ForCausalLM'
,
'ChatGLMForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'TeleChat2ForCausalLM'
,
'MixtralForCausalLM'
,
'FalconForCausalLM'
,
'MedusaModel'
,
'MLPSpeculatorPreTrainedModel'
,
'DeepseekV2ForCausalLM'
,
'DeepseekV3ForCausalLM'
,
'DeepSeekMTPModel'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
visions
!=
[]:
os
.
environ
[
'LLAMA_NN'
]
=
'0'
else
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
if
(
architectures
==
[
'BloomForCausalLM'
]
or
architectures
==
[
'FalconForCausalLM'
])
or
os
.
getenv
(
'LM_NN'
)
==
'0'
:
os
.
environ
[
'LM_NN'
]
=
'0'
else
:
os
.
environ
[
'LM_NN'
]
=
'1'
if
os
.
getenv
(
'GEMM_PAD'
)
!=
'1'
:
os
.
environ
[
'GEMM_PAD'
]
=
'0'
if
os
.
getenv
(
'FA_PAD'
)
!=
'1'
:
os
.
environ
[
'FA_PAD'
]
=
'0'
if
not
envs
.
VLLM_USE_NN
:
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
visions
!=
[]:
os
.
environ
[
'LLAMA_NN'
]
=
'0'
else
:
os
.
environ
[
'LLAMA_NN'
]
=
'1'
if
(
architectures
==
[
'BloomForCausalLM'
]
or
architectures
==
[
'FalconForCausalLM'
])
or
os
.
getenv
(
'LM_NN'
)
==
'0'
:
os
.
environ
[
'LM_NN'
]
=
'0'
else
:
os
.
environ
[
'LM_NN'
]
=
'1'
if
os
.
getenv
(
'GEMM_PAD'
)
!=
'1'
:
os
.
environ
[
'GEMM_PAD'
]
=
'0'
if
os
.
getenv
(
'FA_PAD'
)
!=
'1'
:
os
.
environ
[
'FA_PAD'
]
=
'0'
# awq相关配置
try
:
if
os
.
getenv
(
'AWQ_MOE_SZ'
)
==
None
:
...
...
@@ -205,4 +207,4 @@ def configure_quant_config(quant_config: QuantizationConfig,
logger
.
warning
(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules"
,
model_class
.
__name__
)
"modules"
,
model_class
.
__name__
)
\ No newline at end of file
vllm/model_executor/models/deepseek.py
View file @
6880bf15
...
...
@@ -54,6 +54,7 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
import
vllm.envs
as
envs
class
DeepseekMLP
(
nn
.
Module
):
...
...
@@ -152,6 +153,15 @@ class DeepseekMoE(nn.Module):
param
.
data
=
data
self
.
w2
=
self
.
w2
.
view
(
len
(
w2
),
*
w2s
[
0
].
shape
)
if
envs
.
VLLM_USE_NN
:
self
.
w1
=
self
.
w1
.
permute
(
0
,
2
,
1
).
contiguous
()
for
expert
,
w
in
zip
(
self
.
experts
,
self
.
w1
):
expert
.
gate_up_proj
.
weight
.
data
=
w
.
permute
(
1
,
0
)
self
.
w2
=
self
.
w2
.
permute
(
0
,
2
,
1
).
contiguous
()
for
expert
,
w
in
zip
(
self
.
experts
,
self
.
w2
):
expert
.
down_proj
.
weight
.
data
=
w
.
permute
(
1
,
0
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
...
...
vllm/v1/attention/backends/mla/common.py
View file @
6880bf15
...
...
@@ -193,6 +193,7 @@ import torch
import
os
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
MLAAttentionImpl
)
...
...
@@ -739,7 +740,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
del
eye
# standardize to (output, input)
return
dequant_weights
.
T
return
layer
.
weight
return
layer
.
weight
if
not
envs
.
VLLM_USE_NN
else
layer
.
weight
.
T
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment