Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6880bf15
Commit
6880bf15
authored
Jun 06, 2025
by
zhuwenwen
Browse files
[Model] Add VLLM_USE_NN to use nn layout
parent
fafe3ca7
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
114 additions
and
36 deletions
+114
-36
README.md
README.md
+4
-4
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+1
-1
vllm/envs.py
vllm/envs.py
+6
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+58
-11
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+17
-5
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+16
-14
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+10
-0
vllm/v1/attention/backends/mla/common.py
vllm/v1/attention/backends/mla/common.py
+2
-1
No files found.
README.md
View file @
6880bf15
...
...
@@ -15,7 +15,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
| Llama4ForConditionalGeneration | Llama 4 | No/Yes | - | - | v0.8.5.post1 | No |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes | v0.5.0,Qwen-VL>=v0.6.2 | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5,DeepSeek-R1-Distill-Qwen,gte_Qwen2-1.5B-instruct | Yes | Yes | Yes | v0.5.0,gte>=v0.7.2 | Yes |
| Qwen3ForCausalLM | QWen3
| Yes | - | - | v0.8.4 | Yes |
| Qwen3ForCausalLM | QWen3
,Qwen3-Embedding,Qwen3-Reranker
| Yes | - | - | v0.8.4 | Yes |
| Qwen3MoeForCausalLM | QWen3MoE | Yes | - | - | v0.8.4 | Yes |
| ChatGLMModel | glm-4v-9b,chatglm3,chatglm2 | Yes | No | Yes | v0.5.0 | Yes |
| Glm4ForCausalLM | GLM-4-0414 | No/Yes | - | - | v0.8.5.post1 | Yes |
...
...
vllm/attention/backends/mla/common.py
View file @
6880bf15
...
...
@@ -1192,7 +1192,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
del
eye
# standardize to (output, input)
return
dequant_weights
.
T
return
layer
.
weight
return
layer
.
weight
if
not
envs
.
VLLM_USE_NN
else
layer
.
weight
.
T
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
...
...
vllm/envs.py
View file @
6880bf15
...
...
@@ -125,6 +125,7 @@ if TYPE_CHECKING:
VLLM_ENFORCE_EAGER_BS_THRESHOLD
:
Optional
[
int
]
=
None
VLLM_HAS_CONTEXT_DEFAULT
:
bool
=
False
VLLM_FLASH_ATTN_BACKEND
:
bool
=
False
VLLM_USE_NN
:
bool
=
False
VLLM_ENABLE_TBO
:
bool
=
False
VLLM_TBO_REQ_DELAY_MS
:
int
=
0
...
...
@@ -815,6 +816,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_FLASH_ATTN_BACKEND"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# If set, vLLM will transpose weight to use nn layout
"VLLM_USE_NN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_NN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# Enable two batch overlap.
"VLLM_ENABLE_TBO"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_ENABLE_TBO"
,
"0"
))),
...
...
vllm/model_executor/layers/linear.py
View file @
6880bf15
...
...
@@ -108,6 +108,9 @@ def adjust_scalar_to_fused_array(param, loaded_weight, shard_id):
assert
loaded_weight
.
shape
[
0
]
==
1
loaded_weight
=
loaded_weight
[
0
]
if
envs
.
VLLM_USE_NN
:
return
param
[
shard_id
],
loaded_weight
.
t
()
else
:
return
param
[
shard_id
],
loaded_weight
...
...
@@ -194,6 +197,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
if
envs
.
VLLM_USE_NN
:
weight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
,
sum
(
output_partition_sizes
),
dtype
=
params_dtype
),
requires_grad
=
False
)
else
:
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
...
...
@@ -218,6 +227,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
return
torch
.
matmul
(
x
,
layer
.
weight
)
+
bias
else
:
return
torch
.
matmul
(
x
,
layer
.
weight
)
else
:
if
envs
.
VLLM_USE_NN
and
x
.
shape
[
-
1
]
==
layer
.
weight
.
shape
[
0
]:
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
.
t
(),
bias
)
else
:
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
,
bias
)
...
...
@@ -339,6 +351,10 @@ class ReplicatedLinear(LinearBase):
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param
.
size
()
==
loaded_weight
.
size
(),
(
f
"Tried to load weights of size
{
loaded_weight
.
size
()
}
"
f
"to a parameter of size
{
param
.
size
()
}
"
)
...
...
@@ -456,6 +472,7 @@ class ColumnParallelLinear(LinearBase):
# bitsandbytes loads the weights of the specific portion
# no need to narrow
is_sharded_weight
=
is_sharded_weight
or
use_bitsandbytes_4bit
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
# Special case for GGUF
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
...
...
@@ -474,7 +491,10 @@ class ColumnParallelLinear(LinearBase):
param_data
=
param
.
data
if
output_dim
is
not
None
and
not
is_sharded_weight
:
if
not
envs
.
VLLM_USE_NN
or
len
(
param_data
.
shape
)
==
1
or
is_quantization
:
shard_size
=
param_data
.
shape
[
output_dim
]
else
:
shard_size
=
param_data
.
shape
[
int
(
not
(
output_dim
))]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
...
...
@@ -484,6 +504,9 @@ class ColumnParallelLinear(LinearBase):
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -615,6 +638,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
# Special case for per-tensor scale to load scalar into fused array.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
if
loaded_shard_id
is
None
:
# Loaded weight is already fused on disk (mlp).
...
...
@@ -695,8 +719,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
\
loaded_shard_id
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
if
not
envs
.
VLLM_USE_NN
or
is_quantization
:
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
else
:
param_data
=
param_data
.
narrow
(
int
(
not
(
output_dim
)),
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
if
not
is_sharded_weight
:
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
...
...
@@ -721,6 +748,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions."
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -1013,6 +1043,7 @@ class QKVParallelLinear(ColumnParallelLinear):
# Special case for per-tensor scales in fused case.
needs_scalar_to_array
=
getattr
(
param
,
"needs_scalar_to_array"
,
False
)
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
if
loaded_shard_id
is
None
:
# Loaded weight is already fused on disk (qkv).
...
...
@@ -1120,8 +1151,13 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size
,
shard_offset
=
adjust_bitsandbytes_4bit_shard
(
param
,
orig_qkv_offsets
,
loaded_shard_id
)
if
not
envs
.
VLLM_USE_NN
or
len
(
param_data
.
shape
)
==
1
or
is_quantization
:
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
else
:
param_data
=
param_data
.
narrow
(
int
(
not
(
output_dim
)),
shard_offset
,
shard_size
)
if
loaded_shard_id
==
"q"
:
shard_id
=
tp_rank
else
:
...
...
@@ -1151,6 +1187,9 @@ class QKVParallelLinear(ColumnParallelLinear):
"QKVParallelLinear, assume the weight is the same "
"for all partitions."
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
@@ -1263,9 +1302,14 @@ class RowParallelLinear(LinearBase):
weight_shape
[
input_dim
]
=
weight_shape
[
input_dim
]
//
tp_size
param
.
materialize
(
tuple
(
weight_shape
),
dtype
=
loaded_weight
.
dtype
)
is_quantization
=
not
isinstance
(
self
.
quant_method
,
UnquantizedLinearMethod
)
param_data
=
param
.
data
if
input_dim
is
not
None
and
not
is_sharded_weight
:
if
not
envs
.
VLLM_USE_NN
or
is_quantization
:
shard_size
=
param_data
.
shape
[
input_dim
]
else
:
shard_size
=
param_data
.
shape
[
int
(
not
(
input_dim
))]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
...
...
@@ -1275,6 +1319,9 @@ class RowParallelLinear(LinearBase):
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
if
envs
.
VLLM_USE_NN
and
not
is_quantization
:
loaded_weight
=
loaded_weight
.
t
()
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
6880bf15
...
...
@@ -35,6 +35,12 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
"""Create weights for embedding layer."""
# if envs.VLLM_USE_NN:
# weight = Parameter(torch.empty(input_size_per_partition,
# sum(output_partition_sizes),
# dtype=params_dtype),
# requires_grad=False)
# else:
weight
=
Parameter
(
torch
.
empty
(
sum
(
output_partition_sizes
),
input_size_per_partition
,
dtype
=
params_dtype
),
...
...
@@ -55,6 +61,9 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
return
torch
.
matmul
(
x
,
layer
.
weight
)
+
bias
else
:
return
torch
.
matmul
(
x
,
layer
.
weight
)
else
:
if
envs
.
VLLM_USE_NN
and
x
.
shape
[
-
1
]
==
layer
.
weight
.
shape
[
0
]:
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
.
t
(),
bias
)
else
:
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
,
bias
)
...
...
@@ -404,6 +413,9 @@ class VocabParallelEmbedding(torch.nn.Module):
# Copy the data. Select chunk corresponding to current shard.
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# if envs.VLLM_USE_NN and self.quant_method is not None:
# loaded_weight = loaded_weight.t()
if
current_platform
.
is_hpu
():
# FIXME(kzawora): Weight copy with slicing bugs out on Gaudi here,
# so we're using a workaround. Remove this when fixed in
...
...
vllm/model_executor/model_loader/utils.py
View file @
6880bf15
...
...
@@ -18,6 +18,7 @@ from vllm.model_executor.models import ModelRegistry
from
vllm.model_executor.models.adapters
import
(
as_classification_model
,
as_embedding_model
,
as_reward_model
)
import
vllm.envs
as
envs
logger
=
init_logger
(
__name__
)
...
...
@@ -94,6 +95,7 @@ def get_model_architecture(
'ChatGLMModel'
,
'Glm4ForCausalLM'
,
'ChatGLMForConditionalGeneration'
,
'BaichuanForCausalLM'
,
'BloomForCausalLM'
,
'TeleChat2ForCausalLM'
,
'MixtralForCausalLM'
,
'FalconForCausalLM'
,
'MedusaModel'
,
'MLPSpeculatorPreTrainedModel'
,
'DeepseekV2ForCausalLM'
,
'DeepseekV3ForCausalLM'
,
'DeepSeekMTPModel'
]
if
any
(
arch
in
architectures
for
arch
in
support_nn_architectures
):
if
not
envs
.
VLLM_USE_NN
:
if
os
.
getenv
(
'LLAMA_NN'
)
!=
'0'
:
if
(
architectures
==
[
'QWenLMHeadModel'
]
or
architectures
==
[
'ChatGLMModel'
]
)
and
visions
!=
[]:
os
.
environ
[
'LLAMA_NN'
]
=
'0'
...
...
vllm/model_executor/models/deepseek.py
View file @
6880bf15
...
...
@@ -54,6 +54,7 @@ from .utils import (AutoWeightsLoader, extract_layer_index,
is_pp_missing_parameter
,
make_empty_intermediate_tensors_factory
,
make_layers
,
maybe_prefix
)
import
vllm.envs
as
envs
class
DeepseekMLP
(
nn
.
Module
):
...
...
@@ -153,6 +154,15 @@ class DeepseekMoE(nn.Module):
self
.
w2
=
self
.
w2
.
view
(
len
(
w2
),
*
w2s
[
0
].
shape
)
if
envs
.
VLLM_USE_NN
:
self
.
w1
=
self
.
w1
.
permute
(
0
,
2
,
1
).
contiguous
()
for
expert
,
w
in
zip
(
self
.
experts
,
self
.
w1
):
expert
.
gate_up_proj
.
weight
.
data
=
w
.
permute
(
1
,
0
)
self
.
w2
=
self
.
w2
.
permute
(
0
,
2
,
1
).
contiguous
()
for
expert
,
w
in
zip
(
self
.
experts
,
self
.
w2
):
expert
.
down_proj
.
weight
.
data
=
w
.
permute
(
1
,
0
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
...
vllm/v1/attention/backends/mla/common.py
View file @
6880bf15
...
...
@@ -193,6 +193,7 @@ import torch
import
os
from
vllm
import
_custom_ops
as
ops
from
vllm
import
envs
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionLayer
,
AttentionMetadata
,
MLAAttentionImpl
)
...
...
@@ -739,7 +740,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
del
eye
# standardize to (output, input)
return
dequant_weights
.
T
return
layer
.
weight
return
layer
.
weight
if
not
envs
.
VLLM_USE_NN
else
layer
.
weight
.
T
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment