Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
533d4c72
Commit
533d4c72
authored
Aug 27, 2025
by
jujl1
Browse files
feat: w4a8 marlin合入主分支, 通过-q slimquant_w4a8_marlin 可以启用
parent
d0181e5a
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
300 additions
and
48 deletions
+300
-48
vllm/config.py
vllm/config.py
+3
-1
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+3
-3
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+4
-1
vllm/model_executor/layers/quantization/slimquant_w4a8.py
vllm/model_executor/layers/quantization/slimquant_w4a8.py
+5
-40
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
...del_executor/layers/quantization/slimquant_w4a8_marlin.py
+282
-0
vllm/model_executor/layers/quantization/utils/w4a8_utils.py
vllm/model_executor/layers/quantization/utils/w4a8_utils.py
+2
-2
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-1
No files found.
vllm/config.py
View file @
533d4c72
...
@@ -893,7 +893,8 @@ class ModelConfig:
...
@@ -893,7 +893,8 @@ class ModelConfig:
optimized_quantization_methods
=
[
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed-tensors"
,
"experts_int8"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed-tensors"
,
"experts_int8"
,
"quark"
,
"modelopt_fp4"
,
"bitblas"
,
"gptq_bitblas"
"quark"
,
"modelopt_fp4"
,
"bitblas"
,
"gptq_bitblas"
,
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
]
]
if
self
.
quantization
is
not
None
:
if
self
.
quantization
is
not
None
:
self
.
quantization
=
cast
(
me_quant
.
QuantizationMethods
,
self
.
quantization
=
cast
(
me_quant
.
QuantizationMethods
,
...
@@ -920,6 +921,7 @@ class ModelConfig:
...
@@ -920,6 +921,7 @@ class ModelConfig:
"awq_marlin"
,
"awq_marlin"
,
"ipex"
,
"ipex"
,
"moe_wna16"
,
"moe_wna16"
,
"slimquant_w4a8_marlin"
]
]
quantization_methods
=
[
quantization_methods
=
[
q
for
q
in
supported_quantization
if
q
not
in
overrides
q
for
q
in
supported_quantization
if
q
not
in
overrides
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
533d4c72
...
@@ -819,9 +819,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -819,9 +819,9 @@ class FusedMoE(torch.nn.Module):
"CompressedTensorsWNA16MoEMethod"
)):
"CompressedTensorsWNA16MoEMethod"
)):
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
moe_quant_params
[
"intermediate_size_full"
]
=
intermediate_size
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"BlockInt8MoEMethod"
)):
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"BlockInt8MoEMethod"
,
moe_quant_params
[
"intermediate_size"
]
=
self
.
intermediate_size_per_partition
"SlimQuantW4A8Int8MoEMethod"
,
if
(
self
.
quant_method
.
__class__
.
__name__
in
(
"SlimQuantW4A8Int8MoEMethod"
)):
"SlimQuantW4A8Int8M
arlinM
oEMethod"
)):
moe_quant_params
[
"intermediate_size"
]
=
self
.
intermediate_size_per_partition
moe_quant_params
[
"intermediate_size"
]
=
self
.
intermediate_size_per_partition
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
533d4c72
...
@@ -37,7 +37,8 @@ QuantizationMethods = Literal[
...
@@ -37,7 +37,8 @@ QuantizationMethods = Literal[
"auto-round"
,
"auto-round"
,
"rtn"
,
"rtn"
,
"blockwise_int8"
,
"blockwise_int8"
,
"slimquant_w4a8"
"slimquant_w4a8"
,
"slimquant_w4a8_marlin"
]
]
QUANTIZATION_METHODS
:
list
[
str
]
=
list
(
get_args
(
QuantizationMethods
))
QUANTIZATION_METHODS
:
list
[
str
]
=
list
(
get_args
(
QuantizationMethods
))
...
@@ -118,6 +119,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -118,6 +119,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from
.tpu_int8
import
Int8TpuConfig
from
.tpu_int8
import
Int8TpuConfig
from
.blockwise_int8
import
BlockInt8Config
from
.blockwise_int8
import
BlockInt8Config
from
.slimquant_w4a8
import
SlimQuantW4A8Int8Config
from
.slimquant_w4a8
import
SlimQuantW4A8Int8Config
from
.slimquant_w4a8_marlin
import
SlimQuantW4A8Int8MarlinConfig
method_to_config
:
dict
[
str
,
type
[
QuantizationConfig
]]
=
{
method_to_config
:
dict
[
str
,
type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"aqlm"
:
AQLMConfig
,
...
@@ -151,6 +153,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -151,6 +153,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"rtn"
:
RTNConfig
,
"rtn"
:
RTNConfig
,
"blockwise_int8"
:
BlockInt8Config
,
"blockwise_int8"
:
BlockInt8Config
,
"slimquant_w4a8"
:
SlimQuantW4A8Int8Config
,
"slimquant_w4a8"
:
SlimQuantW4A8Int8Config
,
"slimquant_w4a8_marlin"
:
SlimQuantW4A8Int8MarlinConfig
,
}
}
# Update the `method_to_config` with customized quantization methods.
# Update the `method_to_config` with customized quantization methods.
method_to_config
.
update
(
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
)
method_to_config
.
update
(
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG
)
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8.py
View file @
533d4c72
...
@@ -7,7 +7,6 @@ from torch.nn.parameter import Parameter
...
@@ -7,7 +7,6 @@ from torch.nn.parameter import Parameter
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.w4a8_utils
import
w4a8_2_marlin_weight
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
...
@@ -19,29 +18,10 @@ from lmslim.layers.gemm.int8_utils import (
...
@@ -19,29 +18,10 @@ from lmslim.layers.gemm.int8_utils import (
per_token_quant_int8
)
per_token_quant_int8
)
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.utils
import
W8a8GetCacheJSON
from
vllm.utils
import
W8a8GetCacheJSON
import
vllm.envs
as
envs
import
os
import
os
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
try
:
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
workspace
=
None
global_reduce_buffer
=
None
def
get_marlin_moe_workspace
(
device
):
global
workspace
global
global_reduce_buffer
if
workspace
is
None
:
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
workspace
=
torch
.
zeros
(
500
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
if
global_reduce_buffer
is
None
:
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
global_reduce_buffer
=
torch
.
zeros
(
sms
*
6
*
128
*
512
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
return
workspace
,
global_reduce_buffer
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
...
@@ -339,27 +319,14 @@ class SlimQuantW4A8Int8MoEMethod:
...
@@ -339,27 +319,14 @@ class SlimQuantW4A8Int8MoEMethod:
if
configs_dict
:
if
configs_dict
:
self
.
tritonsingleton
.
triton_moejson_dict
.
update
(
configs_dict
)
self
.
tritonsingleton
.
triton_moejson_dict
.
update
(
configs_dict
)
#
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
,
requires_grad
=
False
)
#
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
)
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
)
w1_marlin_list
=
[]
for
e
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w1_marlin_in
=
w4a8_2_marlin_weight
(
layer
.
w13_weight
[
e
])
w1_marlin_list
.
append
(
w1_marlin_in
)
layer
.
w13_weight
=
Parameter
(
torch
.
stack
(
w1_marlin_list
,
dim
=
0
),
requires_grad
=
False
)
w2_marlin_list
=
[]
for
e
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w2_marlin_in
=
w4a8_2_marlin_weight
(
layer
.
w2_weight
[
e
])
w2_marlin_list
.
append
(
w2_marlin_in
)
layer
.
w2_weight
=
Parameter
(
torch
.
stack
(
w2_marlin_list
,
dim
=
0
),
requires_grad
=
False
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -403,15 +370,13 @@ class SlimQuantW4A8Int8MoEMethod:
...
@@ -403,15 +370,13 @@ class SlimQuantW4A8Int8MoEMethod:
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
use_fused_gate
=
use_fused_gate
)
)
workspace
,
global_reduce_buffer
=
get_marlin_moe_workspace
(
device
=
x
.
device
)
return
fused_experts
_impl_w4a8_marlin
(
return
fused_experts
(
x
,
x
,
layer
.
w13_weight
,
layer
.
w13_weight
,
layer
.
w2_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
workspace
=
workspace
,
global_reduce_buffer
=
global_reduce_buffer
,
inplace
=
True
,
inplace
=
True
,
use_int4_w4a8
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
per_channel_quant
=
True
,
...
...
vllm/model_executor/layers/quantization/slimquant_w4a8_marlin.py
0 → 100644
View file @
533d4c72
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
os
import
torch
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.distributed
import
get_tensor_model_parallel_world_size
from
torch.nn.parameter
import
Parameter
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.w4a8_utils
import
w4a8_2_marlin_weight
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
)
from
vllm.model_executor.layers.quantization.slimquant_w4a8
import
SlimQuantW4A8Int8LinearMethod
try
:
from
lmslim.layers.fused_moe.fuse_moe_w4a8_marlin
import
fused_experts_impl_w4a8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
class
MarlinMoeWorkspace
:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device
"""
_instances
=
{}
def
__new__
(
cls
,
device
):
if
device
not
in
cls
.
_instances
:
instance
=
super
().
__new__
(
cls
)
instance
.
_initialized
=
False
cls
.
_instances
[
device
]
=
instance
return
cls
.
_instances
[
device
]
def
__init__
(
self
,
device
):
if
self
.
_initialized
:
return
sms
=
torch
.
cuda
.
get_device_properties
(
device
).
multi_processor_count
self
.
workspace
=
torch
.
zeros
(
500
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
self
.
global_reduce_buffer
=
torch
.
zeros
(
sms
*
6
*
128
*
512
,
dtype
=
torch
.
int
,
device
=
device
,
requires_grad
=
False
)
self
.
_initialized
=
True
def
get_buffers
(
self
):
return
self
.
workspace
,
self
.
global_reduce_buffer
def
baseline_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
scales
=
scale_a
*
scale_b
.
T
gemmout
=
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))
output
=
(
scales
*
gemmout
).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
return
output
.
to
(
out_dtype
)
class
SlimQuantW4A8Int8MarlinConfig
(
QuantizationConfig
):
"""Config class for W4A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def
__init__
(
self
):
pass
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
float16
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
@
classmethod
def
get_name
(
self
)
->
str
:
return
"slimquant_w4a8_marlin"
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"SlimQuantW4A8Int8MarlinConfig"
:
return
cls
()
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
QuantizationMethods
]:
if
hf_quant_cfg
.
get
(
"quant_method"
)
==
"slimquant_w4a8"
\
and
user_quant
==
"slimquant_w4a8_marlin"
:
return
cls
.
get_name
()
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
,
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
return
SlimQuantW4A8Int8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
SlimQuantW4A8Int8MarlinMoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
class
SlimQuantW4A8Int8MarlinMoEMethod
:
"""MoE method for W4A8INT8 Marlin.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Args:
quant_config: The quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
):
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
tp_size
=
get_tensor_model_parallel_world_size
()
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
torch
.
int8
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
}
)
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
w13_input_scale
=
None
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
w2_input_scale
=
None
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
w1_marlin_list
=
[]
for
e
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w1_marlin_in
=
w4a8_2_marlin_weight
(
layer
.
w13_weight
[
e
])
w1_marlin_list
.
append
(
w1_marlin_in
)
layer
.
w13_weight
=
Parameter
(
torch
.
stack
(
w1_marlin_list
,
dim
=
0
),
requires_grad
=
False
)
w2_marlin_list
=
[]
for
e
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w2_marlin_in
=
w4a8_2_marlin_weight
(
layer
.
w2_weight
[
e
])
w2_marlin_list
.
append
(
w2_marlin_in
)
layer
.
w2_weight
=
Parameter
(
torch
.
stack
(
w2_marlin_list
,
dim
=
0
),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
**
_
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
use_fused_gate
=
use_fused_gate
)
workspace
,
global_reduce_buffer
=
MarlinMoeWorkspace
(
x
.
device
).
get_buffers
()
return
fused_experts_impl_w4a8_marlin
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
workspace
=
workspace
,
global_reduce_buffer
=
global_reduce_buffer
,
inplace
=
True
,
use_int4_w4a8
=
True
,
per_channel_quant
=
True
,
activation
=
activation
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
global_num_experts
=
global_num_experts
,
w1_scale
=
(
layer
.
w13_weight_scale
),
w2_scale
=
(
layer
.
w2_weight_scale
),
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
use_nn_moe
,
)
vllm/model_executor/layers/quantization/utils/w4a8_utils.py
View file @
533d4c72
...
@@ -18,8 +18,8 @@ def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
...
@@ -18,8 +18,8 @@ def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
N
,
K_half
=
tensor_int8
.
shape
N
,
K_half
=
tensor_int8
.
shape
tensor_uint8
=
tensor_int8
.
to
(
torch
.
uint8
)
tensor_uint8
=
tensor_int8
.
to
(
torch
.
uint8
)
low
4
=
tensor_uint8
&
0x0F
high
4
=
tensor_uint8
&
0x0F
high
4
=
(
tensor_uint8
>>
4
)
&
0x0F
low
4
=
(
tensor_uint8
>>
4
)
&
0x0F
unpacked
=
torch
.
empty
((
N
,
K_half
*
2
),
dtype
=
torch
.
int32
,
device
=
tensor_int8
.
device
)
unpacked
=
torch
.
empty
((
N
,
K_half
*
2
),
dtype
=
torch
.
int32
,
device
=
tensor_int8
.
device
)
unpacked
[:,
0
::
2
]
=
low4
.
to
(
torch
.
int32
)
unpacked
[:,
0
::
2
]
=
low4
.
to
(
torch
.
int32
)
unpacked
[:,
1
::
2
]
=
high4
.
to
(
torch
.
int32
)
unpacked
[:,
1
::
2
]
=
high4
.
to
(
torch
.
int32
)
...
...
vllm/platforms/rocm.py
View file @
533d4c72
...
@@ -180,7 +180,7 @@ class RocmPlatform(Platform):
...
@@ -180,7 +180,7 @@ class RocmPlatform(Platform):
supported_quantization
:
list
[
str
]
=
[
supported_quantization
:
list
[
str
]
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
,
"awq"
,
"gptq"
,
"fp8"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
,
"quark"
,
"ptpc_fp8"
,
"moe_wna16"
,
"blockwise_int8"
,
"slimquant_w4a8"
,
"awq_marlin"
"quark"
,
"ptpc_fp8"
,
"moe_wna16"
,
"blockwise_int8"
,
"slimquant_w4a8"
,
"awq_marlin"
,
"slimquant_w4a8_marlin"
]
]
@
classmethod
@
classmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment