Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9f771b3a
Unverified
Commit
9f771b3a
authored
Apr 24, 2026
by
Jinzhen Lin
Committed by
GitHub
Apr 24, 2026
Browse files
[Quantization] add humming quantization kernel (#34556)
parent
c9d3c6e6
Changes
10
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1948 additions
and
9 deletions
+1948
-9
vllm/config/model.py
vllm/config/model.py
+4
-0
vllm/envs.py
vllm/envs.py
+32
-0
vllm/model_executor/layers/fused_moe/fused_humming_moe.py
vllm/model_executor/layers/fused_moe/fused_humming_moe.py
+690
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+5
-1
vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py
vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py
+202
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+13
-6
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+3
-0
vllm/model_executor/layers/quantization/humming.py
vllm/model_executor/layers/quantization/humming.py
+962
-0
vllm/model_executor/layers/quantization/utils/humming_moe_utils.py
...l_executor/layers/quantization/utils/humming_moe_utils.py
+35
-0
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+2
-2
No files found.
vllm/config/model.py
View file @
9f771b3a
...
@@ -953,8 +953,12 @@ class ModelConfig:
...
@@ -953,8 +953,12 @@ class ModelConfig:
"mxfp4"
,
"mxfp4"
,
"gpt_oss_mxfp4"
,
"gpt_oss_mxfp4"
,
"cpu_awq"
,
"cpu_awq"
,
"humming"
,
"gguf"
,
"gguf"
,
]
]
# if the user specifies humming, we should always use humming
if
self
.
quantization
==
"humming"
:
overrides
=
[
"humming"
]
+
overrides
quantization_methods
=
[
quantization_methods
=
[
q
for
q
in
supported_quantization
if
q
not
in
overrides
q
for
q
in
supported_quantization
if
q
not
in
overrides
]
]
...
...
vllm/envs.py
View file @
9f771b3a
...
@@ -152,6 +152,10 @@ if TYPE_CHECKING:
...
@@ -152,6 +152,10 @@ if TYPE_CHECKING:
VLLM_RAY_EXTRA_ENV_VARS_TO_COPY
:
str
=
""
VLLM_RAY_EXTRA_ENV_VARS_TO_COPY
:
str
=
""
VLLM_MARLIN_USE_ATOMIC_ADD
:
bool
=
False
VLLM_MARLIN_USE_ATOMIC_ADD
:
bool
=
False
VLLM_MARLIN_INPUT_DTYPE
:
Literal
[
"int8"
,
"fp8"
]
|
None
=
None
VLLM_MARLIN_INPUT_DTYPE
:
Literal
[
"int8"
,
"fp8"
]
|
None
=
None
VLLM_HUMMING_ONLINE_QUANT_CONFIG
:
dict
[
str
,
Any
]
|
None
=
None
VLLM_HUMMING_INPUT_QUANT_CONFIG
:
dict
[
str
,
Any
]
|
None
=
None
VLLM_HUMMING_USE_F16_ACCUM
:
bool
=
False
VLLM_HUMMING_MOE_GEMM_TYPE
:
Literal
[
"indexed"
,
"grouped"
,
"auto"
]
|
None
=
None
VLLM_MXFP4_USE_MARLIN
:
bool
|
None
=
None
VLLM_MXFP4_USE_MARLIN
:
bool
|
None
=
None
VLLM_DEEPEPLL_NVFP4_DISPATCH
:
bool
=
False
VLLM_DEEPEPLL_NVFP4_DISPATCH
:
bool
=
False
VLLM_V1_USE_OUTLINES_CACHE
:
bool
=
False
VLLM_V1_USE_OUTLINES_CACHE
:
bool
=
False
...
@@ -285,6 +289,15 @@ def maybe_convert_bool(value: str | None) -> bool | None:
...
@@ -285,6 +289,15 @@ def maybe_convert_bool(value: str | None) -> bool | None:
return
bool
(
int
(
value
))
return
bool
(
int
(
value
))
def
maybe_convert_json_str_or_file
(
value
:
str
|
None
)
->
dict
[
str
,
Any
]
|
None
:
if
value
is
None
:
return
None
if
os
.
path
.
exists
(
value
):
with
open
(
value
)
as
f
:
return
json
.
load
(
f
)
return
json
.
loads
(
value
)
def
disable_compile_cache
()
->
bool
:
def
disable_compile_cache
()
->
bool
:
return
bool
(
int
(
os
.
getenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"0"
)))
return
bool
(
int
(
os
.
getenv
(
"VLLM_DISABLE_COMPILE_CACHE"
,
"0"
)))
...
@@ -1193,6 +1206,25 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1193,6 +1206,25 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MARLIN_INPUT_DTYPE"
:
env_with_choices
(
"VLLM_MARLIN_INPUT_DTYPE"
:
env_with_choices
(
"VLLM_MARLIN_INPUT_DTYPE"
,
None
,
[
"int8"
,
"fp8"
]
"VLLM_MARLIN_INPUT_DTYPE"
,
None
,
[
"int8"
,
"fp8"
]
),
),
# The online quantization dtype for humming kernel
"VLLM_HUMMING_ONLINE_QUANT_CONFIG"
:
lambda
:
maybe_convert_json_str_or_file
(
os
.
environ
.
get
(
"VLLM_HUMMING_ONLINE_QUANT_CONFIG"
,
None
)
),
# The activation dtype config for humming kernel
"VLLM_HUMMING_INPUT_QUANT_CONFIG"
:
lambda
:
maybe_convert_json_str_or_file
(
os
.
environ
.
get
(
"VLLM_HUMMING_INPUT_QUANT_CONFIG"
,
None
)
),
# Whether to use fp16 accumulator mma
"VLLM_HUMMING_USE_F16_ACCUM"
:
lambda
:
maybe_convert_bool
(
os
.
environ
.
get
(
"VLLM_HUMMING_USE_F16_ACCUM"
,
"0"
)
),
# Whether to use indexed gemm for humming moe
# if 1, force use indexed gemm
# if 0, force use grouped gemm
# if None, choose better gemm type automatically
"VLLM_HUMMING_MOE_GEMM_TYPE"
:
lambda
:
maybe_convert_bool
(
os
.
environ
.
get
(
"VLLM_HUMMING_MOE_GEMM_TYPE"
,
None
)
),
# Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method
# Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method
# only supported on Blackwell GPUs and with
# only supported on Blackwell GPUs and with
# https://github.com/deepseek-ai/DeepEP/pull/341
# https://github.com/deepseek-ai/DeepEP/pull/341
...
...
vllm/model_executor/layers/fused_moe/fused_humming_moe.py
0 → 100644
View file @
9f771b3a
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/fused_moe/layer.py
View file @
9f771b3a
...
@@ -1097,7 +1097,11 @@ class FusedMoE(PluggableLayer):
...
@@ -1097,7 +1097,11 @@ class FusedMoE(PluggableLayer):
expert_id
:
int
,
expert_id
:
int
,
return_success
:
bool
=
False
,
return_success
:
bool
=
False
,
)
->
bool
|
None
:
)
->
bool
|
None
:
if
self
.
quant_config
and
self
.
quant_config
.
get_name
()
==
"gpt_oss_mxfp4"
:
quant_config_name
=
self
.
quant_config
and
self
.
quant_config
.
get_name
()
if
quant_config_name
==
"humming"
:
assert
hasattr
(
self
.
quant_method
,
"weight_schema"
)
quant_config_name
=
self
.
quant_method
.
weight_schema
.
quant_method
if
quant_config_name
==
"gpt_oss_mxfp4"
:
# (FIXME) for gpt-oss all experts are combined
# (FIXME) for gpt-oss all experts are combined
if
"bias"
in
weight_name
:
if
"bias"
in
weight_name
:
dim1
=
loaded_weight
.
shape
[
1
]
dim1
=
loaded_weight
.
shape
[
1
]
...
...
vllm/model_executor/layers/fused_moe/moe_fused_mul_sum.py
0 → 100644
View file @
9f771b3a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
torch._subclasses.fake_tensor
import
FakeTensor
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
@
triton
.
jit
def
moe_fused_mul_sum_kernel
(
inputs_ptr
,
topk_weights_ptr
,
outputs_ptr
,
top_ids_ptr
,
expert_map_ptr
,
num_tokens
,
stride_m
,
has_expert_map
:
tl
.
constexpr
,
top_k
:
tl
.
constexpr
,
size
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
):
pid_k
=
tl
.
program_id
(
0
)
pid_m
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_k
=
pid_k
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
m_mask
=
offs_m
<
num_tokens
k_mask
=
offs_k
<
size
mask
=
m_mask
[:,
None
]
&
k_mask
[
None
,
:]
a_base
=
inputs_ptr
+
(
offs_m
*
stride_m
)[:,
None
]
+
offs_k
[
None
,
:]
b_base
=
topk_weights_ptr
+
offs_m
*
top_k
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_K
),
dtype
=
tl
.
float32
)
for
n
in
tl
.
static_range
(
top_k
):
b_val
=
tl
.
load
(
b_base
+
n
,
mask
=
m_mask
,
other
=
0.0
).
to
(
tl
.
float32
)
if
has_expert_map
:
id_val
=
tl
.
load
(
top_ids_ptr
+
offs_m
*
top_k
+
n
,
mask
=
m_mask
,
other
=
0
)
expert_mask
=
tl
.
load
(
expert_map_ptr
+
id_val
)
>=
0
a_vec
=
tl
.
load
(
a_base
+
n
*
size
,
mask
=
mask
&
expert_mask
[:,
None
],
other
=
0.0
,
).
to
(
tl
.
float32
)
else
:
a_vec
=
tl
.
load
(
a_base
+
n
*
size
,
mask
=
mask
,
other
=
0.0
,
).
to
(
tl
.
float32
)
acc
+=
a_vec
*
b_val
[:,
None
]
out_ptrs
=
outputs_ptr
+
(
offs_m
*
size
)[:,
None
]
+
offs_k
[
None
,
:]
tl
.
store
(
out_ptrs
,
acc
.
to
(
outputs_ptr
.
dtype
.
element_ty
),
mask
=
mask
,
)
def
_heuristic_config
(
num_tokens
:
int
,
top_k
:
int
,
size
:
int
,
element_size
:
int
,
):
is_fp32
=
element_size
>
2
is_sm90_plus
=
current_platform
.
has_device_capability
(
90
)
is_sm80_before
=
not
current_platform
.
has_device_capability
(
80
)
if
current_platform
.
has_device_capability
(
90
):
# SM90/SM100+: prefer small tiles + many CTAs.
if
is_fp32
:
BLOCK_M
=
1
if
num_tokens
<=
4
else
2
else
:
if
num_tokens
<=
4
:
BLOCK_M
=
1
elif
num_tokens
<=
128
:
BLOCK_M
=
2
else
:
BLOCK_M
=
4
elif
is_fp32
:
if
num_tokens
<=
4
:
BLOCK_M
=
1
elif
num_tokens
<=
32
:
BLOCK_M
=
2
elif
num_tokens
<=
128
:
BLOCK_M
=
4
else
:
BLOCK_M
=
4
else
:
if
num_tokens
<=
4
:
BLOCK_M
=
1
elif
num_tokens
<=
32
:
BLOCK_M
=
2
elif
num_tokens
<=
128
:
BLOCK_M
=
4
elif
num_tokens
<=
1024
:
BLOCK_M
=
16
else
:
BLOCK_M
=
8
if
is_fp32
:
max_block_k
=
256
elif
is_sm80_before
or
is_sm90_plus
:
max_block_k
=
512
else
:
max_block_k
=
1024
BLOCK_K
=
min
(
triton
.
next_power_of_2
(
size
),
max_block_k
)
BLOCK_K
=
max
(
BLOCK_K
,
256
)
total
=
BLOCK_M
*
BLOCK_K
if
is_fp32
:
num_warps
=
max
(
8
,
min
(
16
,
total
//
64
))
else
:
num_warps
=
max
(
4
,
min
(
16
,
total
//
256
))
if
is_sm80_before
:
num_warps
=
min
(
num_warps
,
8
)
num_stages
=
2
elif
is_sm90_plus
:
num_warps
=
min
(
num_warps
,
8
)
num_stages
=
4
if
total
<=
2048
else
2
else
:
num_stages
=
4
if
total
<=
2048
else
2
return
BLOCK_M
,
BLOCK_K
,
num_warps
,
num_stages
def
moe_fused_mul_sum
(
inputs
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
outputs
:
torch
.
Tensor
|
None
=
None
,
topk_ids
:
torch
.
Tensor
|
None
=
None
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
"""
Fused kernel for MoE (Mixture of Experts) to perform weighted summation
of expert outputs.
Args:
inputs: The output from experts.
Shape: (num_tokens, top_k, hidden_size).
topk_weights: The weights assigned to each expert for each token.
Shape: (num_tokens, top_k).
outputs: Optional pre-allocated output tensor.
Shape: (num_tokens, hidden_size).
topk_ids: Optional indices of the top-k experts. Used when
`expert_map` is provided. Shape: (num_tokens, top_k).
expert_map: Optional mapping for Expert Parallelism. A value < 0
indicates an invalid token/expert pair that will be skipped.
Returns:
The fused weighted sum of expert outputs.
Shape: (num_tokens, hidden_size).
"""
assert
inputs
.
ndim
==
3
assert
topk_weights
.
ndim
==
2
assert
inputs
.
is_contiguous
()
assert
topk_weights
.
is_contiguous
()
assert
inputs
.
dtype
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
)
assert
topk_weights
.
dtype
in
(
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
)
num_tokens
,
top_k
,
size
=
inputs
.
shape
output_shape
=
(
num_tokens
,
size
)
if
outputs
is
None
:
outputs
=
torch
.
empty
(
output_shape
,
dtype
=
inputs
.
dtype
,
device
=
inputs
.
device
)
assert
outputs
.
shape
==
output_shape
assert
topk_weights
.
shape
==
(
num_tokens
,
top_k
)
if
not
isinstance
(
inputs
,
FakeTensor
):
BLOCK_M
,
BLOCK_K
,
num_warps
,
num_stages
=
_heuristic_config
(
num_tokens
,
top_k
,
size
,
inputs
.
element_size
(),
)
grid
=
(
triton
.
cdiv
(
size
,
BLOCK_K
),
triton
.
cdiv
(
num_tokens
,
BLOCK_M
))
moe_fused_mul_sum_kernel
[
grid
](
inputs
,
topk_weights
,
outputs
,
topk_ids
,
expert_map
,
num_tokens
,
top_k
*
size
,
expert_map
is
not
None
,
top_k
,
size
,
BLOCK_M
,
BLOCK_K
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
outputs
vllm/model_executor/layers/linear.py
View file @
9f771b3a
...
@@ -60,6 +60,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
...
@@ -60,6 +60,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"ModelOptFp8PbWoLinearMethod"
,
"ModelOptFp8PbWoLinearMethod"
,
"QuarkLinearMethod"
,
"QuarkLinearMethod"
,
"ModelOptNvFp4LinearMethod"
,
"ModelOptNvFp4LinearMethod"
,
"HummingLinearMethod"
,
]
]
...
@@ -245,6 +246,7 @@ class LinearBase(PluggableLayer):
...
@@ -245,6 +246,7 @@ class LinearBase(PluggableLayer):
self
,
self
,
input_size
:
int
,
input_size
:
int
,
output_size
:
int
,
output_size
:
int
,
bias
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
torch
.
dtype
|
None
=
None
,
params_dtype
:
torch
.
dtype
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
quant_config
:
QuantizationConfig
|
None
=
None
,
...
@@ -258,6 +260,7 @@ class LinearBase(PluggableLayer):
...
@@ -258,6 +260,7 @@ class LinearBase(PluggableLayer):
# Keep input parameters
# Keep input parameters
self
.
input_size
=
input_size
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
output_size
=
output_size
self
.
has_bias
=
bias
self
.
skip_bias_add
=
skip_bias_add
self
.
skip_bias_add
=
skip_bias_add
if
params_dtype
is
None
:
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
...
@@ -323,6 +326,7 @@ class ReplicatedLinear(LinearBase):
...
@@ -323,6 +326,7 @@ class ReplicatedLinear(LinearBase):
super
().
__init__
(
super
().
__init__
(
input_size
,
input_size
,
output_size
,
output_size
,
bias
,
skip_bias_add
,
skip_bias_add
,
params_dtype
,
params_dtype
,
quant_config
,
quant_config
,
...
@@ -458,6 +462,7 @@ class ColumnParallelLinear(LinearBase):
...
@@ -458,6 +462,7 @@ class ColumnParallelLinear(LinearBase):
super
().
__init__
(
super
().
__init__
(
input_size
,
input_size
,
output_size
,
output_size
,
bias
,
skip_bias_add
,
skip_bias_add
,
params_dtype
,
params_dtype
,
quant_config
,
quant_config
,
...
@@ -483,6 +488,7 @@ class ColumnParallelLinear(LinearBase):
...
@@ -483,6 +488,7 @@ class ColumnParallelLinear(LinearBase):
else
self
.
weight_loader
else
self
.
weight_loader
),
),
)
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
dtype
=
params_dtype
)
torch
.
empty
(
self
.
output_size_per_partition
,
dtype
=
params_dtype
)
...
@@ -817,8 +823,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -817,8 +823,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# for the packing.
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
packed_factor
shard_size
=
round
(
shard_size
//
param
.
packed_factor
)
shard_offset
=
shard_offset
//
param
.
packed_factor
shard_offset
=
round
(
shard_offset
//
param
.
packed_factor
)
# Special case for Marlin.
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
param
,
shard_size
,
shard_offset
...
@@ -1252,8 +1258,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -1252,8 +1258,8 @@ class QKVParallelLinear(ColumnParallelLinear):
)
)
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
packed_factor
shard_size
=
round
(
shard_size
//
param
.
packed_factor
)
shard_offset
=
shard_offset
//
param
.
packed_factor
shard_offset
=
round
(
shard_offset
//
param
.
packed_factor
)
# Special case for Marlin.
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
...
@@ -1315,8 +1321,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -1315,8 +1321,8 @@ class QKVParallelLinear(ColumnParallelLinear):
# for the packing.
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
packed_factor
shard_size
=
round
(
shard_size
//
param
.
packed_factor
)
shard_offset
=
shard_offset
//
param
.
packed_factor
shard_offset
=
round
(
shard_offset
//
param
.
packed_factor
)
# Special case for Marlin.
# Special case for Marlin.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
...
@@ -1440,6 +1446,7 @@ class RowParallelLinear(LinearBase):
...
@@ -1440,6 +1446,7 @@ class RowParallelLinear(LinearBase):
super
().
__init__
(
super
().
__init__
(
input_size
,
input_size
,
output_size
,
output_size
,
bias
,
skip_bias_add
,
skip_bias_add
,
params_dtype
,
params_dtype
,
quant_config
,
quant_config
,
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
9f771b3a
...
@@ -22,6 +22,7 @@ QuantizationMethods = Literal[
...
@@ -22,6 +22,7 @@ QuantizationMethods = Literal[
"gptq_marlin"
,
"gptq_marlin"
,
"awq_marlin"
,
"awq_marlin"
,
"gptq"
,
"gptq"
,
"humming"
,
"compressed-tensors"
,
"compressed-tensors"
,
"bitsandbytes"
,
"bitsandbytes"
,
"experts_int8"
,
"experts_int8"
,
...
@@ -126,6 +127,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -126,6 +127,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from
.gguf
import
GGUFConfig
from
.gguf
import
GGUFConfig
from
.gptq
import
GPTQConfig
from
.gptq
import
GPTQConfig
from
.gptq_marlin
import
GPTQMarlinConfig
from
.gptq_marlin
import
GPTQMarlinConfig
from
.humming
import
HummingConfig
from
.inc
import
INCConfig
from
.inc
import
INCConfig
from
.modelopt
import
(
from
.modelopt
import
(
ModelOptFp8Config
,
ModelOptFp8Config
,
...
@@ -162,6 +164,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
...
@@ -162,6 +164,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"mxfp4"
:
Mxfp4Config
,
"mxfp4"
:
Mxfp4Config
,
"gpt_oss_mxfp4"
:
GptOssMxfp4Config
,
"gpt_oss_mxfp4"
:
GptOssMxfp4Config
,
"cpu_awq"
:
CPUAWQConfig
,
"cpu_awq"
:
CPUAWQConfig
,
"humming"
:
HummingConfig
,
"online"
:
OnlineQuantizationConfig
,
"online"
:
OnlineQuantizationConfig
,
}
}
...
...
vllm/model_executor/layers/quantization/humming.py
0 → 100644
View file @
9f771b3a
This diff is collapsed.
Click to expand it.
vllm/model_executor/layers/quantization/utils/humming_moe_utils.py
0 → 100644
View file @
9f771b3a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
,
)
def
humming_moe_align
(
configs
:
list
[
int
],
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
assert
len
(
configs
)
>
0
and
len
(
configs
)
%
3
==
0
# NOTE: we choose moe_block_size based on
# num_tokens * top_k (= topk_ids.nelement())
shape_m
=
topk_ids
.
nelement
()
for
i
in
range
(
len
(
configs
)
//
3
):
if
shape_m
>
configs
[
i
*
3
]
and
shape_m
<=
configs
[
i
*
3
+
1
]:
block_size
=
configs
[
i
*
3
+
2
]
break
else
:
raise
ValueError
(
f
"Could not find a matching block_size for shape_m=
{
shape_m
}
"
)
return
moe_align_block_size
(
topk_ids
=
topk_ids
,
block_size
=
block_size
,
num_experts
=
num_experts
,
expert_map
=
expert_map
,
pad_sorted_ids
=
False
,
ignore_invalid_experts
=
True
,
)
vllm/model_executor/parameter.py
View file @
9f771b3a
...
@@ -605,8 +605,8 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size)
...
@@ -605,8 +605,8 @@ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size)
def
_adjust_shard_indexes_for_packing
(
def
_adjust_shard_indexes_for_packing
(
shard_size
,
shard_offset
,
packed_factor
,
marlin_tile_size
shard_size
,
shard_offset
,
packed_factor
,
marlin_tile_size
):
):
shard_size
=
shard_size
//
packed_factor
shard_size
=
round
(
shard_size
//
packed_factor
)
shard_offset
=
shard_offset
//
packed_factor
shard_offset
=
round
(
shard_offset
//
packed_factor
)
if
marlin_tile_size
is
not
None
:
if
marlin_tile_size
is
not
None
:
return
_adjust_shard_indexes_for_marlin
(
return
_adjust_shard_indexes_for_marlin
(
shard_size
=
shard_size
,
shard_size
=
shard_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment