Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9f771b3a
Unverified
Commit
9f771b3a
authored
Apr 24, 2026
by
Jinzhen Lin
Committed by
GitHub
Apr 24, 2026
Browse files
[Quantization] add humming quantization kernel (#34556)
parent
c9d3c6e6
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1948 additions
and
9 deletions
+1948
-9
vllm/config/model.py
vllm/config/model.py
+4
-0
vllm/envs.py
vllm/envs.py
+32
-0
vllm/model_executor/layers/fused_moe/fused_humming_moe.py
vllm/model_executor/layers/fused_moe/fused_humming_moe.py
+690
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+5
-1
vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py
vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py
+202
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+13
-6
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-0
vllm/model_executor/layers/quantization/humming.py
vllm/model_executor/layers/quantization/humming.py
+962
-0
vllm/model_executor/layers/quantization/utils/humming_moe_utils.py
...l_executor/layers/quantization/utils/humming_moe_utils.py
+35
-0
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+2
-2
No files found.
vllm/config/model.py
View file @
9f771b3a
...
@@ -953,8 +953,12 @@ class ModelConfig:
...
@@ -953,8 +953,12 @@ class ModelConfig:
"mxfp4"
,
"mxfp4"
,
"gpt_oss_mxfp4"
,
"gpt_oss_mxfp4"
,
"cpu_awq"
,
"cpu_awq"
,
"humming"
,
"gguf"
,
"gguf"
,
]
]
# if the user specifies humming, we should always use humming
if
self
.
quantization
==
"humming"
:
overrides
=
[
"humming"
]
+
overrides
quantization_methods
=
[
quantization_methods
=
[
q
for
q
in
supported_quantization
if
q
not
in
overrides
q
for
q
in
supported_quantization
if
q
not
in
overrides
]
]
...
...
vllm/envs.py
View file @
9f771b3a
...
@@ -152,6 +152,10 @@ if TYPE_CHECKING:
...
@@ -152,6 +152,10 @@ if TYPE_CHECKING:
VLLM_RAY_EXTRA_ENV_VARS_TO_COPY
:
str
=
""
VLLM_RAY_EXTRA_ENV_VARS_TO_COPY
:
str
=
""
VLLM_MARLIN_USE_ATOMIC_ADD
:
bool
=
False
VLLM_MARLIN_USE_ATOMIC_ADD
:
bool
=
False
VLLM_MARLIN_INPUT_DTYPE
:
Literal
[
"int8"
,
"fp8"
]
|
None
=
None
VLLM_MARLIN_INPUT_DTYPE
:
Literal
[
"int8"
,
"fp8"
]
|
None
=
None
VLLM_HUMMING_ONLINE_QUANT_CONFIG
:
dict
[
str
,
Any
]
|
None
=
None
VLLM_HUMMING_INPUT_QUANT_CONFIG
:
dict
[
str
,
Any
]
|
None
=
None
VLLM_HUMMING_USE_F16_ACCUM
:
bool
=
False
VLLM_HUMMING_MOE_GEMM_TYPE
:
Literal
[
"indexed"
,
"grouped"
,
"auto"
]
|
None
=
None
VLLM_MXFP4_USE_MARLIN
:
bool
|
None
=
None
VLLM_MXFP4_USE_MARLIN
:
bool
|
None
=
None
VLLM_DEEPEPLL_NVFP4_DISPATCH
:
bool
=
False
VLLM_DEEPEPLL_NVFP4_DISPATCH
:
bool
=
False
VLLM_V1_USE_OUTLINES_CACHE
:
bool
=
False
VLLM_V1_USE_OUTLINES_CACHE
:
bool
=
False
...
@@ -285,6 +289,15 @@ def maybe_convert_bool(value: str | None) -> bool | None:
...
@@ -285,6 +289,15 @@ def maybe_convert_bool(value: str | None) -> bool | None:
return
bool
(
int
(
value
))
return
bool
(
int
(
value
))
def
maybe_convert_json_str_or_file
(
value
:
str
|
None
)
->
dict
[
str
,
Any
]
|
None
:
if
value
is
None
:
return
None
if
os
.
path
.
exists
(
value
):
with
open
(
value
)
as
f
:
return
json
.
load
(
f
)
return
json
.
loads
(
value
)
def
disable_compile_cache
()
->
bool
:
def
disable_compile_cache
()
->
bool
:
return
bool
(
int
(
os
.
getenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"0"
)))
return
bool
(
int
(
os
.
getenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"0"
)))
...
@@ -1193,6 +1206,25 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1193,6 +1206,25 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MARLIN_INPUT_DTYPE"
:
env_with_choices
(
"VLLM_MARLIN_INPUT_DTYPE"
:
env_with_choices
(
"VLLM_MARLIN_INPUT_DTYPE"
,
None
,
[
"int8"
,
"fp8"
]
"VLLM_MARLIN_INPUT_DTYPE"
,
None
,
[
"int8"
,
"fp8"
]
),
),
# The online quantization dtype for humming kernel
"VLLM_HUMMING_ONLINE_QUANT_CONFIG"
:
lambda
:
maybe_convert_json_str_or_file
(
os
.
environ
.
get
(
"VLLM_HUMMING_ONLINE_QUANT_CONFIG"
,
None
)
),
# The activation dtype config for humming kernel
"VLLM_HUMMING_INPUT_QUANT_CONFIG"
:
lambda
:
maybe_convert_json_str_or_file
(
os
.
environ
.
get
(
"VLLM_HUMMING_INPUT_QUANT_CONFIG"
,
None
)
),
# Whether to use fp16 accumulator mma
"VLLM_HUMMING_USE_F16_ACCUM"
:
lambda
:
maybe_convert_bool
(
os
.
environ
.
get
(
"VLLM_HUMMING_USE_F16_ACCUM"
,
"0"
)
),
# Whether to use indexed gemm for humming moe
# if 1, force use indexed gemm
# if 0, force use grouped gemm
# if None, choose better gemm type automatically
"VLLM_HUMMING_MOE_GEMM_TYPE"
:
lambda
:
maybe_convert_bool
(
os
.
environ
.
get
(
"VLLM_HUMMING_MOE_GEMM_TYPE"
,
None
)
),
# Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method
# Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method
# only supported on Blackwell GPUs and with
# only supported on Blackwell GPUs and with
# https://github.com/deepseek-ai/DeepEP/pull/341
# https://github.com/deepseek-ai/DeepEP/pull/341
...
...
vllm/model_executor/layers/fused_moe/fused_humming_moe.py
0 → 100644
View file @
9f771b3a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Fused MoE utilities for Humming."""
import
json
import
math
from
typing
import
TYPE_CHECKING
,
Any
import
torch
from
humming
import
dtypes
from
humming.config
import
GemmType
as
HummingGemmType
from
humming.layer
import
HummingLayerMeta
,
HummingMethod
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
envs
from
vllm.forward_context
import
get_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.config
import
FusedMoEParallelConfig
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
,
)
from
vllm.model_executor.layers.fused_moe.moe_fused_mul_sum
import
moe_fused_mul_sum
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
moe_permute
,
moe_unpermute
,
)
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceDelegate
,
TopKWeightAndReduceNoOP
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
QuantKey
from
vllm.platforms
import
current_platform
from
vllm.v1.worker.workspace
import
current_workspace_manager
if
TYPE_CHECKING
:
from
vllm.model_executor.layers.quantization.humming
import
HummingMoEMethod
logger
=
init_logger
(
__name__
)
def
get_humming_moe_gemm_type
()
->
str
:
env_gemm_type
:
str
=
envs
.
VLLM_HUMMING_MOE_GEMM_TYPE
or
""
env_gemm_type
=
env_gemm_type
.
lower
()
if
env_gemm_type
in
[
"indexed"
,
"grouped"
]:
gemm_type
=
env_gemm_type
elif
current_platform
.
has_device_capability
(
90
):
# for device that supports TMA, use grouped gemm
gemm_type
=
"grouped"
else
:
gemm_type
=
"indexed"
logger
.
info_once
(
f
"Using
{
gemm_type
}
gemm for humming moe"
)
# noqa
return
gemm_type
class
HummingExpertsBase
(
mk
.
FusedMoEExpertsModular
):
def
__init__
(
self
,
layer
:
torch
.
nn
.
Module
,
quant_method
:
"HummingMoEMethod"
,
prepare_finalize
:
mk
.
FusedMoEPrepareAndFinalizeModular
|
None
=
None
,
):
self
.
layer
=
layer
self
.
num_experts
=
self
.
layer
.
num_experts
self
.
global_num_experts
=
self
.
layer
.
global_num_experts
self
.
init_humming_moe
()
if
prepare_finalize
is
not
None
:
max_num_tokens
:
int
|
None
=
None
num_dispatchers
:
int
|
None
=
None
if
self
.
is_batched
:
max_num_tokens
=
prepare_finalize
.
max_num_tokens_per_rank
()
num_dispatchers
=
prepare_finalize
.
num_dispatchers
()
assert
quant_method
.
moe_quant_config
is
not
None
super
().
__init__
(
moe_config
=
quant_method
.
moe
,
quant_config
=
quant_method
.
moe_quant_config
,
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
num_dispatchers
,
)
else
:
assert
not
self
.
is_batched
def
init_humming_moe
(
self
):
self
.
compute_config
=
{
"use_batch_invariant"
:
envs
.
VLLM_BATCH_INVARIANT
,
"use_f16_accum"
:
envs
.
VLLM_HUMMING_USE_F16_ACCUM
,
"gemm_type"
:
self
.
humming_gemm_type
.
value
,
}
self
.
w13_tuning_config
=
HummingMethod
.
get_default_tuning_configs
(
layer
=
self
.
layer
,
use_f16_accum
=
envs
.
VLLM_HUMMING_USE_F16_ACCUM
,
use_batch_invariant
=
envs
.
VLLM_BATCH_INVARIANT
,
gemm_type
=
self
.
humming_gemm_type
,
sublayer_name
=
"w13"
,
)
self
.
w2_tuning_config
=
HummingMethod
.
get_default_tuning_configs
(
layer
=
self
.
layer
,
use_f16_accum
=
envs
.
VLLM_HUMMING_USE_F16_ACCUM
,
use_batch_invariant
=
envs
.
VLLM_BATCH_INVARIANT
,
gemm_type
=
self
.
humming_gemm_type
,
sublayer_name
=
"w2"
,
)
self
.
compute_config_str
=
json
.
dumps
(
self
.
compute_config
)
self
.
w13_tuning_config_str
=
json
.
dumps
(
self
.
w13_tuning_config
)
self
.
w2_tuning_config_str
=
json
.
dumps
(
self
.
w2_tuning_config
)
def
get_global_valid_shape_m
(
self
,
topk_ids
:
torch
.
Tensor
):
num_tokens
=
topk_ids
.
size
(
0
)
ctx
=
get_forward_context
()
if
ctx
.
dp_metadata
is
not
None
:
num_tokens
=
ctx
.
dp_metadata
.
num_tokens_across_dp_cpu
.
sum
().
item
()
return
num_tokens
*
topk_ids
.
size
(
1
)
def
estimate_local_valid_shape_m
(
self
,
topk_ids
:
torch
.
Tensor
):
# estimate shape_m for kernel tuning
global_valid_shape_m
=
self
.
get_global_valid_shape_m
(
topk_ids
)
num_experts
=
self
.
num_experts
global_num_experts
=
self
.
global_num_experts
return
math
.
ceil
(
global_valid_shape_m
*
num_experts
/
global_num_experts
)
@
property
def
humming_gemm_type
(
self
)
->
HummingGemmType
:
raise
NotImplementedError
@
property
def
is_batched
(
self
)
->
bool
:
return
self
.
activation_format
()
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
@
staticmethod
def
_supports_quant_scheme
(
weight_key
:
QuantKey
|
None
,
activation_key
:
QuantKey
|
None
,
)
->
bool
:
return
True
def
supports_expert_map
(
self
)
->
bool
:
return
True
@
staticmethod
def
_supports_current_device
()
->
bool
:
platform
=
current_platform
return
platform
.
is_cuda
()
and
platform
.
has_device_capability
((
7
,
5
))
@
staticmethod
def
_supports_no_act_and_mul
()
->
bool
:
return
True
@
staticmethod
def
_supports_activation
(
activation
:
MoEActivation
)
->
bool
:
# Humming uses apply_moe_activation() callback for activation,
# so any activation supported there can be used here.
return
activation
in
[
MoEActivation
.
SILU
,
MoEActivation
.
GELU
,
MoEActivation
.
SWIGLUOAI
,
MoEActivation
.
SWIGLUSTEP
,
MoEActivation
.
SILU_NO_MUL
,
MoEActivation
.
GELU_NO_MUL
,
MoEActivation
.
RELU2_NO_MUL
,
]
@
staticmethod
def
_supports_parallel_config
(
moe_parallel_config
:
FusedMoEParallelConfig
)
->
bool
:
return
not
(
moe_parallel_config
.
use_fi_nvl_two_sided_kernels
or
moe_parallel_config
.
use_fi_nvl_one_sided_kernels
)
def
moe_problem_size
(
self
,
a1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
tuple
[
int
,
int
,
int
,
int
,
int
]:
meta1
:
HummingLayerMeta
=
self
.
layer
.
humming_metas
[
"w13"
]
meta2
:
HummingLayerMeta
=
self
.
layer
.
humming_metas
[
"w2"
]
assert
meta1
.
num_experts
==
meta2
.
num_experts
num_experts
=
meta1
.
num_experts
top_k
=
topk_ids
.
size
(
1
)
assert
w1
.
size
(
0
)
==
num_experts
assert
w2
.
size
(
0
)
==
num_experts
if
not
self
.
is_batched
:
num_tokens
=
a1
.
size
(
0
)
assert
topk_ids
.
size
(
0
)
==
num_tokens
else
:
assert
a1
.
dim
()
==
3
assert
a1
.
size
(
0
)
==
num_experts
num_tokens
=
a1
.
size
(
1
)
return
meta1
.
num_experts
,
num_tokens
,
meta1
.
shape_n
//
2
,
meta1
.
shape_k
,
top_k
def
get_buffer_metas
(
self
,
M
:
int
,
topk
:
int
,
activation
:
MoEActivation
):
num_experts
=
self
.
num_experts
N
=
self
.
layer
.
intermediate_size
K
=
self
.
layer
.
hidden_size
assert
isinstance
(
num_experts
,
int
)
assert
isinstance
(
N
,
int
)
assert
isinstance
(
K
,
int
)
# hidden_states
# (-> quanted_gate_up_input) (if not BF16/FP16 activation)
# -> gate_up_output
# -> activation_output
# (-> quanted_down_input) (if not BF16/FP16 activation)
# -> down_output
# (-> output) (if not is_batched)
# Neighboring nodes are required to utilize distinct workspaces.
# The output must be derived from workspace1.
output_shape
:
tuple
[
int
,
...]
if
self
.
is_batched
:
max_num_tokens
=
self
.
max_num_tokens
num_dispatchers
=
self
.
num_dispatchers
assert
max_num_tokens
is
not
None
and
num_dispatchers
is
not
None
input_shape_m
=
num_experts
*
max_num_tokens
real_shape_m
=
num_experts
*
max_num_tokens
*
num_dispatchers
output_shape
=
(
num_experts
,
max_num_tokens
*
num_dispatchers
,
K
)
else
:
input_shape_m
=
M
if
self
.
humming_gemm_type
!=
HummingGemmType
.
INDEXED
:
input_shape_m
=
M
*
topk
real_shape_m
=
M
*
topk
output_shape
=
(
M
,
K
)
down_input_size
=
N
if
activation
.
is_gated
else
(
N
*
2
)
a_dtype
=
self
.
layer
.
humming_metas
[
"w13"
].
a_dtype
c_dtype
=
self
.
layer
.
humming_metas
[
"w13"
].
c_dtype
num_bits
=
a_dtype
.
num_bits
torch_dtype_map
=
{
dtypes
.
float16
:
torch
.
float16
,
dtypes
.
bfloat16
:
torch
.
bfloat16
,
dtypes
.
float8e4m3
:
torch
.
float8_e4m3fn
,
dtypes
.
int8
:
torch
.
int8
,
dtypes
.
int4
:
torch
.
uint8
,
}
buffer_metas
=
{
"quanted_gate_up_input"
:
{
"shape"
:
(
input_shape_m
,
K
),
"dtype"
:
torch_dtype_map
[
a_dtype
],
},
"gate_up_output"
:
{
"shape"
:
(
real_shape_m
,
N
*
2
),
"dtype"
:
torch_dtype_map
[
c_dtype
],
},
"activation_output"
:
{
"shape"
:
(
real_shape_m
,
down_input_size
),
"dtype"
:
torch_dtype_map
[
c_dtype
],
},
"quanted_down_input"
:
{
"shape"
:
(
real_shape_m
,
down_input_size
),
"dtype"
:
torch_dtype_map
[
a_dtype
],
},
"down_output"
:
{
"shape"
:
output_shape
if
self
.
is_batched
else
(
real_shape_m
,
K
),
"dtype"
:
torch_dtype_map
[
c_dtype
],
},
"output"
:
{
"shape"
:
output_shape
,
"dtype"
:
torch_dtype_map
[
c_dtype
],
},
}
for
key
in
buffer_metas
:
meta
=
buffer_metas
[
key
]
if
"quanted"
in
key
and
a_dtype
.
num_bits
==
4
:
meta
[
"shape"
]
=
meta
[
"shape"
][:
-
1
]
+
(
meta
[
"shape"
][
-
1
]
//
2
,)
if
num_bits
==
16
:
required_buffers
=
[
"gate_up_output"
,
"activation_output"
,
"down_output"
]
else
:
required_buffers
=
[
"quanted_gate_up_input"
,
"gate_up_output"
,
"activation_output"
,
"quanted_down_input"
,
"down_output"
,
]
# batched moe use down_output as output
if
not
self
.
is_batched
:
required_buffers
.
append
(
"output"
)
return
buffer_metas
,
required_buffers
def
_workspace_shapes
(
self
,
M
:
int
,
topk
:
int
,
activation
:
MoEActivation
):
buffer_metas
,
required_buffers
=
self
.
get_buffer_metas
(
M
,
topk
,
activation
)
workspace1_nbytes
=
0
workspace2_nbytes
=
0
for
index
,
name
in
enumerate
(
required_buffers
[::
-
1
]):
buffer_meta
=
buffer_metas
[
name
]
nelement
=
math
.
prod
(
buffer_meta
[
"shape"
])
nbytes
=
nelement
*
buffer_meta
[
"dtype"
].
itemsize
if
index
%
2
==
0
:
workspace1_nbytes
=
max
(
workspace1_nbytes
,
nbytes
)
else
:
workspace2_nbytes
=
max
(
workspace2_nbytes
,
nbytes
)
output_key
=
"down_output"
if
self
.
is_batched
else
"output"
output_shape
=
buffer_metas
[
output_key
][
"shape"
]
return
(
workspace1_nbytes
//
2
,),
(
workspace2_nbytes
//
2
,),
output_shape
def
workspace_shapes
(
self
,
M
:
int
,
N
:
int
,
K
:
int
,
topk
:
int
,
global_num_experts
:
int
,
local_num_experts
:
int
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
activation
:
MoEActivation
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
return
self
.
_workspace_shapes
(
M
,
topk
,
activation
)
def
make_workspaces
(
self
,
M
:
int
,
topk
:
int
,
activation
:
MoEActivation
):
shapes
=
self
.
_workspace_shapes
(
M
,
topk
,
activation
)
workspace1_shape
,
workspace2_shape
,
output_shape
=
shapes
torch_dtype
=
self
.
layer
.
param_dtype
workspace1
,
workspace2
=
current_workspace_manager
().
get_simultaneous
(
(
workspace1_shape
,
torch_dtype
),
(
workspace2_shape
,
torch_dtype
),
)
output
=
_resize_cache
(
workspace1
,
output_shape
)
return
workspace1
,
workspace2
,
output
def
prepare_buffers
(
self
,
workspace1
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
M
:
int
,
topk
:
int
,
activation
:
MoEActivation
,
)
->
dict
[
str
,
torch
.
Tensor
]:
buffer_metas
,
required_buffers
=
self
.
get_buffer_metas
(
M
,
topk
,
activation
)
buffers
=
{}
for
index
,
name
in
enumerate
(
required_buffers
[::
-
1
]):
buffer_meta
=
buffer_metas
[
name
]
workspace
=
workspace1
if
index
%
2
==
0
else
workspace2
workspace
=
workspace
.
view
(
buffer_meta
[
"dtype"
])
buffers
[
name
]
=
_resize_cache
(
workspace
,
buffer_meta
[
"shape"
])
return
buffers
def
apply
(
self
,
output
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
activation
:
MoEActivation
,
global_num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
,
a1q_scale
:
torch
.
Tensor
|
None
,
a2_scale
:
torch
.
Tensor
|
None
,
workspace13
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
):
assert
not
apply_router_weight_on_input
self
.
main_apply
(
hidden_states
=
hidden_states
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
workspace1
=
workspace13
,
workspace2
=
workspace2
,
expert_tokens_meta
=
expert_tokens_meta
,
)
def
main_apply
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
workspace1
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
):
raise
NotImplementedError
class
HummingIndexedExperts
(
HummingExpertsBase
):
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
return
TopKWeightAndReduceNoOP
()
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
@
property
def
humming_gemm_type
(
self
)
->
HummingGemmType
:
return
HummingGemmType
.
INDEXED
def
prepare_humming_moe_kwargs
(
self
,
topk_ids
:
torch
.
Tensor
,
expert_map
:
torch
.
Tensor
|
None
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
)
->
tuple
[
dict
[
str
,
Any
],
dict
[
str
,
Any
]]:
valid_shape_m
=
self
.
estimate_local_valid_shape_m
(
topk_ids
)
for
min_shape_m
,
max_shape_m
,
config
in
self
.
w13_tuning_config
:
if
valid_shape_m
>
min_shape_m
and
valid_shape_m
<=
max_shape_m
:
moe_block_size
=
config
[
"block_shape"
][
0
]
break
else
:
raise
ValueError
(
f
"cannot found moe_block_size for shape
{
valid_shape_m
}
"
)
sorted_ids
,
expert_ids
,
num_tokens_padded
=
moe_align_block_size
(
topk_ids
=
topk_ids
,
block_size
=
moe_block_size
,
num_experts
=
self
.
global_num_experts
,
expert_map
=
expert_map
,
ignore_invalid_experts
=
True
,
)
moe_common_kwargs
=
{
"sorted_ids"
:
sorted_ids
,
"expert_ids"
:
expert_ids
,
"num_tokens_padded"
:
num_tokens_padded
,
"compute_config"
:
self
.
compute_config_str
,
"valid_shape_m"
:
valid_shape_m
,
}
top_k
=
topk_ids
.
size
(
1
)
moe_kwargs1
=
{
"top_k"
:
top_k
,
"tuning_config"
:
self
.
w13_tuning_config_str
}
moe_kwargs2
=
{
"top_k"
:
1
,
"tuning_config"
:
self
.
w2_tuning_config_str
}
moe_kwargs1
.
update
(
moe_common_kwargs
)
moe_kwargs2
.
update
(
moe_common_kwargs
)
return
moe_kwargs1
,
moe_kwargs2
def
main_apply
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
workspace1
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
):
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
-
1
))
buffers
=
self
.
prepare_buffers
(
workspace1
,
workspace2
,
topk_ids
.
size
(
0
),
topk_ids
.
size
(
1
),
self
.
layer
.
activation
,
)
moe_kwargs1
,
moe_kwargs2
=
self
.
prepare_humming_moe_kwargs
(
topk_ids
=
topk_ids
,
expert_map
=
self
.
layer
.
expert_map
,
expert_tokens_meta
=
expert_tokens_meta
,
)
inputs
,
input_scale
=
HummingMethod
.
may_quant_input
(
layer
=
self
.
layer
,
inputs
=
hidden_states
,
quanted_input
=
buffers
.
get
(
"quanted_gate_up_input"
,
None
),
sublayer_name
=
"w13"
,
)
HummingMethod
.
forward_layer
(
layer
=
self
.
layer
,
inputs
=
inputs
,
input_scale
=
input_scale
,
outputs
=
buffers
[
"gate_up_output"
],
sublayer_name
=
"w13"
,
**
moe_kwargs1
,
)
self
.
activation
(
activation
=
self
.
layer
.
activation
,
input
=
buffers
[
"gate_up_output"
],
output
=
buffers
[
"activation_output"
],
)
inputs
,
input_scale
=
HummingMethod
.
may_quant_input
(
layer
=
self
.
layer
,
inputs
=
buffers
[
"activation_output"
],
quanted_input
=
buffers
.
get
(
"quanted_down_input"
,
None
),
sublayer_name
=
"w2"
,
)
HummingMethod
.
forward_layer
(
layer
=
self
.
layer
,
inputs
=
inputs
,
input_scale
=
input_scale
,
outputs
=
buffers
[
"down_output"
].
view
(
-
1
,
hidden_states
.
size
(
-
1
)),
sublayer_name
=
"w2"
,
**
moe_kwargs2
,
)
moe_fused_mul_sum
(
inputs
=
buffers
[
"down_output"
].
view
(
*
topk_ids
.
shape
,
-
1
),
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
expert_map
=
self
.
layer
.
expert_map
,
outputs
=
buffers
[
"output"
],
)
class
HummingGroupedExperts
(
HummingExpertsBase
):
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
return
TopKWeightAndReduceNoOP
()
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
Standard
@
property
def
humming_gemm_type
(
self
)
->
HummingGemmType
:
return
HummingGemmType
.
GROUPED_CONTIGUOUS
def
main_apply
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
workspace1
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
):
valid_shape_m
=
self
.
estimate_local_valid_shape_m
(
topk_ids
)
buffers
=
self
.
prepare_buffers
(
workspace1
,
workspace2
,
topk_ids
.
size
(
0
),
topk_ids
.
size
(
1
),
self
.
layer
.
activation
,
)
hidden_states
,
_
,
expert_first_token_offset
,
inv_perm
,
_
=
moe_permute
(
hidden_states
=
hidden_states
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
n_expert
=
self
.
global_num_experts
,
n_local_expert
=
self
.
num_experts
,
expert_map
=
self
.
layer
.
expert_map
,
)
inputs
,
input_scale
=
HummingMethod
.
may_quant_input
(
layer
=
self
.
layer
,
inputs
=
hidden_states
,
quanted_input
=
buffers
.
get
(
"quanted_gate_up_input"
,
None
),
sublayer_name
=
"w13"
,
)
HummingMethod
.
forward_layer
(
layer
=
self
.
layer
,
inputs
=
inputs
,
input_scale
=
input_scale
,
outputs
=
buffers
[
"gate_up_output"
],
valid_shape_m
=
valid_shape_m
,
expert_layout
=
expert_first_token_offset
,
compute_config
=
self
.
compute_config_str
,
tuning_config
=
self
.
w13_tuning_config_str
,
sublayer_name
=
"w13"
,
)
self
.
activation
(
activation
=
self
.
layer
.
activation
,
input
=
buffers
[
"gate_up_output"
],
output
=
buffers
[
"activation_output"
],
)
inputs
,
input_scale
=
HummingMethod
.
may_quant_input
(
layer
=
self
.
layer
,
inputs
=
buffers
[
"activation_output"
],
quanted_input
=
buffers
.
get
(
"quanted_down_input"
,
None
),
sublayer_name
=
"w2"
,
)
HummingMethod
.
forward_layer
(
layer
=
self
.
layer
,
inputs
=
inputs
,
input_scale
=
input_scale
,
outputs
=
buffers
[
"down_output"
],
valid_shape_m
=
valid_shape_m
,
expert_layout
=
expert_first_token_offset
,
compute_config
=
self
.
compute_config_str
,
tuning_config
=
self
.
w2_tuning_config_str
,
sublayer_name
=
"w2"
,
)
moe_unpermute
(
out
=
buffers
[
"output"
],
permuted_hidden_states
=
buffers
[
"down_output"
].
view
(
*
topk_ids
.
shape
,
-
1
),
topk_weights
=
topk_weights
,
inv_permuted_idx
=
inv_perm
,
expert_first_token_offset
=
expert_first_token_offset
,
)
class
BatchedHummingGroupedExperts
(
HummingExpertsBase
):
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
return
TopKWeightAndReduceDelegate
()
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
return
mk
.
FusedMoEActivationFormat
.
BatchedExperts
@
property
def
humming_gemm_type
(
self
)
->
HummingGemmType
:
return
HummingGemmType
.
GROUPED_MASKED
def
main_apply
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
workspace1
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
):
assert
expert_tokens_meta
is
not
None
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
-
1
))
valid_shape_m
=
self
.
estimate_local_valid_shape_m
(
topk_ids
)
expert_num_tokens
=
expert_tokens_meta
.
expert_num_tokens
buffers
=
self
.
prepare_buffers
(
workspace1
,
workspace2
,
topk_ids
.
size
(
0
),
topk_ids
.
size
(
1
),
self
.
layer
.
activation
,
)
inputs
,
input_scale
=
HummingMethod
.
may_quant_input
(
layer
=
self
.
layer
,
inputs
=
hidden_states
,
quanted_input
=
buffers
.
get
(
"quanted_gate_up_input"
,
None
),
sublayer_name
=
"w13"
,
)
HummingMethod
.
forward_layer
(
layer
=
self
.
layer
,
inputs
=
inputs
,
input_scale
=
input_scale
,
outputs
=
buffers
[
"gate_up_output"
],
valid_shape_m
=
valid_shape_m
,
expert_layout
=
expert_num_tokens
,
compute_config
=
self
.
compute_config_str
,
tuning_config
=
self
.
w13_tuning_config_str
,
sublayer_name
=
"w13"
,
)
self
.
activation
(
activation
=
self
.
layer
.
activation
,
input
=
buffers
[
"gate_up_output"
],
output
=
buffers
[
"activation_output"
],
)
inputs
,
input_scale
=
HummingMethod
.
may_quant_input
(
layer
=
self
.
layer
,
inputs
=
buffers
[
"activation_output"
],
quanted_input
=
buffers
.
get
(
"quanted_down_input"
,
None
),
sublayer_name
=
"w2"
,
)
HummingMethod
.
forward_layer
(
layer
=
self
.
layer
,
inputs
=
inputs
,
input_scale
=
input_scale
,
outputs
=
buffers
[
"down_output"
].
view
(
-
1
,
hidden_states
.
size
(
-
1
)),
valid_shape_m
=
valid_shape_m
,
expert_layout
=
expert_num_tokens
,
compute_config
=
self
.
compute_config_str
,
tuning_config
=
self
.
w2_tuning_config_str
,
sublayer_name
=
"w2"
,
)
vllm/model_executor/layers/fused_moe/layer.py
View file @
9f771b3a
...
@@ -1097,7 +1097,11 @@ class FusedMoE(PluggableLayer):
...
@@ -1097,7 +1097,11 @@ class FusedMoE(PluggableLayer):
expert_id
:
int
,
expert_id
:
int
,
return_success
:
bool
=
False
,
return_success
:
bool
=
False
,
)
->
bool
|
None
:
)
->
bool
|
None
:
if
self
.
quant_config
and
self
.
quant_config
.
get_name
()
==
"gpt_oss_mxfp4"
:
quant_config_name
=
self
.
quant_config
and
self
.
quant_config
.
get_name
()
if
quant_config_name
==
"humming"
:
assert
hasattr
(
self
.
quant_method
,
"weight_schema"
)
quant_config_name
=
self
.
quant_method
.
weight_schema
.
quant_method
if
quant_config_name
==
"gpt_oss_mxfp4"
:
# (FIXME) for gpt-oss all experts are combined
# (FIXME) for gpt-oss all experts are combined
if
"bias"
in
weight_name
:
if
"bias"
in
weight_name
:
dim1
=
loaded_weight
.
shape
[
1
]
dim1
=
loaded_weight
.
shape
[
1
]
...
...
vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py
0 → 100644
View file @
9f771b3a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
torch._subclasses.fake_tensor
import
FakeTensor
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
@
triton
.
jit
def
moe_fused_mul_sum_kernel
(
inputs_ptr
,
topk_weights_ptr
,
outputs_ptr
,
top_ids_ptr
,
expert_map_ptr
,
num_tokens
,
stride_m
,
has_expert_map
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
size
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
):
pid_k
=
tl
.
program_id
(
0
)
pid_m
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_k
=
pid_k
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
m_mask
=
offs_m
<
num_tokens
k_mask
=
offs_k
<
size
mask
=
m_mask
[:,
None
]
&
k_mask
[
None
,
:]
a_base
=
inputs_ptr
+
(
offs_m
*
stride_m
)[:,
None
]
+
offs_k
[
None
,
:]
b_base
=
topk_weights_ptr
+
offs_m
*
top_k
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_K
),
dtype
=
tl
.
float32
)
for
n
in
tl
.
static_range
(
top_k
):
b_val
=
tl
.
load
(
b_base
+
n
,
mask
=
m_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
if
has_expert_map
:
id_val
=
tl
.
load
(
top_ids_ptr
+
offs_m
*
top_k
+
n
,
mask
=
m_mask
,
other
=
0
)
expert_mask
=
tl
.
load
(
expert_map_ptr
+
id_val
)
>=
0
a_vec
=
tl
.
load
(
a_base
+
n
*
size
,
mask
=
mask
&
expert_mask
[:,
None
],
other
=
0.0
,
).
to
(
tl
.
float32
)
else
:
a_vec
=
tl
.
load
(
a_base
+
n
*
size
,
mask
=
mask
,
other
=
0.0
,
).
to
(
tl
.
float32
)
acc
+=
a_vec
*
b_val
[:,
None
]
out_ptrs
=
outputs_ptr
+
(
offs_m
*
size
)[:,
None
]
+
offs_k
[
None
,
:]
tl
.
store
(
out_ptrs
,
acc
.
to
(
outputs_ptr
.
dtype
.
element_ty
),
mask
=
mask
,
)
def
_heuristic_config
(
num_tokens
:
int
,
top_k
:
int
,
size
:
int
,
element_size
:
int
,
):
is_fp32
=
element_size
>
2
is_sm90_plus
=
current_platform
.
has_device_capability
(
90
)
is_sm80_before
=
not
current_platform
.
has_device_capability
(
80
)
if
current_platform
.
has_device_capability
(
90
):
# SM90/SM100+: prefer small tiles + many CTAs.
if
is_fp32
:
BLOCK_M
=
1
if
num_tokens
<=
4
else
2
else
:
if
num_tokens
<=
4
:
BLOCK_M
=
1
elif
num_tokens
<=
128
:
BLOCK_M
=
2
else
:
BLOCK_M
=
4
elif
is_fp32
:
if
num_tokens
<=
4
:
BLOCK_M
=
1
elif
num_tokens
<=
32
:
BLOCK_M
=
2
elif
num_tokens
<=
128
:
BLOCK_M
=
4
else
:
BLOCK_M
=
4
else
:
if
num_tokens
<=
4
:
BLOCK_M
=
1
elif
num_tokens
<=
32
:
BLOCK_M
=
2
elif
num_tokens
<=
128
:
BLOCK_M
=
4
elif
num_tokens
<=
1024
:
BLOCK_M
=
16
else
:
BLOCK_M
=
8
if
is_fp32
:
max_block_k
=
256
elif
is_sm80_before
or
is_sm90_plus
:
max_block_k
=
512
else
:
max_block_k
=
1024
BLOCK_K
=
min
(
triton
.
next_power_of_2
(
size
),
max_block_k
)
BLOCK_K
=
max
(
BLOCK_K
,
256
)
total
=
BLOCK_M
*
BLOCK_K
if
is_fp32
:
num_warps
=
max
(
8
,
min
(
16
,
total
//
64
))
else
:
num_warps
=
max
(
4
,
min
(
16
,
total
//
256
))
if
is_sm80_before
:
num_warps
=
min
(
num_warps
,
8
)
num_stages
=
2
elif
is_sm90_plus
:
num_warps
=
min
(
num_warps
,
8
)
num_stages
=
4
if
total
<=
2048
else
2
else
:
num_stages
=
4
if
total
<=
2048
else
2
return
BLOCK_M
,
BLOCK_K
,
num_warps
,
num_stages
def
moe_fused_mul_sum
(
inputs
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
outputs
:
torch
.
Tensor
|
None
=
None
,
topk_ids
:
torch
.
Tensor
|
None
=
None
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""
Fused kernel for MoE (Mixture of Experts) to perform weighted summation
of expert outputs.
Args:
inputs: The output from experts.
Shape: (num_tokens, top_k, hidden_size).
topk_weights: The weights assigned to each expert for each token.
Shape: (num_tokens, top_k).
outputs: Optional pre-allocated output tensor.
Shape: (num_tokens, hidden_size).
topk_ids: Optional indices of the top-k experts. Used when
`expert_map` is provided. Shape: (num_tokens, top_k).
expert_map: Optional mapping for Expert Parallelism. A value < 0
indicates an invalid token/expert pair that will be skipped.
Returns:
The fused weighted sum of expert outputs.
Shape: (num_tokens, hidden_size).
"""
assert
inputs
.
ndim
==
3
assert
topk_weights
.
ndim
==
2
assert
inputs
.
is_contiguous
()
assert
topk_weights
.
is_contiguous
()
assert
inputs
.
dtype
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
)
assert
topk_weights
.
dtype
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
)
num_tokens
,
top_k
,
size
=
inputs
.
shape
output_shape
=
(
num_tokens
,
size
)
if
outputs
is
None
:
outputs
=
torch
.
empty
(
output_shape
,
dtype
=
inputs
.
dtype
,
device
=
inputs
.
device
)
assert
outputs
.
shape
==
output_shape
assert
topk_weights
.
shape
==
(
num_tokens
,
top_k
)
if
not
isinstance
(
inputs
,
FakeTensor
):
BLOCK_M
,
BLOCK_K
,
num_warps
,
num_stages
=
_heuristic_config
(
num_tokens
,
top_k
,
size
,
inputs
.
element_size
(),
)
grid
=
(
triton
.
cdiv
(
size
,
BLOCK_K
),
triton
.
cdiv
(
num_tokens
,
BLOCK_M
))
moe_fused_mul_sum_kernel
[
grid
](
inputs
,
topk_weights
,
outputs
,
topk_ids
,
expert_map
,
num_tokens
,
top_k
*
size
,
expert_map
is
not
None
,
top_k
,
size
,
BLOCK_M
,
BLOCK_K
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
outputs
vllm/model_executor/layers/linear.py
View file @
9f771b3a
...
@@ -60,6 +60,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
...
@@ -60,6 +60,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"ModelOptFp8PbWoLinearMethod"
,
"ModelOptFp8PbWoLinearMethod"
,
"QuarkLinearMethod"
,
"QuarkLinearMethod"
,
"ModelOptNvFp4LinearMethod"
,
"ModelOptNvFp4LinearMethod"
,
"HummingLinearMethod"
,
]
]
...
@@ -245,6 +246,7 @@ class LinearBase(PluggableLayer):
...
@@ -245,6 +246,7 @@ class LinearBase(PluggableLayer):
self
,
self
,
input_size
:
int
,
input_size
:
int
,
output_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
torch
.
dtype
|
None
=
None
,
params_dtype
:
torch
.
dtype
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
...
@@ -258,6 +260,7 @@ class LinearBase(PluggableLayer):
...
@@ -258,6 +260,7 @@ class LinearBase(PluggableLayer):
# Keep input parameters
# Keep input parameters
self
.
input_size
=
input_size
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
output_size
=
output_size
self
.
has_bias
=
bias
self
.
skip_bias_add
=
skip_bias_add
self
.
skip_bias_add
=
skip_bias_add
if
params_dtype
is
None
:
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
...
@@ -323,6 +326,7 @@ class ReplicatedLinear(LinearBase):
...
@@ -323,6 +326,7 @@ class ReplicatedLinear(LinearBase):
super
().
__init__
(
super
().
__init__
(
input_size
,
input_size
,
output_size
,
output_size
,
bias
,
skip_bias_add
,
skip_bias_add
,
params_dtype
,
params_dtype
,
quant_config
,
quant_config
,
...
@@ -458,6 +462,7 @@ class ColumnParallelLinear(LinearBase):
...
@@ -458,6 +462,7 @@ class ColumnParallelLinear(LinearBase):
super
().
__init__
(
super
().
__init__
(
input_size
,
input_size
,
output_size
,
output_size
,
bias
,
skip_bias_add
,
skip_bias_add
,
params_dtype
,
params_dtype
,
quant_config
,
quant_config
,
...
@@ -483,6 +488,7 @@ class ColumnParallelLinear(LinearBase):
...
@@ -483,6 +488,7 @@ class ColumnParallelLinear(LinearBase):
else
self
.
weight_loader
else
self
.
weight_loader
),
),
)
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
dtype
=
params_dtype
)
torch
.
empty
(
self
.
output_size_per_partition
,
dtype
=
params_dtype
)
...
@@ -817,8 +823,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -817,8 +823,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# for the packing.
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
packed_factor
shard_size
=
round
(
shard_size
//
param
.
packed_factor
)
shard_offset
=
shard_offset
//
param
.
packed_factor
shard_offset
=
round
(
shard_offset
//
param
.
packed_factor
)
# Special case for Marlin.
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
param
,
shard_size
,
shard_offset
...
@@ -1252,8 +1258,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -1252,8 +1258,8 @@ class QKVParallelLinear(ColumnParallelLinear):
)
)
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
packed_factor
shard_size
=
round
(
shard_size
//
param
.
packed_factor
)
shard_offset
=
shard_offset
//
param
.
packed_factor
shard_offset
=
round
(
shard_offset
//
param
.
packed_factor
)
# Special case for Marlin.
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
...
@@ -1315,8 +1321,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -1315,8 +1321,8 @@ class QKVParallelLinear(ColumnParallelLinear):
# for the packing.
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
packed_factor
shard_size
=
round
(
shard_size
//
param
.
packed_factor
)
shard_offset
=
shard_offset
//
param
.
packed_factor
shard_offset
=
round
(
shard_offset
//
param
.
packed_factor
)
# Special case for Marlin.
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
...
@@ -1440,6 +1446,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1440,6 +1446,7 @@ class RowParallelLinear(LinearBase):
super
().
__init__
(
super
().
__init__
(
input_size
,
input_size
,
output_size
,
output_size
,
bias
,
skip_bias_add
,
skip_bias_add
,
params_dtype
,
params_dtype
,
quant_config
,
quant_config
,
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
9f771b3a
...
@@ -22,6 +22,7 @@ QuantizationMethods = Literal[
...
@@ -22,6 +22,7 @@ QuantizationMethods = Literal[
"gptq_marlin"
,
"gptq_marlin"
,
"awq_marlin"
,
"awq_marlin"
,
"gptq"
,
"gptq"
,
"humming"
,
"compressed-tensors"
,
"compressed-tensors"
,
"bitsandbytes"
,
"bitsandbytes"
,
"experts_int8"
,
"experts_int8"
,
...
@@ -126,6 +127,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -126,6 +127,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from
.gguf
import
GGUFConfig
from
.gguf
import
GGUFConfig
from
.gptq
import
GPTQConfig
from
.gptq
import
GPTQConfig
from
.gptq_marlin
import
GPTQMarlinConfig
from
.gptq_marlin
import
GPTQMarlinConfig
from
.humming
import
HummingConfig
from
.inc
import
INCConfig
from
.inc
import
INCConfig
from
.modelopt
import
(
from
.modelopt
import
(
ModelOptFp8Config
,
ModelOptFp8Config
,
...
@@ -162,6 +164,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -162,6 +164,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"mxfp4"
:
Mxfp4Config
,
"mxfp4"
:
Mxfp4Config
,
"gpt_oss_mxfp4"
:
GptOssMxfp4Config
,
"gpt_oss_mxfp4"
:
GptOssMxfp4Config
,
"cpu_awq"
:
CPUAWQConfig
,
"cpu_awq"
:
CPUAWQConfig
,
"humming"
:
HummingConfig
,
"online"
:
OnlineQuantizationConfig
,
"online"
:
OnlineQuantizationConfig
,
}
}
...
...
vllm/model_executor/layers/quantization/humming.py
0 → 100644
View file @
9f771b3a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
json
import
math
from
collections.abc
import
Callable
from
typing
import
TYPE_CHECKING
,
Any
import
regex
as
re
import
torch
from
vllm
import
envs
from
vllm.model_executor.layers.fused_moe.activation
import
MoEActivation
from
vllm.model_executor.layers.fused_moe.config
import
(
FusedMoEConfig
,
FusedMoEQuantConfig
,
FusedMoEQuantDesc
,
)
from
vllm.model_executor.layers.fused_moe.layer
import
(
FusedMoE
,
FusedMoEMethodBase
,
)
from
vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method
import
(
UnquantizedFusedMoEMethod
,
)
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
,
)
from
vllm.model_executor.layers.quantization
import
QuantizationMethods
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
GroupShape
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
BlockQuantScaleParameter
,
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
ModelWeightParameter
,
PackedvLLMParameter
,
PerTensorScaleParameter
,
RowvLLMParameter
,
)
from
vllm.model_executor.utils
import
set_weight_attrs
if
TYPE_CHECKING
:
from
vllm.model_executor.models.utils
import
WeightsMapper
try
:
from
humming.dtypes
import
DataType
from
humming.layer
import
HummingMethod
from
humming.schema
import
(
BaseInputSchema
,
BaseWeightSchema
,
HummingInputSchema
,
HummingWeightSchema
,
)
from
humming.utils.weight
import
quantize_weight
from
vllm.model_executor.layers.fused_moe.fused_humming_moe
import
(
BatchedHummingGroupedExperts
,
HummingGroupedExperts
,
HummingIndexedExperts
,
get_humming_moe_gemm_type
,
)
except
ModuleNotFoundError
:
HummingMethod
=
None
def
assert_humming_available
():
assert
HummingMethod
is
not
None
,
(
"humming is not available, please run "
"'pip install git+https://github.com/inclusionAI/humming' to install it."
)
def
prepare_padded_shape
(
shape
,
x
):
padded_shape
=
math
.
ceil
(
shape
/
x
)
*
x
return
padded_shape
,
padded_shape
-
shape
def
prepare_param
(
tensor
,
name
,
extra_attrs
):
extra_attrs
=
extra_attrs
.
copy
()
scale_type
=
extra_attrs
.
pop
(
"scale_type"
,
None
)
param_cls_name_map
=
{
"block"
:
BlockQuantScaleParameter
,
"tensor"
:
PerTensorScaleParameter
,
"group"
:
GroupQuantScaleParameter
,
"channel"
:
ChannelQuantScaleParameter
,
"input_scale"
:
PerTensorScaleParameter
,
}
param_cls
:
type
[
BasevLLMParameter
]
if
"packed_dim"
in
extra_attrs
:
param_cls
=
PackedvLLMParameter
elif
scale_type
in
param_cls_name_map
:
param_cls
=
param_cls_name_map
[
scale_type
]
elif
"output_dim"
in
extra_attrs
and
"input_dim"
in
extra_attrs
:
param_cls
=
ModelWeightParameter
elif
"input_dim"
in
extra_attrs
:
param_cls
=
RowvLLMParameter
elif
"output_dim"
in
extra_attrs
:
param_cls
=
ChannelQuantScaleParameter
else
:
param_cls
=
BasevLLMParameter
kwargs_keys
=
[
"input_dim"
,
"output_dim"
,
"packed_dim"
,
"packed_factor"
,
"weight_loader"
,
]
cls_kwargs
=
{}
for
key
in
extra_attrs
.
copy
():
if
key
in
kwargs_keys
:
cls_kwargs
[
key
]
=
extra_attrs
.
pop
(
key
)
param
=
param_cls
(
data
=
tensor
,
**
cls_kwargs
)
set_weight_attrs
(
param
,
extra_attrs
)
param
.
param_name
=
name
param
.
ignore_warning
=
True
if
scale_type
in
[
"tensor"
,
"input_scale"
]:
param
.
needs_scalar_to_array
=
True
return
param
def
prepare_moe_param
(
tensor
,
name
,
extra_attrs
):
param
=
torch
.
nn
.
Parameter
(
tensor
,
requires_grad
=
False
)
if
"scale_type"
in
extra_attrs
:
extra_attrs
[
"quant_method"
]
=
extra_attrs
[
"scale_type"
]
if
"input_dim"
in
extra_attrs
and
"output_dim"
in
extra_attrs
:
input_dim
=
extra_attrs
[
"input_dim"
]
output_dim
=
extra_attrs
[
"output_dim"
]
extra_attrs
[
"is_transposed"
]
=
input_dim
<
output_dim
set_weight_attrs
(
param
,
extra_attrs
)
param
.
param_name
=
name
return
param
def
may_pad_loaded_weight
(
param
,
loaded_weight
):
pad_shape
=
getattr
(
param
,
"pad_shape"
,
None
)
if
pad_shape
is
None
:
return
loaded_weight
value
=
1
if
loaded_weight
.
dtype
==
torch
.
float8_e8m0fnu
else
0
padding
=
[]
for
x
in
pad_shape
[::
-
1
][:
loaded_weight
.
ndim
]:
padding
+=
[
0
,
x
]
loaded_weight
=
torch
.
nn
.
functional
.
pad
(
input
=
loaded_weight
,
pad
=
padding
,
value
=
value
,
)
return
loaded_weight
def
compressed_tensors_get_config
(
config
:
dict
[
str
,
Any
],
key
:
str
):
assert
key
in
[
"weights"
,
"input_activations"
]
target_group_config
=
None
for
group_config
in
config
[
"config_groups"
].
values
():
if
"Linear"
in
group_config
[
"targets"
]:
if
"weights"
not
in
group_config
:
return
None
if
key
not
in
group_config
or
group_config
[
key
]
is
None
:
return
None
target_group_config
=
group_config
[
key
].
copy
()
break
if
target_group_config
is
None
:
return
None
target_group_config
[
"quant_method"
]
=
config
[
"quant_method"
]
if
config
[
"quant_method"
]
==
"compressed-tensors"
:
target_group_config
[
"format"
]
=
config
[
"format"
]
elif
config
[
"quant_method"
]
==
"modelopt"
:
target_group_config
[
"quant_algo"
]
=
config
[
"quant_algo"
]
return
target_group_config
class
HummingConfig
(
QuantizationConfig
):
packed_modules_mapping
=
{}
def
__init__
(
self
,
full_config
:
dict
[
str
,
Any
]
|
None
=
None
):
assert_humming_available
()
self
.
full_config
:
dict
[
str
,
Any
]
=
full_config
or
{}
@
classmethod
def
get_name
(
cls
)
->
QuantizationMethods
:
return
"humming"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
list
[
torch
.
dtype
]:
return
[
torch
.
bfloat16
,
torch
.
half
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
75
@
classmethod
def
get_config_filenames
(
cls
)
->
list
[
str
]:
return
[]
@
classmethod
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"HummingConfig"
:
return
cls
(
full_config
=
config
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
,
hf_config
=
None
)
->
QuantizationMethods
|
None
:
return
"humming"
if
user_quant
==
"humming"
else
None
def
apply_vllm_mapper
(
self
,
hf_to_vllm_mapper
:
"WeightsMapper"
):
self
.
hf_to_vllm_mapper
=
hf_to_vllm_mapper
def
is_layer_skipped
(
self
,
config
:
dict
[
str
,
Any
],
prefix
:
str
):
keys
=
[
"ignored_layers"
,
"ignore"
,
"modules_to_not_convert"
]
ignored_layers
=
self
.
get_from_keys_or
(
config
,
keys
,
[])
or
[]
if
hasattr
(
self
,
"hf_to_vllm_mapper"
):
ignored_layers
=
self
.
hf_to_vllm_mapper
.
apply_list
(
ignored_layers
)
if
any
(
module_name
in
prefix
for
module_name
in
ignored_layers
):
return
True
if
"lm_head"
in
prefix
:
return
True
for
regex
in
config
.
get
(
"dynamic"
,
{}):
if
regex
[:
1
]
!=
"-"
:
continue
if
re
.
match
(
regex
[
2
:],
prefix
):
return
True
return
False
def
get_layer_weight_schema
(
self
,
config
:
dict
[
str
,
Any
],
prefix
:
str
):
if
self
.
is_layer_skipped
(
config
,
prefix
):
return
None
if
config
[
"quant_method"
]
in
[
"compressed-tensors"
,
"modelopt"
]:
group_config
=
compressed_tensors_get_config
(
config
,
"weights"
)
if
group_config
is
None
:
return
None
config
=
group_config
layer_config
=
config
layer_dynamic
=
config
.
get
(
"dynamic"
,
{})
if
not
isinstance
(
layer_dynamic
,
dict
):
layer_dynamic
=
{}
for
regex
,
override_config
in
layer_dynamic
.
items
():
if
regex
[:
1
]
!=
"+"
:
continue
if
re
.
match
(
regex
[
2
:],
prefix
):
layer_config
=
config
.
copy
()
layer_config
.
update
(
override_config
)
break
if
"quant_method"
in
layer_config
:
return
BaseWeightSchema
.
from_config
(
layer_config
)
return
None
def
get_layer_input_schema
(
self
,
config
:
dict
[
str
,
Any
],
prefix
:
str
):
if
self
.
is_layer_skipped
(
config
,
prefix
):
return
None
if
config
[
"quant_method"
]
in
[
"compressed-tensors"
,
"modelopt"
]:
group_config
=
compressed_tensors_get_config
(
config
,
"input_activations"
)
if
group_config
is
None
:
return
None
config
=
group_config
if
config
.
get
(
"quant_method"
,
None
)
in
BaseInputSchema
.
INPUT_SCHEMA_MAP
:
return
BaseInputSchema
.
from_config
(
config
)
return
None
def
get_quant_config_for_layer
(
self
,
prefix
:
str
,
layer_type
:
str
)
->
"HummingLayerQuantizationConfig | None"
:
weight_schema
:
BaseWeightSchema
|
None
=
None
force_weight_schema
:
HummingWeightSchema
|
None
=
None
if
self
.
full_config
:
weight_schema
=
self
.
get_layer_weight_schema
(
self
.
full_config
,
prefix
)
is_online_quant
=
False
online_quant_config
=
envs
.
VLLM_HUMMING_ONLINE_QUANT_CONFIG
or
{}
if
not
self
.
full_config
or
online_quant_config
.
get
(
"force_requant"
,
False
):
online_quant_config
[
"quant_method"
]
=
"humming"
schema
=
self
.
get_layer_weight_schema
(
online_quant_config
,
prefix
)
if
not
self
.
full_config
:
weight_schema
=
schema
is_online_quant
=
True
else
:
force_weight_schema
=
schema
if
weight_schema
is
not
None
:
if
weight_schema
.
quant_method
==
"gpt_oss_mxfp4"
and
layer_type
!=
"moe"
:
return
None
input_schema
=
None
force_input_schema
=
None
if
self
.
full_config
:
input_schema
=
self
.
get_layer_input_schema
(
self
.
full_config
,
prefix
)
if
envs
.
VLLM_HUMMING_INPUT_QUANT_CONFIG
:
quant_config
=
envs
.
VLLM_HUMMING_INPUT_QUANT_CONFIG
.
copy
()
quant_config
[
"quant_method"
]
=
"humming"
force_input_schema
=
self
.
get_layer_input_schema
(
quant_config
,
prefix
)
if
input_schema
is
None
:
input_schema
=
force_input_schema
if
force_weight_schema
is
not
None
and
force_input_schema
is
None
:
force_input_schema
=
HummingInputSchema
()
return
HummingLayerQuantizationConfig
(
weight_schema
=
weight_schema
,
input_schema
=
input_schema
,
force_weight_schema
=
force_weight_schema
,
force_input_schema
=
force_input_schema
,
is_online_quant
=
is_online_quant
,
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
"QuantizeMethodBase | None"
:
layer_type
=
"other"
if
isinstance
(
layer
,
FusedMoE
):
layer_type
=
"moe"
elif
isinstance
(
layer
,
LinearBase
):
layer_type
=
"linear"
# TODO: remove this after humming moe backend is ready
quant_method
=
self
.
full_config
.
get
(
"quant_method"
,
None
)
moe_activation
=
getattr
(
layer
,
"activation"
,
None
)
if
quant_method
==
"mxfp4"
and
moe_activation
==
MoEActivation
.
SWIGLUOAI
:
self
.
full_config
[
"quan_method"
]
=
"gpt_oss_mxfp4"
quant_config
=
self
.
get_quant_config_for_layer
(
prefix
,
layer_type
)
if
quant_config
is
None
:
if
isinstance
(
layer
,
FusedMoE
):
return
UnquantizedFusedMoEMethod
(
layer
.
moe_config
)
elif
isinstance
(
layer
,
LinearBase
):
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
LinearBase
):
return
HummingLinearMethod
(
quant_config
)
elif
isinstance
(
layer
,
FusedMoE
):
return
HummingMoEMethod
(
quant_config
,
layer
.
moe_config
)
return
None
class
HummingLayerQuantizationConfig
(
HummingConfig
):
def
__init__
(
self
,
weight_schema
:
"BaseWeightSchema"
,
input_schema
:
"BaseInputSchema | None"
=
None
,
force_weight_schema
:
"HummingWeightSchema | None"
=
None
,
force_input_schema
:
"HummingInputSchema | None"
=
None
,
is_online_quant
:
bool
=
False
,
):
self
.
weight_schema
=
weight_schema
if
input_schema
is
None
:
input_schema
=
HummingInputSchema
()
self
.
input_schema
=
input_schema
self
.
force_weight_schema
=
force_weight_schema
self
.
force_input_schema
=
force_input_schema
self
.
is_online_quant
=
is_online_quant
@
classmethod
def
from_config
(
cls
,
config
):
weight_schema
=
BaseWeightSchema
.
from_config
(
config
)
return
cls
(
weight_schema
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
QuantizeMethodBase
|
None
:
raise
NotImplementedError
class
HummingLinearMethod
(
LinearMethodBase
):
def
__init__
(
self
,
quant_config
:
HummingLayerQuantizationConfig
):
self
.
quant_config
=
quant_config
self
.
weight_schema
=
quant_config
.
weight_schema
self
.
input_schema
=
quant_config
.
input_schema
self
.
force_weight_schema
=
quant_config
.
force_weight_schema
self
.
force_input_schema
=
quant_config
.
force_input_schema
self
.
is_online_quant
=
self
.
quant_config
.
is_online_quant
def
prepare_weight_loader
(
self
,
layer
:
torch
.
nn
.
Module
,
weight_loader
:
Callable
):
def
new_weight_loader
(
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
str
|
int
|
None
=
None
,
):
name
=
param
.
param_name
float_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
is_unquantized
=
name
==
"weight"
and
loaded_weight
.
dtype
in
float_dtypes
if
is_unquantized
and
self
.
is_online_quant
:
# online quant (fp16/bf16 -> quant_type)
assert
isinstance
(
self
.
weight_schema
,
HummingWeightSchema
)
f16_dtype
=
DataType
.
from_torch_dtype
(
layer
.
param_dtype
)
has_global_scale
=
"TENSOR"
in
str
(
self
.
weight_schema
.
weight_scale_type
)
tensor_list
=
quantize_weight
(
weight
=
loaded_weight
,
dtype
=
self
.
weight_schema
.
b_dtype
,
scale_dtype
=
self
.
weight_schema
.
bs_dtype
or
f16_dtype
,
group_size
=
self
.
weight_schema
.
weight_scale_group_size
,
has_zero_point
=
self
.
weight_schema
.
has_zero_point
,
has_global_scale
=
has_global_scale
,
is_fp_zero_point
=
self
.
weight_schema
.
is_fp_zero_point
,
pack
=
True
,
)
key_list
=
[
"weight"
,
"weight_scale"
,
"zero_point"
,
"global_scale"
]
for
key
,
tensor
in
zip
(
key_list
,
tensor_list
):
if
tensor
is
None
or
tensor
.
nelement
()
==
0
:
continue
param
=
getattr
(
layer
,
key
)
param
.
weight_loader
(
param
,
tensor
,
shard_id
)
return
None
elif
is_unquantized
and
not
self
.
is_online_quant
:
# fallback to unquantized linear
# some model skip some layer when quantizing model, but
# don't mark the layer as unquantized.
if
not
layer
.
is_fallback
:
layer
.
is_fallback
=
True
for
name
,
_
in
list
(
layer
.
named_parameters
()):
if
name
!=
"bias"
:
delattr
(
layer
,
name
)
delattr
(
layer
,
"locks"
)
self
.
__class__
=
UnquantizedLinearMethod
# type: ignore
tensor
=
torch
.
empty
(
(
layer
.
output_partition_sizes_sum
,
layer
.
input_size_per_partition
,
),
dtype
=
layer
.
param_dtype
,
device
=
param
.
device
,
)
extra_weight_attrs
=
layer
.
extra_weight_attrs
.
copy
()
orig_weight_loader
=
extra_weight_attrs
.
pop
(
"weight_loader"
)
layer
.
weight
=
ModelWeightParameter
(
data
=
tensor
,
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
orig_weight_loader
,
)
layer
.
weight
.
tp_size
=
layer
.
tp_size
layer
.
weight
.
tp_rank
=
layer
.
tp_rank
set_weight_attrs
(
layer
.
weight
,
extra_weight_attrs
)
param
=
layer
.
weight
if
shard_id
is
not
None
:
return
layer
.
weight
.
weight_loader
(
param
,
loaded_weight
,
shard_id
)
return
layer
.
weight
.
weight_loader
(
param
,
loaded_weight
)
# weight processing logic for specific quantization schema
loaded_weight
=
self
.
weight_schema
.
process_loaded_weight
(
tensor
=
loaded_weight
,
name
=
name
,
)
if
shard_id
is
not
None
:
return
weight_loader
(
param
,
loaded_weight
,
shard_id
)
return
weight_loader
(
param
,
loaded_weight
)
return
new_weight_loader
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
list
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
is_fallback
=
False
layer
.
param_dtype
=
params_dtype
layer
.
input_size
=
input_size
layer
.
output_size
=
output_size
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_partition_sizes_sum
=
sum
(
output_partition_sizes
)
layer
.
output_partition_sizes
=
output_partition_sizes
layer
.
extra_weight_attrs
=
extra_weight_attrs
.
copy
()
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
,
default_weight_loader
)
new_weight_loader
=
self
.
prepare_weight_loader
(
layer
,
weight_loader
)
extra_weight_attrs
[
"weight_loader"
]
=
new_weight_loader
for
key
in
[
"weight_block_size"
,
"block_structure"
]:
block_size
=
getattr
(
self
.
weight_schema
,
key
,
None
)
if
block_size
is
not
None
:
layer
.
weight_block_size
=
block_size
weight_tensor_attrs
=
self
.
weight_schema
.
get_tensors_attrs
(
shape_n
=
layer
.
output_partition_sizes_sum
,
shape_k
=
layer
.
input_size_per_partition
,
param_dtype
=
params_dtype
,
stack_size
=
len
(
layer
.
output_partition_sizes
),
)
input_tensor_attrs
=
self
.
input_schema
.
get_tensors_attrs
(
shape_k
=
layer
.
input_size_per_partition
,
param_dtype
=
params_dtype
,
stack_size
=
len
(
layer
.
output_partition_sizes
),
)
tensors_attrs
=
weight_tensor_attrs
|
input_tensor_attrs
for
name
,
attrs
in
tensors_attrs
.
items
():
tensor
=
torch
.
empty
(
attrs
[
"shape"
],
dtype
=
attrs
[
"dtype"
])
extra_attrs
=
attrs
.
get
(
"extra_attrs"
,
{}).
copy
()
extra_attrs
.
update
(
extra_weight_attrs
)
param
=
prepare_param
(
tensor
,
name
,
extra_attrs
)
setattr
(
layer
,
name
,
param
)
locks
=
torch
.
zeros
(
1024
,
dtype
=
torch
.
int32
)
layer
.
register_buffer
(
"locks"
,
locks
)
if
self
.
force_input_schema
is
not
None
:
self
.
input_schema
=
self
.
force_input_schema
if
not
hasattr
(
layer
,
"weight"
):
param
=
prepare_param
(
torch
.
tensor
(
0
),
"weight"
,
extra_weight_attrs
)
layer
.
weight
=
param
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
layer
.
is_fallback
:
return
None
# convert from checkpoint format to humming format
if
not
isinstance
(
self
.
weight_schema
,
HummingWeightSchema
):
self
.
weight_schema
,
tensors
=
self
.
weight_schema
.
convert_humming
(
tensors
=
layer
.
state_dict
(),
shape_n_stacks
=
layer
.
output_partition_sizes
,
shape_k_stacks
=
[
layer
.
input_size_per_partition
],
param_dtype
=
layer
.
param_dtype
,
)
self
.
input_schema
,
_
=
self
.
input_schema
.
convert_humming
(
tensors
=
layer
.
state_dict
(),
shape_n_stacks
=
layer
.
output_partition_sizes
,
shape_k_stacks
=
[
layer
.
input_size_per_partition
],
param_dtype
=
layer
.
param_dtype
,
)
for
name
,
_
in
list
(
layer
.
named_parameters
()):
delattr
(
layer
,
name
)
for
name
,
tensor
in
tensors
.
items
():
param
=
torch
.
nn
.
Parameter
(
tensor
,
requires_grad
=
False
)
setattr
(
layer
,
name
,
param
)
del
tensors
# force requant (origin quant setting -> fp16/bf16 -> new_quant setting)
assert
isinstance
(
self
.
weight_schema
,
HummingWeightSchema
)
force_requant
=
self
.
force_weight_schema
is
not
None
if
force_requant
and
self
.
weight_schema
!=
self
.
force_weight_schema
:
tensors
=
self
.
weight_schema
.
requant_tensors
(
tensors
=
layer
.
state_dict
(),
target_weight_schema
=
self
.
force_weight_schema
,
param_dtype
=
layer
.
param_dtype
,
)
self
.
weight_schema
=
self
.
force_weight_schema
for
name
,
_
in
list
(
layer
.
named_parameters
()):
if
name
!=
"bias"
:
delattr
(
layer
,
name
)
for
name
,
tensor
in
tensors
.
items
():
param
=
torch
.
nn
.
Parameter
(
tensor
,
requires_grad
=
False
)
setattr
(
layer
,
name
,
param
)
del
tensors
# prepare layer config from humming kernel
HummingMethod
.
prepare_layer_meta
(
layer
=
layer
,
shape_n
=
layer
.
output_partition_sizes_sum
,
shape_k
=
layer
.
input_size_per_partition
,
weight_schema
=
self
.
weight_schema
,
input_schema
=
self
.
input_schema
,
pad_n_to_multiple
=
256
,
pad_k_to_multiple
=
128
,
has_bias
=
layer
.
has_bias
,
torch_dtype
=
layer
.
param_dtype
,
)
# preprocess weight for inference
HummingMethod
.
transform_humming_layer
(
layer
)
# compute_config: kernel configs that do not directly affect weights
# but significantly impact kernel behavior or computation precision.
# see https://github.com/inclusionAI/humming/blob/main/docs/config.md
compute_config
=
{
"use_batch_invariant"
:
envs
.
VLLM_BATCH_INVARIANT
,
"use_f16_accum"
:
envs
.
VLLM_HUMMING_USE_F16_ACCUM
,
"gemm_type"
:
"dense"
,
}
self
.
compute_config
=
json
.
dumps
(
compute_config
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
flatten_inputs
=
x
.
view
(
-
1
,
x
.
size
(
-
1
))
output
=
HummingMethod
.
forward_layer
(
layer
=
layer
,
inputs
=
flatten_inputs
,
compute_config
=
self
.
compute_config
,
)
output
=
output
.
view
(
*
x
.
shape
[:
-
1
],
output
.
size
(
-
1
))
return
output
class
HummingMoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
quant_config
:
HummingLayerQuantizationConfig
,
moe
:
"FusedMoEConfig"
)
->
None
:
super
().
__init__
(
moe
)
self
.
quant_config
=
quant_config
self
.
moe
=
moe
self
.
weight_schema
=
quant_config
.
weight_schema
self
.
input_schema
=
quant_config
.
input_schema
self
.
force_weight_schema
=
quant_config
.
force_weight_schema
self
.
force_input_schema
=
quant_config
.
force_input_schema
def
prepare_weight_loader
(
self
,
layer
,
weight_loader
):
def
new_weight_loader
(
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
|
None
=
None
,
return_success
:
bool
=
False
,
):
name
=
param
.
param_name
float_dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
is_unquantized
=
name
==
"weight"
and
loaded_weight
.
dtype
in
float_dtypes
# online quant (fp16/bf16 -> quant_type)
if
is_unquantized
:
assert
isinstance
(
self
.
weight_schema
,
HummingWeightSchema
)
f16_dtype
=
DataType
.
from_torch_dtype
(
layer
.
param_dtype
)
has_global_scale
=
"TENSOR"
in
str
(
self
.
weight_schema
.
weight_scale_type
)
tensor_list
=
quantize_weight
(
weight
=
loaded_weight
,
dtype
=
self
.
weight_schema
.
b_dtype
,
scale_dtype
=
self
.
weight_schema
.
bs_dtype
or
f16_dtype
,
group_size
=
self
.
weight_schema
.
weight_scale_group_size
,
has_zero_point
=
self
.
weight_schema
.
has_zero_point
,
has_global_scale
=
has_global_scale
,
is_fp_zero_point
=
self
.
weight_schema
.
is_fp_zero_point
,
pack
=
True
,
)
key_list
=
[
"weight"
,
"weight_scale"
,
"zero_point"
,
"global_scale"
]
success
=
True
for
key
,
tensor
in
zip
(
key_list
,
tensor_list
):
if
tensor
is
None
or
tensor
.
nelement
()
==
0
:
continue
sublayer_name
=
"w2"
if
shard_id
==
"w2"
else
"w13"
param
=
getattr
(
layer
,
sublayer_name
+
"_"
+
key
)
part_subccess
=
param
.
weight_loader
(
param
=
param
,
loaded_weight
=
tensor
.
cpu
(),
weight_name
=
shard_id
+
"_"
+
key
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
return_success
=
return_success
,
)
success
=
success
and
part_subccess
return
success
if
return_success
else
None
# weight processing logic for specific quantization schema
loaded_weight
=
self
.
weight_schema
.
process_loaded_weight
(
tensor
=
loaded_weight
,
name
=
name
,
)
return
weight_loader
(
param
,
loaded_weight
,
weight_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
return_success
=
return_success
,
)
return
new_weight_loader
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
layer
.
num_experts
=
num_experts
layer
.
param_dtype
=
params_dtype
layer
.
intermediate_size
=
intermediate_size_per_partition
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
,
default_weight_loader
)
weight_loader
=
self
.
prepare_weight_loader
(
layer
,
weight_loader
)
extra_weight_attrs
[
"weight_loader"
]
=
weight_loader
# sublayer: a layer contains multiple sets of weights for quantized GEMM
# (e.g., weight, weight_scale, etc.).
# The weight names of sublayer start with the prefix "{sublayer_name}_"
layer
.
sublayer_configs
=
{
"w13"
:
{
"shape_n"
:
intermediate_size_per_partition
*
2
,
"shape_k"
:
hidden_size
,
"tensors_attrs"
:
self
.
weight_schema
.
get_padded_tensors_attrs
(
shape_n
=
intermediate_size_per_partition
*
2
,
shape_k
=
hidden_size
,
num_experts
=
num_experts
,
param_dtype
=
params_dtype
,
has_bias
=
self
.
moe
.
has_bias
,
),
},
"w2"
:
{
"shape_n"
:
hidden_size
,
"shape_k"
:
intermediate_size_per_partition
,
"tensors_attrs"
:
self
.
weight_schema
.
get_padded_tensors_attrs
(
shape_n
=
hidden_size
,
shape_k
=
intermediate_size_per_partition
,
num_experts
=
num_experts
,
param_dtype
=
params_dtype
,
has_bias
=
self
.
moe
.
has_bias
,
),
},
}
for
sublayer_name
,
configs
in
layer
.
sublayer_configs
.
items
():
for
name
,
attrs
in
configs
[
"tensors_attrs"
].
items
():
tensor
=
torch
.
empty
(
attrs
[
"shape"
],
dtype
=
attrs
[
"dtype"
])
param
=
torch
.
nn
.
Parameter
(
tensor
,
requires_grad
=
False
)
extra_attrs
=
attrs
.
get
(
"extra_attrs"
,
{}).
copy
()
extra_attrs
.
update
(
extra_weight_attrs
)
param
=
prepare_moe_param
(
tensor
,
name
,
extra_attrs
)
setattr
(
layer
,
f
"
{
sublayer_name
}
_
{
name
}
"
,
param
)
if
self
.
force_input_schema
is
not
None
:
self
.
input_schema
=
self
.
force_input_schema
locks
=
torch
.
zeros
(
1024
,
dtype
=
torch
.
int32
)
layer
.
register_buffer
(
"locks"
,
locks
)
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
FusedMoEQuantConfig
:
self
.
process_weights_after_loading
(
layer
)
input_schema
=
self
.
input_schemas
[
"w13"
]
weight_schema
=
self
.
weight_schemas
[
"w13"
]
a_dtype
=
input_schema
.
a_dtype
if
a_dtype
is
None
or
a_dtype
.
num_bits
==
16
:
a_quant_desc
=
FusedMoEQuantDesc
(
dtype
=
None
)
else
:
shape
=
GroupShape
(
row
=
1
,
col
=-
1
)
a_quant_desc
=
FusedMoEQuantDesc
(
dtype
=
str
(
a_dtype
),
shape
=
shape
)
weight_scale_group_size
=
weight_schema
.
weight_scale_group_size
weight_scale_group_size_n
=
weight_schema
.
weight_scale_group_size_n
weight_group_shape
:
tuple
[
int
,
...]
=
()
if
weight_scale_group_size_n
>
1
:
weight_group_shape
=
GroupShape
(
row
=
weight_scale_group_size
,
col
=
weight_scale_group_size_n
,
)
elif
weight_scale_group_size
==
0
:
weight_group_shape
=
GroupShape
(
row
=-
1
,
col
=
1
)
else
:
weight_group_shape
=
GroupShape
(
row
=
weight_scale_group_size
,
col
=
1
)
w1_quant_desc
=
FusedMoEQuantDesc
(
dtype
=
str
(
weight_schema
.
b_dtype
),
shape
=
weight_group_shape
,
scale
=
getattr
(
layer
,
"w13_weight_scale"
,
None
),
alpha_or_gscale
=
getattr
(
layer
,
"w13_global_scale"
,
None
),
zp
=
getattr
(
layer
,
"w13_zero_point"
,
None
),
bias
=
getattr
(
layer
,
"w13_bias"
,
None
),
)
w2_quant_desc
=
FusedMoEQuantDesc
(
dtype
=
str
(
weight_schema
.
b_dtype
),
shape
=
weight_group_shape
,
scale
=
getattr
(
layer
,
"w2_weight_scale"
,
None
),
alpha_or_gscale
=
getattr
(
layer
,
"w2_global_scale"
,
None
),
zp
=
getattr
(
layer
,
"w2_zero_point"
,
None
),
bias
=
getattr
(
layer
,
"w2_bias"
,
None
),
)
return
FusedMoEQuantConfig
(
_a1
=
a_quant_desc
,
_a2
=
a_quant_desc
,
_w1
=
w1_quant_desc
,
_w2
=
w2_quant_desc
,
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
getattr
(
self
,
"processed"
,
False
):
return
self
.
processed
=
True
self
.
weight_schemas
=
{}
self
.
input_schemas
=
{}
for
sublayer_name
,
configs
in
layer
.
sublayer_configs
.
items
():
input_schema
=
self
.
input_schema
weight_schema
=
self
.
weight_schema
# convert from checkpoint format to humming format
if
not
isinstance
(
weight_schema
,
HummingWeightSchema
):
tensors
:
dict
[
str
,
torch
.
Tensor
]
=
dict
(
(
key
.
removeprefix
(
sublayer_name
+
"_"
),
value
)
for
key
,
value
in
layer
.
state_dict
().
items
()
if
key
.
startswith
(
sublayer_name
+
"_"
)
)
shape_k_stacks
=
[
configs
[
"shape_k"
]]
shape_n_stacks
=
[
configs
[
"shape_n"
]]
if
sublayer_name
==
"w13"
:
shape_n_stacks
=
[
configs
[
"shape_n"
]
//
2
]
*
2
weight_schema
,
tensors
=
weight_schema
.
convert_humming
(
tensors
=
tensors
,
shape_n_stacks
=
shape_n_stacks
,
shape_k_stacks
=
shape_k_stacks
,
param_dtype
=
layer
.
param_dtype
,
num_experts
=
layer
.
num_experts
,
)
input_schema
,
_
=
input_schema
.
convert_humming
(
tensors
=
tensors
,
shape_n_stacks
=
shape_n_stacks
,
shape_k_stacks
=
shape_k_stacks
,
param_dtype
=
layer
.
param_dtype
,
num_experts
=
layer
.
num_experts
,
)
for
name
,
_
in
list
(
layer
.
named_parameters
()):
if
not
name
.
startswith
(
sublayer_name
+
"_"
):
continue
delattr
(
layer
,
name
)
for
name
,
tensor
in
tensors
.
items
():
name
=
f
"
{
sublayer_name
}
_
{
name
}
"
param
=
torch
.
nn
.
Parameter
(
tensor
,
requires_grad
=
False
)
setattr
(
layer
,
name
,
param
)
self
.
weight_schemas
[
sublayer_name
]
=
weight_schema
self
.
input_schemas
[
sublayer_name
]
=
input_schema
# force requant (origin quant setting -> fp16/bf16 -> new_quant setting)
assert
isinstance
(
weight_schema
,
HummingWeightSchema
)
force_requant
=
self
.
force_weight_schema
is
not
None
if
force_requant
and
weight_schema
!=
self
.
force_weight_schema
:
tensors
=
dict
(
(
key
.
removeprefix
(
sublayer_name
+
"_"
),
value
)
for
key
,
value
in
layer
.
state_dict
().
items
()
if
key
.
startswith
(
sublayer_name
+
"_"
)
)
tensors
=
weight_schema
.
requant_tensors
(
tensors
=
tensors
,
target_weight_schema
=
self
.
force_weight_schema
,
param_dtype
=
layer
.
param_dtype
,
)
weight_schema
=
self
.
force_weight_schema
for
name
,
_
in
list
(
layer
.
named_parameters
()):
if
not
name
.
startswith
(
sublayer_name
+
"_"
):
continue
if
name
==
sublayer_name
+
"_bias"
:
continue
delattr
(
layer
,
name
)
for
name
,
tensor
in
tensors
.
items
():
name
=
f
"
{
sublayer_name
}
_
{
name
}
"
param
=
torch
.
nn
.
Parameter
(
tensor
,
requires_grad
=
False
)
setattr
(
layer
,
name
,
param
)
del
tensors
# prepare layer config from humming kernel
HummingMethod
.
prepare_layer_meta
(
layer
=
layer
,
shape_n
=
configs
[
"shape_n"
],
shape_k
=
configs
[
"shape_k"
],
pad_n_to_multiple
=
256
,
pad_k_to_multiple
=
128
,
input_schema
=
input_schema
,
weight_schema
=
weight_schema
,
has_bias
=
self
.
moe
.
has_bias
,
num_experts
=
layer
.
num_experts
,
torch_dtype
=
layer
.
param_dtype
,
sublayer_name
=
sublayer_name
,
)
# preprocess weight for inference
HummingMethod
.
transform_humming_layer
(
layer
,
sublayer_name
=
sublayer_name
)
# use moe modular
experts
:
HummingIndexedExperts
|
HummingGroupedExperts
if
get_humming_moe_gemm_type
()
==
"indexed"
:
experts
=
HummingIndexedExperts
(
layer
,
self
)
else
:
experts
=
HummingGroupedExperts
(
layer
,
self
)
self
.
experts
=
experts
def
select_gemm_impl
(
self
,
prepare_finalize
,
layer
:
torch
.
nn
.
Module
,
):
from
vllm.model_executor.layers.fused_moe
import
modular_kernel
as
mk
activation_format
=
prepare_finalize
.
activation_format
if
activation_format
==
mk
.
FusedMoEActivationFormat
.
BatchedExperts
:
return
BatchedHummingGroupedExperts
(
layer
,
self
,
prepare_finalize
)
elif
get_humming_moe_gemm_type
()
==
"indexed"
:
return
HummingIndexedExperts
(
layer
,
self
,
prepare_finalize
)
else
:
return
HummingGroupedExperts
(
layer
,
self
,
prepare_finalize
)
def
apply
(
self
,
layer
:
FusedMoE
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
shared_experts_input
:
torch
.
Tensor
|
None
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
workspace1
,
workspace2
,
output
=
self
.
experts
.
make_workspaces
(
M
=
topk_ids
.
size
(
0
),
topk
=
topk_ids
.
size
(
1
),
activation
=
layer
.
activation
,
)
assert
workspace1
.
data_ptr
()
==
output
.
data_ptr
()
self
.
experts
.
main_apply
(
hidden_states
=
x
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
workspace1
=
workspace1
,
workspace2
=
workspace2
,
expert_tokens_meta
=
None
,
)
return
output
vllm/model_executor/layers/quantization/utils/humming_moe_utils.py
0 → 100644
View file @
9f771b3a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
,
)
def
humming_moe_align
(
configs
:
list
[
int
],
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
len
(
configs
)
>
0
and
len
(
configs
)
%
3
==
0
# NOTE: we choose moe_block_size based on
# num_tokens * top_k (= topk_ids.nelement())
shape_m
=
topk_ids
.
nelement
()
for
i
in
range
(
len
(
configs
)
//
3
):
if
shape_m
>
configs
[
i
*
3
]
and
shape_m
<=
configs
[
i
*
3
+
1
]:
block_size
=
configs
[
i
*
3
+
2
]
break
else
:
raise
ValueError
(
f
"Could not find a matching block_size for shape_m=
{
shape_m
}
"
)
return
moe_align_block_size
(
topk_ids
=
topk_ids
,
block_size
=
block_size
,
num_experts
=
num_experts
,
expert_map
=
expert_map
,
pad_sorted_ids
=
False
,
ignore_invalid_experts
=
True
,
)
vllm/model_executor/parameter.py
View file @
9f771b3a
...
@@ -605,8 +605,8 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size)
...
@@ -605,8 +605,8 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size)
def
_adjust_shard_indexes_for_packing
(
def
_adjust_shard_indexes_for_packing
(
shard_size
,
shard_offset
,
packed_factor
,
marlin_tile_size
shard_size
,
shard_offset
,
packed_factor
,
marlin_tile_size
):
):
shard_size
=
shard_size
//
packed_factor
shard_size
=
round
(
shard_size
//
packed_factor
)
shard_offset
=
shard_offset
//
packed_factor
shard_offset
=
round
(
shard_offset
//
packed_factor
)
if
marlin_tile_size
is
not
None
:
if
marlin_tile_size
is
not
None
:
return
_adjust_shard_indexes_for_marlin
(
return
_adjust_shard_indexes_for_marlin
(
shard_size
=
shard_size
,
shard_size
=
shard_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment