Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bac201c9
Commit
bac201c9
authored
Feb 11, 2026
by
zhuwenwen
Browse files
fix: 修复ep的变量未定义
set VLLM_USE_FUSED_QA_KVA_GEMM=1 feat:w4a8Linear调用apply_int8_linear,以支持blaslt
parent
ffd123f6
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
23 additions
and
86 deletions
+23
-86
setup.py
setup.py
+2
-2
vllm/envs.py
vllm/envs.py
+1
-1
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+11
-76
vllm/model_executor/model_loader/utils.py
vllm/model_executor/model_loader/utils.py
+6
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+2
-6
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+1
-1
No files found.
setup.py
View file @
bac201c9
...
@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
...
@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if
sha
is
None
:
if
sha
is
None
:
sha
=
get_sha
(
vllm_root
)
sha
=
get_sha
(
vllm_root
)
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
version
=
'das.opt
5
.'
+
sha
[:
7
]
version
=
'das.opt
6
.'
+
sha
[:
7
]
else
:
else
:
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
if
(
major
,
minor
)
>=
(
'2'
,
'5'
):
version
=
'das.opt
5
'
version
=
'das.opt
6
'
# dtk version
# dtk version
...
...
vllm/envs.py
View file @
bac201c9
...
@@ -1365,7 +1365,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1365,7 +1365,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Only quantized DeepSeek models supported.
# Only quantized DeepSeek models supported.
# Unquantized versions are not supported.
# Unquantized versions are not supported.
"VLLM_USE_FUSED_QA_KVA_GEMM"
:
"VLLM_USE_FUSED_QA_KVA_GEMM"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_QA_KVA_GEMM"
,
"
Fals
e"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FUSED_QA_KVA_GEMM"
,
"
Tru
e"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
"VLLM_ZERO_OVERHEAD_ENHANCE"
:
"VLLM_ZERO_OVERHEAD_ENHANCE"
:
lambda
:
(
os
.
getenv
(
'VLLM_ZERO_OVERHEAD_ENHANCE'
,
'0'
).
lower
()
in
lambda
:
(
os
.
getenv
(
'VLLM_ZERO_OVERHEAD_ENHANCE'
,
'0'
).
lower
()
in
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
bac201c9
...
@@ -5,6 +5,7 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -5,6 +5,7 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
apply_int8_linear
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
...
@@ -111,6 +112,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -111,6 +112,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
for
key
,
value
in
configs_dict
.
items
():
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
elif
self
.
w8a8_strategy
==
3
:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
T
else
:
else
:
weight_data
=
layer
.
weight
.
data
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
...
@@ -158,81 +161,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -158,81 +161,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
input_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
**
_
silu_quant_args
:
Optional
[
list
[
torch
.
Tensor
]]
=
None
,
**
_
):
):
if
envs
.
USE_FUSED_RMS_QUANT
and
input_quant_args
is
not
None
:
return
apply_int8_linear
(
input
=
x
,
assert
len
(
input_quant_args
)
==
2
weight
=
layer
.
weight
,
x_q
,
x_scale
=
input_quant_args
weight_scale
=
layer
.
weight_scale
,
elif
envs
.
USE_FUSED_SILU_MUL_QUANT
and
silu_quant_args
is
not
None
:
bias
=
bias
,
assert
len
(
silu_quant_args
)
==
2
w8a8_strategy
=
self
.
w8a8_strategy
,
x_q
,
x_scale
=
silu_quant_args
input_quant_args
=
input_quant_args
,
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
input_quant_args
is
not
None
:
silu_quant_args
=
silu_quant_args
)
assert
len
(
input_quant_args
)
==
2
x_q
,
x_scale
=
input_quant_args
elif
envs
.
USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT
and
silu_quant_args
is
not
None
:
assert
len
(
silu_quant_args
)
==
2
x_q
,
x_scale
=
silu_quant_args
else
:
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
if
self
.
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
k
=
x_q
.
shape
[
1
]
n
=
layer
.
weight
.
shape
[
1
]
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
best_config
=
None
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
m_
=
m
elif
m
<=
64
:
m_
=
((
m
+
3
)
//
4
)
*
4
#取值到最近的4的倍数
elif
m
<=
160
:
m_
=
(
m
//
8
)
*
8
elif
m
<
200
:
#256
m_
=
160
elif
m
<
480
:
#512
m_
=
256
elif
m
<
960
:
#1024
m_
=
512
elif
m
<
2048
:
m_
=
1024
elif
m
<
4096
:
m_
=
2048
elif
m
<
6000
:
m_
=
4096
else
:
m_
=
8192
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
#if best_config==None:
# print("m:{},n:{},k:{}".format(m,n,k))
# print("config not found!")
return
ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
elif
self
.
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
class
SlimQuantW4A8Int8MoEMethod
:
class
SlimQuantW4A8Int8MoEMethod
:
...
...
vllm/model_executor/model_loader/utils.py
View file @
bac201c9
...
@@ -288,6 +288,9 @@ def get_model_architecture(
...
@@ -288,6 +288,9 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_OPT_RESHAPE_AND_CACHE'
]
=
'1'
os
.
environ
[
'VLLM_USE_OPT_RESHAPE_AND_CACHE'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_FUSED_RMS_ROPE"
):
if
not
envs
.
is_set
(
"VLLM_USE_FUSED_RMS_ROPE"
):
os
.
environ
[
'VLLM_USE_FUSED_RMS_ROPE'
]
=
'1'
os
.
environ
[
'VLLM_USE_FUSED_RMS_ROPE'
]
=
'1'
if
architectures
in
[[
'Qwen3ForCausalLM'
]]:
if
not
envs
.
is_set
(
"VLLM_USE_OPT_RESHAPE_AND_CACHE"
):
os
.
environ
[
'VLLM_USE_OPT_RESHAPE_AND_CACHE'
]
=
'1'
if
architectures
in
[[
'DeepseekV32ForCausalLM'
]]:
if
architectures
in
[[
'DeepseekV32ForCausalLM'
]]:
if
not
envs
.
is_set
(
"VLLM_USE_V32_ENCODE"
):
if
not
envs
.
is_set
(
"VLLM_USE_V32_ENCODE"
):
...
@@ -336,6 +339,9 @@ def get_model_architecture(
...
@@ -336,6 +339,9 @@ def get_model_architecture(
os
.
environ
[
'VLLM_USE_OPT_RESHAPE_AND_CACHE'
]
=
'1'
os
.
environ
[
'VLLM_USE_OPT_RESHAPE_AND_CACHE'
]
=
'1'
if
not
envs
.
is_set
(
"VLLM_USE_FUSED_RMS_ROPE"
):
if
not
envs
.
is_set
(
"VLLM_USE_FUSED_RMS_ROPE"
):
os
.
environ
[
'VLLM_USE_FUSED_RMS_ROPE'
]
=
'1'
os
.
environ
[
'VLLM_USE_FUSED_RMS_ROPE'
]
=
'1'
if
architectures
in
[[
'Qwen3ForCausalLM'
]]:
if
not
envs
.
is_set
(
"VLLM_USE_OPT_RESHAPE_AND_CACHE"
):
os
.
environ
[
'VLLM_USE_OPT_RESHAPE_AND_CACHE'
]
=
'1'
if
architectures
in
[[
'DeepseekV32ForCausalLM'
]]:
if
architectures
in
[[
'DeepseekV32ForCausalLM'
]]:
if
not
envs
.
is_set
(
"VLLM_USE_V32_ENCODE"
):
if
not
envs
.
is_set
(
"VLLM_USE_V32_ENCODE"
):
...
...
vllm/model_executor/models/deepseek_v2.py
View file @
bac201c9
...
@@ -422,9 +422,6 @@ class DeepseekV2MoE(nn.Module):
...
@@ -422,9 +422,6 @@ class DeepseekV2MoE(nn.Module):
# Fix FP16 overflow
# Fix FP16 overflow
# See DeepseekV2DecoderLayer for more details.
# See DeepseekV2DecoderLayer for more details.
# fp16 mode not fused quant
# fp16 mode not fused quant
if
i_q
is
not
None
:
i_q
=
iqis
[
0
]
i_s
=
iqis
[
1
]
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
i_q
=
i_q
,
i_s
=
i_s
)
i_q
=
i_q
,
i_s
=
i_s
)
...
@@ -469,9 +466,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -469,9 +466,8 @@ class DeepseekV2MoE(nn.Module):
assert
shared_output
is
not
None
assert
shared_output
is
not
None
final_hidden_states
+=
(
shared_output
*
(
1.
/
self
.
routed_scaling_factor
))
final_hidden_states
+=
(
shared_output
*
(
1.
/
self
.
routed_scaling_factor
))
else
:
else
:
if
i_q
is
not
None
:
if
iqis
is
not
None
:
i_q
=
iqis
[
0
]
i_q
,
i_s
=
iqis
i_s
=
iqis
[
1
]
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
final_hidden_states
=
self
.
experts
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
i_q
=
i_q
,
i_s
=
i_s
)
i_q
=
i_q
,
i_s
=
i_s
)
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
bac201c9
...
@@ -577,7 +577,7 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -577,7 +577,7 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_v_scale
,
layer
.
_v_scale
,
)
)
else
:
else
:
if
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
and
key
.
dtype
==
value
.
dtype
==
torch
.
float16
:
if
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
:
from
lightop
import
reshape_and_cache_cuda
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
reshape_and_cache_cuda
(
key
,
value
,
key
,
value
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment