Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
69f30ae0
Commit
69f30ae0
authored
Sep 01, 2025
by
王敏
Browse files
Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev
parents
d04683a4
4a946680
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1263 additions
and
21 deletions
+1263
-21
vllm/attention/backends/dual_chunk_flash_attn.py
vllm/attention/backends/dual_chunk_flash_attn.py
+7
-2
vllm/attention/backends/utils.py
vllm/attention/backends/utils.py
+27
-6
vllm/config.py
vllm/config.py
+3
-1
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+2
-2
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+3
-3
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+4
-1
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+282
-0
vllm/model_executor/layers/quantization/utils/w4a8_utils.py
vllm/model_executor/layers/quantization/utils/w4a8_utils.py
+72
-0
vllm/model_executor/models/deepseek_v2.py
vllm/model_executor/models/deepseek_v2.py
+11
-4
vllm/perf/benchmark_moe.py
vllm/perf/benchmark_moe.py
+846
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+5
-1
vllm/zero_overhead/v1/core.py
vllm/zero_overhead/v1/core.py
+1
-1
No files found.
vllm/attention/backends/dual_chunk_flash_attn.py
View file @
69f30ae0
...
...
@@ -19,8 +19,13 @@ from vllm.attention.backends.flash_attn import (FlashAttentionBackend,
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
vllm.logger
import
init_logger
from
vllm.utils
import
async_tensor_h2d
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
sparse_attn_func
)
from
vllm.platforms
import
current_platform
if
not
current_platform
.
is_rocm
():
from
vllm.vllm_flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
sparse_attn_func
)
else
:
from
flash_attn
import
(
flash_attn_varlen_func
,
flash_attn_with_kvcache
,
sparse_attn_func
)
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUBuilder
...
...
vllm/attention/backends/utils.py
View file @
69f30ae0
...
...
@@ -246,12 +246,33 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device
,
non_blocking
=
True
)
else
:
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
has_empty
:
bool
=
any
(
len
(
bt
)
==
0
for
bt
in
self
.
block_tables
)
has_non_empty
=
any
(
len
(
bt
)
>
0
for
bt
in
self
.
block_tables
)
max_block_length
=
0
if
has_empty
and
has_non_empty
:
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
block_tables
=
inter_data
.
block_tables
if
block_tables
:
for
seq_id
in
inter_data
.
seq_ids
:
if
seq_id
in
block_tables
:
block_table
=
block_tables
[
seq_id
]
max_block_length
=
max
(
max_block_length
,
len
(
block_table
))
if
max_block_length
>
0
:
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
max_len
=
max_block_length
,
)
else
:
block_tables
=
make_tensor_with_pad
(
self
.
block_tables
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
device
,
)
assert
max_query_len
>
0
,
"query_lens: {}"
.
format
(
query_lens
)
assert
device
is
not
None
...
...
vllm/config.py
View file @
69f30ae0
...
...
@@ -893,7 +893,8 @@ class ModelConfig:
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed-tensors"
,
"experts_int8"
,
"quark"
,
"modelopt_fp4"
,
"bitblas"
,
"gptq_bitblas"
"quark"
,
"modelopt_fp4"
,
"bitblas"
,
"gptq_bitblas"
,
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
cast
(
me_quant
.
QuantizationMethods
,
...
...
@@ -920,6 +921,7 @@ class ModelConfig:
"awq_marlin"
,
"ipex"
,
"moe_wna16"
,
"slimquant_w4a8_marlin"
]
quantization_methods
=
[
q
for
q
in
supported_quantization
if
q
not
in
overrides
...
...
vllm/engine/arg_utils.py
View file @
69f30ae0
...
...
@@ -1107,8 +1107,8 @@ class EngineArgs:
"Cuda graph is not supported with DualChunkFlashAttention. "
"To run the model in eager mode, set 'enforce_eager=True' "
"or use '--enforce-eager' in the CLI."
)
assert
current_platform
.
is_cuda
(),
(
"DualChunkFlashAttention is
only
supported on CUDA platform."
)
assert
current_platform
.
is_cuda
()
or
current_platform
.
is_rocm
()
,
(
"DualChunkFlashAttention is supported on CUDA
/ROCM
platform."
)
assert
not
use_v1
,
(
"DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'"
)
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
69f30ae0
...
...
@@ -811,9 +811,9 @@ class FusedMoE(torch.nn.Module):
"CompressedTensorsWNA16MoEMethod"
)):
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"BlockInt8MoEMethod"
)):
moe_quant_params
[
"intermediate_size"
]
=
self
.
intermediate_size_per_partition
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"SlimQuantW4A8Int8MoEMethod"
)):
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"BlockInt8MoEMethod"
,
"SlimQuantW4A8Int8MoEMethod"
,
"SlimQuantW4A8Int8M
arlinM
oEMethod"
)):
moe_quant_params
[
"intermediate_size"
]
=
self
.
intermediate_size_per_partition
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
69f30ae0
...
...
@@ -37,7 +37,8 @@ QuantizationMethods = Literal[
"auto-round"
,
"rtn"
,
"blockwise_int8"
,
"slimquant_w4a8"
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
]
QUANTIZATION_METHODS
:
list
[
str
]
=
list
(
get_args
(
QuantizationMethods
))
...
...
@@ -118,6 +119,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from
.tpu_int8
import
Int8TpuConfig
from
.blockwise_int8
import
BlockInt8Config
from
.slimquant_w4a8
import
SlimQuantW4A8Int8Config
from
.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
method_to_config
:
dict
[
str
,
type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
...
...
@@ -151,6 +153,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"rtn"
:
RTNConfig
,
"blockwise_int8"
:
BlockInt8Config
,
"slimquant_w4a8"
:
SlimQuantW4A8Int8Config
,
"slimquant_w4a8_marlin"
:
SlimQuantW4A8Int8MarlinConfig
,
}
# Update the `method_to_config` with customized quantization methods.
method_to_config
.
update
(
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
)
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
0 → 100644
View file @
69f30ae0
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
os
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.w4a8_utils
import
w4a8_2_marlin_weight
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
)
from
vllm.model_executor.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
try
:
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
class
MarlinMoeWorkspace
:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device
"""
_instances
=
{}
def
__new__
(
cls
,
device
):
if
device
not
in
cls
.
_instances
:
instance
=
super
().
__new__
(
cls
)
instance
.
_initialized
=
False
cls
.
_instances
[
device
]
=
instance
return
cls
.
_instances
[
device
]
def
__init__
(
self
,
device
):
if
self
.
_initialized
:
return
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
self
.
workspace
=
torch
.
zeros
(
500
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
self
.
global_reduce_buffer
=
torch
.
zeros
(
sms
*
6
*
128
*
512
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
self
.
_initialized
=
True
def
get_buffers
(
self
):
return
self
.
workspace
,
self
.
global_reduce_buffer
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
scales
=
scale_a
*
scale_b
.
T
gemmout
=
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))
output
=
(
scales
*
gemmout
).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
out_dtype
)
class
SlimQuantW4A8Int8MarlinConfig
(
QuantizationConfig
):
"""Config class for W4A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def
__init__
(
self
):
pass
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
@
classmethod
def
get_name
(
self
)
->
str
:
return
"slimquant_w4a8_marlin"
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"SlimQuantW4A8Int8MarlinConfig"
:
return
cls
()
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
QuantizationMethods
]:
if
hf_quant_cfg
.
get
(
"quant_method"
)
==
"slimquant_w4a8"
\
and
user_quant
==
"slimquant_w4a8_marlin"
:
return
cls
.
get_name
()
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
return
SlimQuantW4A8Int8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
SlimQuantW4A8Int8MarlinMoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SlimQuantW4A8Int8MarlinMoEMethod
:
"""MoE method for W4A8INT8 Marlin.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Args:
quant_config: The quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
tp_size
=
get_tensor_model_parallel_world_size
()
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
w13_input_scale
=
None
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
None
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
w1_marlin_list
=
[]
for
e
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w1_marlin_in
=
w4a8_2_marlin_weight
(
layer
.
w13_weight
[
e
])
w1_marlin_list
.
append
(
w1_marlin_in
)
layer
.
w13_weight
=
Parameter
(
torch
.
stack
(
w1_marlin_list
,
dim
=
0
),
requires_grad
=
False
)
w2_marlin_list
=
[]
for
e
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w2_marlin_in
=
w4a8_2_marlin_weight
(
layer
.
w2_weight
[
e
])
w2_marlin_list
.
append
(
w2_marlin_in
)
layer
.
w2_weight
=
Parameter
(
torch
.
stack
(
w2_marlin_list
,
dim
=
0
),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
return
fused_experts_impl_w4a8_marlin
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
workspace
=
workspace
,
global_reduce_buffer
=
global_reduce_buffer
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
)
vllm/model_executor/layers/quantization/utils/w4a8_utils.py
0 → 100644
View file @
69f30ae0
import
torch
import
numpy
as
np
def
unpack_int8_to_int4
(
tensor_int8
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
Args:
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
Returns:
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
"""
if
tensor_int8
.
dtype
!=
torch
.
int8
:
raise
ValueError
(
"Input tensor must be of type torch.int8"
)
N
,
K_half
=
tensor_int8
.
shape
tensor_uint8
=
tensor_int8
.
to
(
torch
.
uint8
)
high4
=
tensor_uint8
&
0x0F
low4
=
(
tensor_uint8
>>
4
)
&
0x0F
unpacked
=
torch
.
empty
((
N
,
K_half
*
2
),
dtype
=
torch
.
int32
,
device
=
tensor_int8
.
device
)
unpacked
[:,
0
::
2
]
=
low4
.
to
(
torch
.
int32
)
unpacked
[:,
1
::
2
]
=
high4
.
to
(
torch
.
int32
)
return
unpacked
def
get_weight_perms
(
interleave
:
bool
=
True
):
perm
=
[]
for
i
in
range
(
64
):
for
col
in
range
(
4
):
cur_col
=
(
i
%
16
)
*
4
+
col
for
row
in
range
(
8
):
cur_row
=
(
i
//
16
)
*
8
+
row
cur_idx
=
cur_row
*
64
+
cur_col
perm
.
append
(
cur_idx
)
perm
=
np
.
array
(
perm
)
if
interleave
:
interleave
=
np
.
array
([
4
,
0
,
5
,
1
,
6
,
2
,
7
,
3
])
perm
=
perm
.
reshape
((
-
1
,
8
))[:,
interleave
].
ravel
()
perm
=
torch
.
from_numpy
(
perm
)
return
perm
def
marlin_weights
(
q_w
,
weight_perm
,
k_tile
=
32
,
n_tile
=
64
,
pack_factor
=
8
):
size_k
,
size_n
=
q_w
.
shape
q_w
=
q_w
.
reshape
((
size_k
//
k_tile
,
k_tile
,
size_n
//
n_tile
,
n_tile
))
q_w
=
q_w
.
permute
((
0
,
2
,
1
,
3
))
q_w
=
q_w
.
reshape
((
size_k
//
k_tile
,
size_n
*
k_tile
))
q_w
=
q_w
.
reshape
((
-
1
,
weight_perm
.
numel
()))[:,
weight_perm
].
reshape
(
q_w
.
shape
)
orig_device
=
q_w
.
device
q_w
=
q_w
.
cpu
().
numpy
().
astype
(
np
.
uint32
)
q_packed
=
np
.
zeros
((
q_w
.
shape
[
0
],
q_w
.
shape
[
1
]
//
pack_factor
),
dtype
=
np
.
uint32
)
for
i
in
range
(
pack_factor
):
q_packed
|=
q_w
[:,
i
::
pack_factor
]
<<
4
*
i
q_packed
=
torch
.
from_numpy
(
q_packed
.
astype
(
np
.
int32
)).
to
(
orig_device
)
return
q_packed
def
w4a8_2_marlin_weight
(
w4a8_w
):
full_w4a8_w
=
unpack_int8_to_int4
(
w4a8_w
)
full_w4a8_w
=
full_w4a8_w
.
T
weight_perm
=
get_weight_perms
()
marlin_q_w
=
marlin_weights
(
full_w4a8_w
,
weight_perm
,
k_tile
=
32
,
n_tile
=
64
,
pack_factor
=
8
)
return
marlin_q_w
vllm/model_executor/models/deepseek_v2.py
View file @
69f30ae0
...
...
@@ -67,7 +67,7 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
os
.
environ
[
'DPSK_FP16_QUICK'
]
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
,
'
1
'
)
os
.
environ
[
'DPSK_FP16_QUICK'
]
=
os
.
environ
.
get
(
'DPSK_FP16_QUICK'
,
'
0
'
)
class
DeepseekV2MLP
(
nn
.
Module
):
def
__init__
(
...
...
@@ -622,9 +622,13 @@ class DeepseekV2DecoderLayer(nn.Module):
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
# Fix residual FP16 overflow
residual_fix_overflow
=
False
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
residual_fix_overflow
=
True
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
...
...
@@ -640,7 +644,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
hidden_states
*=
1.
/
self
.
routed_scaling_factor
if
self
.
layer_idx
==
0
:
if
self
.
layer_idx
==
0
or
residual_fix_overflow
:
# The residual is shared by all layers, we only scale it on
# first layer.
residual
*=
1.
/
self
.
routed_scaling_factor
...
...
@@ -778,14 +782,17 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
self
.
num_expert_groups
=
config
.
n_group
self
.
moe_layers
:
list
[
FusedMoE
]
=
[]
example_moe
=
None
for
layer
in
self
.
model
.
layers
:
if
isinstance
(
layer
,
PPMissingLayer
):
continue
assert
isinstance
(
layer
,
DeepseekV2DecoderLayer
)
if
isinstance
(
layer
.
mlp
,
DeepseekV2MoE
):
example_moe
=
layer
.
mlp
self
.
moe_layers
.
append
(
layer
.
mlp
.
experts
)
# Pick last one layer since the first ones may be dense layers.
example_moe
=
typing
.
cast
(
DeepseekV2MoE
,
self
.
model
.
layers
[
config
.
num_hidden_layers
-
1
].
mlp
)
self
.
num_logical_experts
=
example_moe
.
n_logical_experts
self
.
num_physical_experts
=
example_moe
.
n_physical_experts
self
.
num_local_physical_experts
=
example_moe
.
n_local_physical_experts
...
...
vllm/perf/benchmark_moe.py
0 → 100644
View file @
69f30ae0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
argparse
import
json
import
time
from
contextlib
import
nullcontext
from
datetime
import
datetime
from
itertools
import
product
from
typing
import
Any
,
TypedDict
,
Optional
import
ray
import
torch
from
ray.experimental.tqdm_ray
import
tqdm
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.transformers_utils.config
import
get_config
from
vllm.triton_utils
import
triton
from
vllm.utils
import
FlexibleArgumentParser
# 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype()
class
BenchmarkConfig
(
TypedDict
):
BLOCK_SIZE_M
:
int
BLOCK_SIZE_N
:
int
BLOCK_SIZE_K
:
int
GROUP_SIZE_M
:
int
num_warps
:
int
num_stages
:
int
num_ldmatrixes
:
Optional
[
int
]
def
benchmark_config
(
config
:
BenchmarkConfig
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
num_iters
:
int
=
100
,
block_quant_shape
:
list
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
float
:
from
vllm.platforms
import
current_platform
device
=
torch
.
cuda
.
current_device
()
init_dtype
=
torch
.
float16
if
use_fp8_w8a8
else
dtype
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
device
)
if
use_int8_w8a16
:
if
not
nn_moe
:
w1
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
),
dtype
=
torch
.
int8
,
device
=
device
,
)
w2
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
),
dtype
=
torch
.
int8
,
device
=
device
,
)
else
:
w1
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
hidden_size
,
shard_intermediate_size
,
),
dtype
=
torch
.
int8
,
device
=
device
,
)
w2
=
torch
.
randint
(
-
127
,
127
,
(
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
,
),
dtype
=
torch
.
int8
,
device
=
device
,
)
else
:
if
not
nn_moe
:
w1
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
,
hidden_size
,
dtype
=
init_dtype
,
device
=
device
)
w2
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
//
2
,
dtype
=
init_dtype
,
device
=
device
)
else
:
w1
=
torch
.
randn
(
num_experts
,
hidden_size
,
shard_intermediate_size
,
dtype
=
init_dtype
,
device
=
device
)
w2
=
torch
.
randn
(
num_experts
,
shard_intermediate_size
//
2
,
hidden_size
,
dtype
=
init_dtype
,
device
=
device
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
w1_scale
=
None
w2_scale
=
None
a1_scale
=
None
a2_scale
=
None
if
use_int8_w8a16
:
w1_scale
=
torch
.
randn
(
(
num_experts
,
2
*
shard_intermediate_size
),
dtype
=
torch
.
float32
,
device
=
device
)
w2_scale
=
torch
.
randn
((
hidden_size
,
num_experts
),
dtype
=
torch
.
float32
,
device
=
device
)
if
use_fp8_w8a8
:
if
block_quant_shape
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
E
=
num_experts
N
=
shard_intermediate_size
//
2
K
=
hidden_size
factor_for_scale
=
1e-2
n_tiles_w1
=
(
2
*
N
+
block_n
-
1
)
//
block_n
n_tiles_w2
=
(
K
+
block_n
-
1
)
//
block_n
k_tiles_w1
=
(
K
+
block_k
-
1
)
//
block_k
k_tiles_w2
=
(
N
+
block_k
-
1
)
//
block_k
w1_scale
=
(
torch
.
rand
((
E
,
n_tiles_w1
,
k_tiles_w1
),
dtype
=
torch
.
float32
,
device
=
device
)
*
factor_for_scale
)
w2_scale
=
(
torch
.
rand
((
E
,
n_tiles_w2
,
k_tiles_w2
),
dtype
=
torch
.
float32
,
device
=
device
)
*
factor_for_scale
)
else
:
w1_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
w2_scale
=
torch
.
randn
(
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
,
device
=
device
)
# 获取 FP8_DTYPE
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
w1
=
w1
.
to
(
FP8_DTYPE
)
w2
=
w2
.
to
(
FP8_DTYPE
)
input_gating
=
torch
.
empty
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
,
device
=
device
)
def
prepare
(
i
:
int
):
input_gating
.
copy_
(
gating_output
[
i
])
def
run
():
from
vllm.model_executor.layers.fused_moe
import
override_config
with
override_config
(
config
):
if
use_deep_gemm
:
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
x
,
input_gating
,
topk
,
False
)
return
fused_experts
(
x
,
w1
,
w2
,
topk_weights
,
topk_ids
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_quant_shape
,
allow_deep_gemm
=
True
,
use_nn_moe
=
nn_moe
,
)
else
:
fused_moe
(
x
,
w1
,
w2
,
input_gating
,
topk
,
renormalize
=
True
,
inplace
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a16
=
use_int8_w8a16
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_quant_shape
,
use_nn_moe
=
nn_moe
,
)
# JIT compilation & warmup
run
()
torch
.
cuda
.
synchronize
()
# Capture 10 invocations with CUDA graph
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
for
_
in
range
(
10
):
run
()
torch
.
cuda
.
synchronize
()
# Warmup
for
_
in
range
(
5
):
graph
.
replay
()
# run()
torch
.
cuda
.
synchronize
()
start_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
end_event
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
latencies
:
list
[
float
]
=
[]
for
i
in
range
(
num_iters
):
prepare
(
i
)
torch
.
cuda
.
synchronize
()
start_event
.
record
()
graph
.
replay
()
# run()
end_event
.
record
()
end_event
.
synchronize
()
latencies
.
append
(
start_event
.
elapsed_time
(
end_event
))
avg
=
sum
(
latencies
)
/
(
num_iters
*
10
)
*
1000
# us
graph
.
reset
()
return
avg
def
get_rocm_tuning_space
(
use_fp16
,
nn_moe
:
Optional
[
bool
]
=
False
):
block_m_range
=
[
16
,
32
,
64
,
128
,
256
]
block_n_range
=
[
32
,
64
,
128
,
256
]
block_k_range
=
[
32
,
64
,
128
,
256
]
if
not
use_fp16
:
block_k_range
.
remove
(
16
)
# BLOCK_K=16 not supported for fp8
num_warps_range
=
[
2
,
4
,
8
]
group_m_range
=
[
1
,
16
,
32
,
64
]
num_stage_range
=
[
2
,
3
,
4
,
5
]
# waves_per_eu_range = [0]
# matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
# kpack_range = [1, 2] if use_fp16 else []
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_m_range
,
"BLOCK_SIZE_N"
:
block_n_range
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stage_range
,
# "waves_per_eu": waves_per_eu_range,
}
if
nn_moe
:
param_ranges
[
"num_ldmatrixes"
]
=
[
1
]
# DCU currently does not support the following parameters
# if use_fp16:
# param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
# param_ranges["kpack"] = kpack_range
return
param_ranges
def
get_configs_compute_bound
(
use_fp16
,
block_quant_shape
,
nn_moe
:
Optional
[
bool
]
=
False
)
->
list
[
dict
[
str
,
int
]]:
configs
:
list
[
BenchmarkConfig
]
=
[]
# 局部导入 current_platform
from
vllm.platforms
import
current_platform
if
current_platform
.
is_rocm
():
param_ranges
=
get_rocm_tuning_space
(
use_fp16
,
nn_moe
)
else
:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
block_m_range
=
[
16
,
32
,
64
,
128
,
256
]
block_n_range
=
[
32
,
64
,
128
,
256
]
block_k_range
=
[
64
,
128
,
256
]
num_warps_range
=
[
4
,
8
]
group_m_range
=
[
1
,
16
,
32
,
64
]
num_stage_range
=
[
2
,
3
,
4
,
5
]
param_ranges
=
{
"BLOCK_SIZE_M"
:
block_m_range
,
"BLOCK_SIZE_N"
:
block_n_range
,
"BLOCK_SIZE_K"
:
block_k_range
,
"GROUP_SIZE_M"
:
group_m_range
,
"num_warps"
:
num_warps_range
,
"num_stages"
:
num_stage_range
,
}
keys
,
values
=
zip
(
*
param_ranges
.
items
())
for
config_values
in
product
(
*
values
):
config
=
dict
(
zip
(
keys
,
config_values
))
configs
.
append
(
config
)
# Remove configs that are not compatible with fp8 block quantization
# BLOCK_SIZE_K must be a multiple of block_k
# BLOCK_SIZE_N must be a multiple of block_n
if
block_quant_shape
is
not
None
and
not
use_fp16
:
block_n
,
block_k
=
block_quant_shape
[
0
],
block_quant_shape
[
1
]
for
config
in
configs
[:]:
if
(
config
[
"BLOCK_SIZE_K"
]
%
block_k
!=
0
or
config
[
"BLOCK_SIZE_N"
]
%
block_n
!=
0
):
configs
.
remove
(
config
)
return
configs
def
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
,
topk
):
N1
,
K1
=
shard_intermediate_size
,
hidden_size
N2
,
K2
=
hidden_size
,
shard_intermediate_size
//
2
pruned_space_1
=
prune_rocm_configs
(
num_tokens
*
topk
,
N1
,
K1
,
search_space
,
is_fp16
)
pruned_space_2
=
prune_rocm_configs
(
num_tokens
*
topk
,
N2
,
K2
,
search_space
,
is_fp16
)
search_space
=
merge_unique_dicts
(
pruned_space_1
,
pruned_space_2
)
return
search_space
# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def
prune_rocm_configs
(
M
,
N
,
K
,
configs
,
is_fp16
=
True
):
pruned_configs
=
[]
elemBytes_a
=
2
if
is_fp16
else
1
elemBytes_b
=
2
if
is_fp16
else
1
mfma
=
16
if
M
<
32
or
N
<
32
else
32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm
=
False
if
M
>=
2048
and
N
>=
2048
:
large_gemm
=
True
for
config
in
configs
:
BLOCK_SIZE_M
=
config
.
get
(
"BLOCK_SIZE_M"
)
BLOCK_SIZE_N
=
config
.
get
(
"BLOCK_SIZE_N"
)
BLOCK_SIZE_K
=
config
.
get
(
"BLOCK_SIZE_K"
)
num_warps
=
config
.
get
(
"num_warps"
)
# DCU currently does not support matrix_instr_nonkdim param
# if is_fp16:
# matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
# if matrix_instr_nonkdim > mfma:
# continue
if
mfma
==
4
and
BLOCK_SIZE_K
<
64
:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if
BLOCK_SIZE_M
*
BLOCK_SIZE_N
<
64
:
continue
SPLIT_K
=
config
.
get
(
"SPLIT_K"
,
1
)
GROUP_M
=
config
.
get
(
"GROUP_SIZE_M"
)
# DCU currently does not support matrix_instr_nonkdim param
# if is_fp16:
# if (
# matrix_instr_nonkdim > BLOCK_SIZE_M
# or matrix_instr_nonkdim > BLOCK_SIZE_N
# ):
# continue
# if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
# continue
# if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
# continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if
M
*
2
<
BLOCK_SIZE_M
and
BLOCK_SIZE_M
!=
16
:
continue
if
N
*
2
<
BLOCK_SIZE_N
and
BLOCK_SIZE_N
!=
16
:
continue
# skip large split_k when not necessary
if
SPLIT_K
!=
1
and
not
need_split_k
(
M
,
N
,
K
):
continue
# skip split_k that leads to EVEN_K = false
leap
=
SPLIT_K
*
BLOCK_SIZE_K
modv
=
K
%
leap
if
modv
!=
0
:
continue
# skip large GROUP_M
if
GROUP_M
*
BLOCK_SIZE_M
>
M
and
GROUP_M
!=
1
:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS
=
(
BLOCK_SIZE_K
*
BLOCK_SIZE_M
*
elemBytes_a
+
BLOCK_SIZE_K
*
BLOCK_SIZE_N
*
elemBytes_b
)
if
LDS
>
65536
:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if
large_gemm
:
if
BLOCK_SIZE_M
<
64
or
BLOCK_SIZE_N
<
64
:
continue
if
BLOCK_SIZE_K
<
64
:
continue
if
num_warps
<
4
:
continue
pruned_configs
.
append
(
config
)
return
pruned_configs
def
need_split_k
(
SIZE_M
,
SIZE_N
,
SIZE_K
):
return
(
SIZE_M
<
64
or
SIZE_N
<
64
)
and
SIZE_K
>
1024
def
merge_unique_dicts
(
list1
,
list2
):
result
=
[]
combined_list
=
list1
.
copy
()
combined_list
.
extend
(
list2
)
for
dictionary
in
combined_list
:
if
dictionary
not
in
result
:
result
.
append
(
dictionary
)
return
result
@
ray
.
remote
(
num_gpus
=
1
)
class
BenchmarkWorker
:
def
__init__
(
self
,
seed
:
int
,
device_id
:
int
)
->
None
:
from
vllm.platforms
import
current_platform
import
os
if
current_platform
.
is_rocm
():
# In ROCm environment with Ray, let Ray handle device assignment
# Don't manually set default device as it may conflict with Ray's device mapping
pass
else
:
torch
.
set_default_device
(
"cuda:"
+
str
(
device_id
))
current_platform
.
seed_everything
(
seed
)
self
.
seed
=
seed
# Store the logical device ID for Ray
self
.
device_id
=
device_id
def
benchmark
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
list
[
int
]
=
None
,
use_deep_gemm
:
bool
=
False
,
nn_moe
:
Optional
[
bool
]
=
False
,
)
->
tuple
[
dict
[
str
,
int
],
float
]:
# 局部导入 current_platform
from
vllm.platforms
import
current_platform
current_platform
.
seed_everything
(
self
.
seed
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
get_config_dtype_str
,
get_moe_configs
,
get_default_config
)
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config
=
get_moe_configs
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
,
use_nn_moe
=
nn_moe
)
if
op_config
is
None
:
config
=
get_default_config
(
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype_str
,
is_marlin
=
False
,
use_nn_moe
=
nn_moe
)
else
:
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
100
,
block_quant_shape
=
block_quant_shape
,
use_deep_gemm
=
use_deep_gemm
,
use_nn_moe
=
nn_moe
)
return
config
,
kernel_time
def
tune
(
self
,
num_tokens
:
int
,
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
search_space
:
list
[
dict
[
str
,
int
]],
block_quant_shape
:
list
[
int
],
use_deep_gemm
:
bool
,
nn_moe
:
Optional
[
bool
]
=
False
,
)
->
dict
[
str
,
int
]:
from
vllm.platforms
import
current_platform
import
os
best_config
=
None
best_time
=
float
(
"inf"
)
if
current_platform
.
is_rocm
():
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
)
search_space
=
prune_rocm_search_space
(
num_tokens
,
shard_intermediate_size
,
hidden_size
,
search_space
,
is_fp16
,
topk
,
)
# In ROCm environments with Ray, device context is already handled by Ray
# Using torch.cuda.device() may cause device ordinal conflicts
need_device_guard
=
False
if
current_platform
.
is_rocm
():
# For ROCm with Ray, skip additional device context management
need_device_guard
=
False
else
:
# For other platforms, use device guard if needed
visible_devices
=
os
.
environ
.
get
(
"CUDA_VISIBLE_DEVICES"
,
None
)
if
visible_devices
is
not
None
and
len
(
visible_devices
.
split
(
','
))
>
1
:
need_device_guard
=
True
with
torch
.
cuda
.
device
(
self
.
device_id
)
if
need_device_guard
else
nullcontext
():
for
config
in
tqdm
(
search_space
):
try
:
kernel_time
=
benchmark_config
(
config
,
num_tokens
,
num_experts
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
num_iters
=
20
,
block_quant_shape
=
block_quant_shape
,
use_deep_gemm
=
use_deep_gemm
,
nn_moe
=
nn_moe
)
except
triton
.
runtime
.
autotuner
.
OutOfResources
:
# Some configurations may be invalid and fail to compile.
continue
if
kernel_time
<
best_time
:
best_time
=
kernel_time
best_config
=
config
now
=
datetime
.
now
()
print
(
f
"
{
now
.
ctime
()
}
] Completed tuning for batch_size=
{
num_tokens
}
"
)
assert
best_config
is
not
None
return
best_config
def
sort_config
(
config
:
BenchmarkConfig
)
->
BenchmarkConfig
:
return
{
"BLOCK_SIZE_M"
:
config
[
"BLOCK_SIZE_M"
],
"BLOCK_SIZE_N"
:
config
[
"BLOCK_SIZE_N"
],
"BLOCK_SIZE_K"
:
config
[
"BLOCK_SIZE_K"
],
"GROUP_SIZE_M"
:
config
[
"GROUP_SIZE_M"
],
"num_warps"
:
config
[
"num_warps"
],
"num_stages"
:
config
[
"num_stages"
],
**
(
{
"num_ldmatrixes"
:
config
[
"num_ldmatrixes"
]}
if
"num_ldmatrixes"
in
config
else
{}
),
**
(
{
"waves_per_eu"
:
config
[
"waves_per_eu"
]}
if
"waves_per_eu"
in
config
else
{}
),
**
(
{
"matrix_instr_nonkdim"
:
config
[
"matrix_instr_nonkdim"
]}
if
"matrix_instr_nonkdim"
in
config
else
{}
),
**
({
"kpack"
:
config
[
"kpack"
]}
if
"kpack"
in
config
else
{}),
}
def
save_configs
(
configs
:
dict
[
int
,
BenchmarkConfig
],
num_experts
:
int
,
shard_intermediate_size
:
int
,
hidden_size
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
use_fp8_w8a8
:
bool
,
use_int8_w8a16
:
bool
,
block_quant_shape
:
list
[
int
],
use_nn_moe
:
Optional
[
bool
]
=
False
,
)
->
None
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
get_config_dtype_str
,
get_config_file_name
)
dtype_str
=
get_config_dtype_str
(
dtype
,
use_int8_w8a16
=
use_int8_w8a16
,
use_fp8_w8a8
=
use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename
=
get_config_file_name
(
num_experts
,
shard_intermediate_size
//
2
,
dtype_str
,
block_quant_shape
,
use_nn_moe
=
use_nn_moe
)
print
(
f
"Writing best config to
{
filename
}
..."
)
with
open
(
filename
,
"w"
)
as
f
:
json
.
dump
(
configs
,
f
,
indent
=
4
)
f
.
write
(
"
\n
"
)
def
get_weight_block_size_safety
(
config
,
default_value
=
None
):
quantization_config
=
getattr
(
config
,
"quantization_config"
,
{})
if
isinstance
(
quantization_config
,
dict
):
return
quantization_config
.
get
(
"weight_block_size"
,
default_value
)
return
default_value
def
main
(
args
:
argparse
.
Namespace
):
import
os
import
logging
from
vllm.platforms
import
current_platform
logger
=
logging
.
getLogger
(
__name__
)
print
(
args
)
tp_size
=
args
.
tp_size
config
=
get_config
(
model
=
args
.
model
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
model_prefix
:
config
=
getattr
(
config
,
args
.
model_prefix
)
if
config
.
architectures
[
0
]
==
"DbrxForCausalLM"
:
E
=
config
.
ffn_config
.
moe_num_experts
topk
=
config
.
ffn_config
.
moe_top_k
intermediate_size
=
config
.
ffn_config
.
ffn_hidden_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
==
"JambaForCausalLM"
:
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
(
"DeepseekV3ForCausalLM"
,
"DeepseekV2ForCausalLM"
,
"Glm4MoeForCausalLM"
):
E
=
config
.
n_routed_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
(
"Qwen2MoeForCausalLM"
,
"Qwen3MoeForCausalLM"
):
E
=
config
.
num_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
elif
config
.
architectures
[
0
]
in
(
"Step3VLForConditionalGeneration"
):
E
=
config
.
text_config
.
moe_num_experts
topk
=
config
.
text_config
.
moe_top_k
intermediate_size
=
config
.
text_config
.
moe_intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
else
:
# Support for llama4
config
=
config
.
get_text_config
()
# Default: Mixtral.
E
=
config
.
num_local_experts
topk
=
config
.
num_experts_per_tok
intermediate_size
=
config
.
intermediate_size
shard_intermediate_size
=
2
*
intermediate_size
//
tp_size
hidden_size
=
config
.
hidden_size
dtype
=
torch
.
float16
if
current_platform
.
is_rocm
()
else
config
.
torch_dtype
use_fp8_w8a8
=
args
.
dtype
==
"fp8_w8a8"
use_int8_w8a16
=
args
.
dtype
==
"int8_w8a16"
block_quant_shape
=
get_weight_block_size_safety
(
config
)
if
args
.
batch_size
is
None
:
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
24
,
32
,
48
,
64
,
96
,
128
,
256
,
512
,
1024
,
1536
,
2048
,
3072
,
4096
,
]
else
:
batch_sizes
=
args
.
batch_size
use_deep_gemm
=
bool
(
args
.
use_deep_gemm
)
if
current_platform
.
is_rocm
()
and
"HIP_VISIBLE_DEVICES"
in
os
.
environ
:
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
logger
.
warning
(
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
)
val
=
os
.
environ
[
"HIP_VISIBLE_DEVICES"
]
os
.
environ
[
"ROCR_VISIBLE_DEVICES"
]
=
val
del
os
.
environ
[
"HIP_VISIBLE_DEVICES"
]
ray
.
init
(
address
=
None
,
ignore_reinit_error
=
True
,
num_gpus
=
args
.
num_gpus
)
num_gpus
=
int
(
ray
.
available_resources
()[
"GPU"
])
workers
=
[
BenchmarkWorker
.
remote
(
args
.
seed
,
i
)
for
i
in
range
(
num_gpus
)]
def
_distribute
(
method
:
str
,
inputs
:
list
[
Any
])
->
list
[
Any
]:
outputs
=
[]
worker_idx
=
0
for
input_args
in
inputs
:
worker
=
workers
[
worker_idx
]
worker_method
=
getattr
(
worker
,
method
)
output
=
worker_method
.
remote
(
*
input_args
)
outputs
.
append
(
output
)
worker_idx
=
(
worker_idx
+
1
)
%
num_gpus
return
ray
.
get
(
outputs
)
if
args
.
tune
:
is_fp16
=
not
(
use_fp8_w8a8
or
use_int8_w8a16
)
search_space
=
get_configs_compute_bound
(
is_fp16
,
block_quant_shape
,
args
.
nn_moe
)
print
(
f
"Start tuning over
{
len
(
search_space
)
}
configurations..."
)
start
=
time
.
time
()
configs
=
_distribute
(
"tune"
,
[
(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
search_space
,
block_quant_shape
,
use_deep_gemm
,
args
.
nn_moe
,
)
for
batch_size
in
batch_sizes
],
)
best_configs
=
{
M
:
sort_config
(
config
)
for
M
,
config
in
zip
(
batch_sizes
,
configs
)
}
save_configs
(
best_configs
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
,
use_nn_moe
=
args
.
nn_moe
,
)
end
=
time
.
time
()
print
(
f
"Tuning took
{
end
-
start
:.
2
f
}
seconds"
)
else
:
outputs
=
_distribute
(
"benchmark"
,
[
(
batch_size
,
E
,
shard_intermediate_size
,
hidden_size
,
topk
,
dtype
,
use_fp8_w8a8
,
use_int8_w8a16
,
block_quant_shape
,
use_deep_gemm
,
args
.
nn_moe
,
)
for
batch_size
in
batch_sizes
],
)
for
batch_size
,
(
config
,
kernel_time
)
in
zip
(
batch_sizes
,
outputs
):
print
(
f
"Batch size:
{
batch_size
}
, config:
{
config
}
"
)
print
(
f
"Kernel time:
{
kernel_time
:.
2
f
}
us"
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
()
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
"--tensor-parallel-size"
,
type
=
int
,
default
=
2
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"auto"
,
"fp8_w8a8"
,
"int8_w8a16"
],
default
=
"auto"
)
parser
.
add_argument
(
"--use-deep-gemm"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
nargs
=
"+"
,
required
=
False
)
parser
.
add_argument
(
"--tune"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--nn-moe"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--trust-remote-code"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--model-prefix"
,
type
=
str
,
required
=
False
)
parser
.
add_argument
(
"--num-gpus"
,
type
=
int
,
default
=
1
)
args
=
parser
.
parse_args
()
main
(
args
)
vllm/platforms/rocm.py
View file @
69f30ae0
...
...
@@ -180,7 +180,7 @@ class RocmPlatform(Platform):
supported_quantization
:
list
[
str
]
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
,
"quark"
,
"ptpc_fp8"
,
"moe_wna16"
,
"blockwise_int8"
,
"slimquant_w4a8"
,
"awq_marlin"
"quark"
,
"ptpc_fp8"
,
"moe_wna16"
,
"blockwise_int8"
,
"slimquant_w4a8"
,
"awq_marlin"
,
"slimquant_w4a8_marlin"
]
@
classmethod
...
...
@@ -282,6 +282,10 @@ class RocmPlatform(Platform):
logger
.
info_once
(
"Using Triton backend on V1 engine."
)
return
TRITON_ATTN_VLLM_V1
if
selected_backend
==
_Backend
.
DUAL_CHUNK_FLASH_ATTN
:
logger
.
info
(
"Using DualChunkFlashAttention backend."
)
return
(
"vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend"
)
if
selected_backend
==
_Backend
.
ROCM_FLASH
:
if
not
cls
.
has_device_capability
(
90
):
# not Instinct series GPUs.
...
...
vllm/zero_overhead/v1/core.py
View file @
69f30ae0
...
...
@@ -177,11 +177,11 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop.
for
request
in
scheduler
.
running
:
req_id
=
request
.
request_id
if
request
.
is_finished
():
if
req_id
in
requsets_valid_token_len
:
requsets_valid_token_len
.
pop
(
req_id
)
continue
req_id
=
request
.
request_id
num_tokens_scheduled
=
num_scheduled_tokens
.
get
(
req_id
,
0
)
if
num_tokens_scheduled
==
0
:
# The request was not scheduled in this step.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment