Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
359d2930
Unverified
Commit
359d2930
authored
Sep 24, 2025
by
Nikhil Gupta
Committed by
GitHub
Sep 24, 2025
Browse files
[fix]: add Arm 4bit fused moe support (#23809)
Signed-off-by:
Nikhil Gupta
<
nikhil.gupta2@arm.com
>
parent
9df8da54
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
488 additions
and
11 deletions
+488
-11
cmake/cpu_extension.cmake
cmake/cpu_extension.cmake
+2
-1
csrc/cpu/torch_bindings.cpp
csrc/cpu/torch_bindings.cpp
+10
-0
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
+156
-0
csrc/ops.h
csrc/ops.h
+6
-0
vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
+9
-6
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+0
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+305
-2
No files found.
cmake/cpu_extension.cmake
View file @
359d2930
...
@@ -258,7 +258,8 @@ set(VLLM_EXT_SRC
...
@@ -258,7 +258,8 @@ set(VLLM_EXT_SRC
"csrc/cpu/layernorm.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/cpu/torch_bindings.cpp"
)
"csrc/cpu/torch_bindings.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp"
)
if
(
AVX512_FOUND AND NOT AVX512_DISABLED
)
if
(
AVX512_FOUND AND NOT AVX512_DISABLED
)
set
(
VLLM_EXT_SRC
set
(
VLLM_EXT_SRC
...
...
csrc/cpu/torch_bindings.cpp
View file @
359d2930
...
@@ -88,8 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -88,8 +88,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int tp_rank, int blocksparse_local_blocks,"
" int tp_rank, int blocksparse_local_blocks,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_vert_stride, int blocksparse_block_size,"
" int blocksparse_head_sliding_step) -> ()"
);
" int blocksparse_head_sliding_step) -> ()"
);
ops
.
impl
(
"paged_attention_v1"
,
torch
::
kCPU
,
&
paged_attention_v1
);
ops
.
impl
(
"paged_attention_v1"
,
torch
::
kCPU
,
&
paged_attention_v1
);
ops
.
def
(
"dynamic_4bit_int_moe("
"Tensor x, Tensor topk_ids, Tensor topk_weights,"
"Tensor w13_packed, Tensor w2_packed, int H, int I, int I2,"
"int group_size, bool apply_router_weight_on_input, int activation_kind"
") -> Tensor"
);
ops
.
impl
(
"dynamic_4bit_int_moe"
,
torch
::
kCPU
,
&
dynamic_4bit_int_moe_cpu
);
// PagedAttention V2.
// PagedAttention V2.
ops
.
def
(
ops
.
def
(
"paged_attention_v2("
"paged_attention_v2("
...
...
csrc/moe/dynamic_4bit_int_moe_cpu.cpp
0 → 100644
View file @
359d2930
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <torch/all.h>
// _dyn_quant_matmul_4bit is only available on AArch64.
#if defined(__aarch64__)
#include <ATen/ops/_dyn_quant_matmul_4bit.h>
#endif
inline
torch
::
Tensor
mm
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
packed_w
,
int64_t
group_size_eff
,
int64_t
in_features
,
int64_t
out_features
)
{
#if defined(__aarch64__)
return
at
::
_ops
::
_dyn_quant_matmul_4bit
::
call
(
a
,
packed_w
,
group_size_eff
,
in_features
,
out_features
);
#else
TORCH_CHECK
(
false
,
"dynamic 4-bit int MoE path requires AArch64 (ARM64); "
"_dyn_quant_matmul_4bit is unavailable on this architecture"
);
return
{};
#endif
}
enum
ActivationKind
:
int64_t
{
SwiGLU_Gu
=
0
,
// act = SiLU(g) * u
SwiGLUOAI
=
1
,
// act = SiLU(u) * g
SiLU
=
2
// SiLU
};
torch
::
Tensor
dynamic_4bit_int_moe_cpu
(
torch
::
Tensor
x
,
torch
::
Tensor
topk_ids
,
torch
::
Tensor
topk_weights
,
torch
::
Tensor
w13_packed
,
torch
::
Tensor
w2_packed
,
int64_t
H
,
int64_t
I
,
int64_t
I2
,
int64_t
group_size
,
bool
apply_router_weight_on_input
,
int64_t
activation_kind
)
{
TORCH_CHECK
(
x
.
dim
()
==
2
,
"x must be 2D"
);
TORCH_CHECK
(
topk_ids
.
dim
()
==
2
&&
topk_weights
.
dim
()
==
2
,
"topk tensors must be [T, K]"
);
TORCH_CHECK
(
w13_packed
.
size
(
0
)
==
w2_packed
.
size
(
0
),
"w13_packed and w2_packed must have same number of experts in dim 0"
);
TORCH_CHECK
(
I2
==
2
*
I
,
"I2 must equal 2*I"
);
const
int64_t
T
=
x
.
size
(
0
);
const
int64_t
K
=
topk_ids
.
size
(
1
);
const
int64_t
E
=
w13_packed
.
size
(
0
);
const
int64_t
N
=
T
*
K
;
auto
x_c
=
x
.
contiguous
();
auto
ids_c
=
topk_ids
.
contiguous
();
auto
gates_c
=
topk_weights
.
to
(
at
::
kFloat
).
contiguous
();
// bucketing tokens -> experts
c10
::
SmallVector
<
int64_t
,
64
>
counts
(
E
,
0
);
// Small vector uses stack allocation
{
const
auto
*
ids_ptr
=
ids_c
.
data_ptr
<
int64_t
>
();
for
(
int64_t
i
=
0
;
i
<
N
;
++
i
)
{
const
int64_t
e_id
=
ids_ptr
[
i
];
TORCH_CHECK
(
0
<=
e_id
&&
e_id
<
E
,
"expert id out of range"
);
counts
[
e_id
]
++
;
}
}
c10
::
SmallVector
<
int64_t
,
65
>
offsets
(
E
+
1
,
0
);
// ( E +1 )
for
(
int64_t
e
=
0
;
e
<
E
;
++
e
)
offsets
[
e
+
1
]
=
offsets
[
e
]
+
counts
[
e
];
auto
expert_tokens
=
at
::
empty
({
offsets
[
E
]},
ids_c
.
options
());
auto
expert_gates
=
at
::
empty
({
offsets
[
E
]},
gates_c
.
options
());
{
c10
::
SmallVector
<
int64_t
,
64
>
cursor
(
E
,
0
);
const
auto
*
ids_ptr
=
ids_c
.
data_ptr
<
int64_t
>
();
const
auto
*
gts_ptr
=
gates_c
.
data_ptr
<
float
>
();
auto
*
tok_ptr
=
expert_tokens
.
data_ptr
<
int64_t
>
();
auto
*
gate_ptr
=
expert_gates
.
data_ptr
<
float
>
();
for
(
int64_t
t
=
0
;
t
<
T
;
++
t
)
{
const
int64_t
base
=
t
*
K
;
for
(
int64_t
k
=
0
;
k
<
K
;
++
k
)
{
const
int64_t
idx
=
base
+
k
;
const
int64_t
e
=
ids_ptr
[
idx
];
const
int64_t
p
=
offsets
[
e
]
+
(
cursor
[
e
]
++
);
tok_ptr
[
p
]
=
t
;
gate_ptr
[
p
]
=
gts_ptr
[
idx
];
}
}
}
const
int64_t
g_eff_13
=
(
group_size
!=
-
1
)
?
group_size
:
H
;
const
int64_t
g_eff_2
=
(
group_size
!=
-
1
)
?
group_size
:
I
;
// Per-expert outputs filled in parallel
std
::
vector
<
torch
::
Tensor
>
y_list
(
E
);
y_list
.
resize
(
E
);
at
::
parallel_for
(
0
,
E
,
1
,
[
&
](
int64_t
e_begin
,
int64_t
e_end
)
{
for
(
int64_t
e
=
e_begin
;
e
<
e_end
;
++
e
)
{
const
int64_t
te
=
counts
[
e
];
if
(
te
==
0
)
{
y_list
[
e
]
=
at
::
empty
({
0
,
H
},
x_c
.
options
());
continue
;
}
const
int64_t
start
=
offsets
[
e
];
auto
sel_tokens
=
expert_tokens
.
narrow
(
/*dim=*/
0
,
/*start=*/
start
,
/*length=*/
te
);
auto
gates_e
=
expert_gates
.
narrow
(
/*dim=*/
0
,
/*start=*/
start
,
/*length=*/
te
);
auto
x_e
=
x_c
.
index_select
(
/*dim=*/
0
,
sel_tokens
);
if
(
apply_router_weight_on_input
)
{
x_e
=
x_e
.
mul
(
gates_e
.
unsqueeze
(
1
));
}
auto
w13_e
=
w13_packed
.
select
(
/*dim=*/
0
,
e
);
auto
w2_e
=
w2_packed
.
select
(
/*dim=*/
0
,
e
);
// W13
auto
y13
=
mm
(
x_e
,
w13_e
,
g_eff_13
,
/*in_features=*/
H
,
/*out_features=*/
I2
);
auto
g_part
=
y13
.
narrow
(
/*dim=*/
1
,
/*start=*/
0
,
/*length=*/
I
);
auto
u_part
=
y13
.
narrow
(
/*dim=*/
1
,
/*start=*/
I
,
/*length=*/
I
);
torch
::
Tensor
act
;
if
(
activation_kind
==
ActivationKind
::
SwiGLUOAI
)
{
// SwiGLUOAI
constexpr
double
kAlpha
=
1.702
;
// GPT-OSS default
constexpr
double
kLimit
=
7.0
;
// GPT-OSS default
auto
gate_c
=
at
::
clamp_max
(
g_part
,
kLimit
);
auto
up_c
=
at
::
clamp
(
u_part
,
-
kLimit
,
kLimit
);
auto
glu
=
gate_c
.
mul
(
at
::
sigmoid
(
gate_c
.
mul
(
kAlpha
)));
act
=
up_c
.
add
(
1.0
).
mul
(
glu
);
}
else
{
// SiLU , SwiGLU_GU, vLLM maps silu to SiluAndMul()
act
=
at
::
silu
(
g_part
).
mul
(
u_part
);
}
// W2
auto
y
=
mm
(
act
,
w2_e
,
g_eff_2
,
/*in_features=*/
I
,
/*out_features=*/
H
);
if
(
!
apply_router_weight_on_input
)
{
y
=
y
.
mul
(
gates_e
.
unsqueeze
(
1
));
}
// Store per-expert result
y_list
[
e
]
=
y
;
}
});
// Concatenate all expert outputs to match expert_tokens order
auto
Y_all
=
at
::
cat
(
y_list
,
/*dim=*/
0
);
auto
out
=
at
::
zeros
({
T
,
H
},
x
.
options
());
out
=
at
::
index_add
(
out
,
/*dim=*/
0
,
/*index=*/
expert_tokens
,
/*source=*/
Y_all
);
return
out
;
}
csrc/ops.h
View file @
359d2930
...
@@ -328,6 +328,12 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
...
@@ -328,6 +328,12 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
const
std
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
std
::
optional
<
torch
::
Tensor
>&
has_initial_state
,
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
);
const
torch
::
Tensor
&
ssm_states
,
int64_t
pad_slot_id
);
torch
::
Tensor
dynamic_4bit_int_moe_cpu
(
torch
::
Tensor
x
,
torch
::
Tensor
topk_ids
,
torch
::
Tensor
topk_weights
,
torch
::
Tensor
w13_packed
,
torch
::
Tensor
w2_packed
,
int64_t
H
,
int64_t
I
,
int64_t
I2
,
int64_t
group_size
,
bool
apply_router_weight_on_input
,
int64_t
activation_kind
);
using
fptr_t
=
int64_t
;
using
fptr_t
=
int64_t
;
fptr_t
init_custom_ar
(
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
,
fptr_t
init_custom_ar
(
const
std
::
vector
<
int64_t
>&
fake_ipc_ptrs
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
torch
::
Tensor
&
rank_data
,
int64_t
rank
,
...
...
vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
View file @
359d2930
...
@@ -98,13 +98,16 @@ def select_experts(
...
@@ -98,13 +98,16 @@ def select_experts(
e_score_correction_bias
=
e_score_correction_bias
)
e_score_correction_bias
=
e_score_correction_bias
)
elif
custom_routing_function
is
None
:
elif
custom_routing_function
is
None
:
assert
scoring_func
==
"softmax"
assert
scoring_func
==
"softmax"
topk_
weights
=
torch
.
nn
.
functional
.
softmax
(
router_logits
,
topk_
logit_vals
,
topk_idx
=
torch
.
topk
(
router_logits
,
dim
=
1
,
k
=
top_k
,
dtype
=
torch
.
float32
)
dim
=-
1
,
topk_weights
,
topk_ids
=
torch
.
topk
(
topk_weights
,
top_k
,
dim
=-
1
)
sorted
=
False
)
if
renormalize
:
if
renormalize
:
topk_weights
/=
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_vals
=
torch
.
softmax
(
topk_logit_vals
,
dim
=-
1
)
return
topk_weights
,
topk_ids
.
to
(
torch
.
int32
)
else
:
logZ
=
torch
.
logsumexp
(
router_logits
,
dim
=-
1
,
keepdim
=
True
)
topk_vals
=
(
topk_logit_vals
-
logZ
).
exp
()
return
topk_vals
.
to
(
torch
.
float32
),
topk_idx
.
to
(
torch
.
int32
)
else
:
else
:
return
custom_routing_function
(
hidden_states
=
hidden_states
,
return
custom_routing_function
(
hidden_states
=
hidden_states
,
gating_output
=
router_logits
,
gating_output
=
router_logits
,
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
359d2930
...
@@ -69,8 +69,6 @@ else:
...
@@ -69,8 +69,6 @@ else:
if
is_rocm_aiter_moe_enabled
():
if
is_rocm_aiter_moe_enabled
():
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
# noqa: E501
from
vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe
import
(
# noqa: E501
rocm_aiter_grouped_topk
as
grouped_topk
)
rocm_aiter_grouped_topk
as
grouped_topk
)
elif
current_platform
.
is_cpu
():
pass
else
:
else
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
grouped_topk
from
vllm.model_executor.layers.fused_moe.fused_moe
import
grouped_topk
if
current_platform
.
is_tpu
():
if
current_platform
.
is_tpu
():
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
359d2930
...
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
...
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
FusedMoEQuantConfig
,
fp8_w8a8_moe_quant_config
,
int4_w4a16_moe_quant_config
,
int8_w8a8_moe_quant_config
,
int4_w4a16_moe_quant_config
,
int8_w8a8_moe_quant_config
,
int8_w8a16_moe_quant_config
,
nvfp4_moe_quant_config
)
int8_w8a16_moe_quant_config
,
nvfp4_moe_quant_config
)
from
vllm.model_executor.layers.fused_moe.cpu_fused_moe
import
select_experts
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
from
vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe
import
(
is_valid_flashinfer_cutlass_fused_moe
)
is_valid_flashinfer_cutlass_fused_moe
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16
import
(
# noqa
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16
import
(
# noqa
...
@@ -47,7 +48,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -47,7 +48,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
CpuArchEnum
,
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
from
vllm.utils.deep_gemm
import
is_deep_gemm_e8m0_used
...
@@ -63,7 +64,7 @@ __all__ = [
...
@@ -63,7 +64,7 @@ __all__ = [
"CompressedTensorsMoEMethod"
,
"CompressedTensorsW8A8Fp8MoEMethod"
,
"CompressedTensorsMoEMethod"
,
"CompressedTensorsW8A8Fp8MoEMethod"
,
"CompressedTensorsW8A8Int8MoEMethod"
,
"CompressedTensorsW8A8Int8MoEMethod"
,
"CompressedTensorsWNA16MarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
,
"CompressedTensorsWNA16MarlinMoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
,
"CompressedTensorsW4A4MoeMethod"
"CompressedTensorsW4A4MoeMethod"
,
"CompressedTensorsW4A8Int8MoEMethod"
]
]
...
@@ -139,6 +140,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
...
@@ -139,6 +140,10 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
elif
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
elif
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8MoEMethod
(
quant_config
,
return
CompressedTensorsW8A8Int8MoEMethod
(
quant_config
,
layer
.
moe_config
)
layer
.
moe_config
)
elif
quant_config
.
_is_dynamic_token_w4a8_int
(
weight_quant
,
input_quant
):
return
CompressedTensorsW4A8Int8MoEMethod
(
quant_config
,
layer
.
moe_config
)
else
:
else
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Unsupported FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
f
"Unsupported FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
...
@@ -1769,3 +1774,301 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -1769,3 +1774,301 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
expert_map
=
expert_map
,
expert_map
=
expert_map
,
quant_config
=
self
.
moe_quant_config
,
quant_config
=
self
.
moe_quant_config
,
)
)
class
CompressedTensorsW4A8Int8MoEMethod
(
CompressedTensorsMoEMethod
):
"""
CPU-only MoE method using dynamic 4-bit matmul kernels on Arm Platform
- Weights: int4 (stored as int8 values in [-8,7], packed to uint8 nibbles)
- Scales: Fp32 for Channelwise , bf16 for groupwise quantization
- Bias: Same data type as original weights
- Activations: FP32/Bf16 dynamic per-token (A8 Int),
quantized inside the kernel
"""
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
,
# type: ignore # noqa E501
moe
:
FusedMoEConfig
):
super
().
__init__
(
moe
)
self
.
has_bias
=
self
.
moe
.
has_bias
self
.
quant_config
=
quant_config
# Validate scheme: weights=W4 (channel or group),
# activations=dynamic TOKEN (A8)
wq
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
aq
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
# Must be dynamic per-token activations
if
aq
.
strategy
!=
QuantizationStrategy
.
TOKEN
or
not
aq
.
dynamic
:
raise
ValueError
(
"W4A8-int MoE needs dynamic per-token activation quantization."
)
# Weight can be channel-wise (group_size=None) or group-wise
self
.
group_size
=
wq
.
group_size
if
(
wq
.
group_size
is
not
None
)
else
-
1
if
wq
.
num_bits
!=
4
:
raise
ValueError
(
"This method only supports 4-bit weights (num_bits=4)."
)
# CPU only
if
not
current_platform
.
is_cpu
():
raise
ValueError
(
"CompressedTensorsW4A8Int8MoEMethod is CPU-only."
)
# Arm: check _dyn ops availability
if
current_platform
.
get_cpu_architecture
()
==
CpuArchEnum
.
ARM
:
try
:
_
=
torch
.
ops
.
aten
.
_dyn_quant_matmul_4bit
_
=
torch
.
ops
.
aten
.
_dyn_quant_pack_4bit_weight
except
AttributeError
as
err
:
raise
RuntimeError
(
f
"""PyTorch
{
torch
.
__version__
}
lacks _dyn_quant_* 4bit ops;
install a newer build."""
)
from
err
self
.
static_input_scales
=
False
# always dynamic per token
# ---- parameter creation ----
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
# Shapes per local rank (TP/EP):
# w13: [E, 2*I_local, H] int8 (int4 values in [-8,7])
# w2 : [E, H, I_local] int8
# Scales:
# channel-wise: group_size=-1 -> per-output-row, single scale per row
# group-wise : group_size=g ->
# per-output-row, (in_features/g) scales
E
=
num_experts
H
=
hidden_size
IN
=
intermediate_size_per_partition
g
=
self
.
group_size
# Per-row scale columns
def
_n_scale_cols
(
in_features
:
int
)
->
int
:
return
1
if
g
==
-
1
else
(
in_features
//
g
)
# Register unpacked int4-as-int8 weights the loader will fill.
w13
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
E
,
2
*
IN
,
H
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
set_weight_attrs
(
w13
,
extra_weight_attrs
)
layer
.
register_parameter
(
"w13_weight"
,
w13
)
w2
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
E
,
H
,
IN
,
dtype
=
torch
.
int8
),
requires_grad
=
False
)
set_weight_attrs
(
w2
,
extra_weight_attrs
)
layer
.
register_parameter
(
"w2_weight"
,
w2
)
# Register scales
# KleidiAI groupwise kernels accepts float32 scales
# KleidiAI groupwise kernels accepts bfloat16 scales
scale_dtype
=
torch
.
float32
if
g
==
-
1
else
torch
.
bfloat16
w13_s
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
E
,
2
*
IN
,
_n_scale_cols
(
H
),
dtype
=
scale_dtype
),
requires_grad
=
False
)
set_weight_attrs
(
w13_s
,
{
"quant_method"
:
"channel"
if
g
==
-
1
else
"group"
,
**
extra_weight_attrs
})
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_s
)
w2_s
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
E
,
H
,
_n_scale_cols
(
IN
),
dtype
=
scale_dtype
),
requires_grad
=
False
)
set_weight_attrs
(
w2_s
,
{
"quant_method"
:
"channel"
if
g
==
-
1
else
"group"
,
**
extra_weight_attrs
})
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_s
)
if
self
.
has_bias
:
w13_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
E
,
2
*
IN
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_bias"
,
w13_bias
)
set_weight_attrs
(
w13_bias
,
extra_weight_attrs
)
w2_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_bias"
,
w2_bias
)
set_weight_attrs
(
w2_bias
,
extra_weight_attrs
)
# Placeholders for packed weights (will be replaced after packing)
layer
.
register_parameter
(
"w13_weight_packed"
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
))
set_weight_attrs
(
layer
.
w13_weight_packed
,
extra_weight_attrs
)
layer
.
register_parameter
(
"w2_weight_packed"
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
))
set_weight_attrs
(
layer
.
w2_weight_packed
,
extra_weight_attrs
)
# dims for 4 bit fused matmuls
layer
.
w13_in_features
=
H
layer
.
w13_out_features
=
2
*
IN
layer
.
w2_in_features
=
IN
layer
.
w2_out_features
=
H
layer
.
group_size
=
g
# post-load packing to dyn-4bit KleidiAI kernel's format
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
E
=
layer
.
w13_weight
.
shape
[
0
]
H
=
layer
.
w13_in_features
I2
=
layer
.
w13_out_features
IN
=
layer
.
w2_in_features
g
=
layer
.
group_size
def
_pack_matrix
(
int4_as_int8_2d
:
torch
.
Tensor
,
scales_2d
:
torch
.
Tensor
,
bias_1d
:
Optional
[
torch
.
Tensor
],
in_features
:
int
,
out_features
:
int
)
->
torch
.
Tensor
:
# int4 values are stored as int8 in [-8,7].
# Shift to unsigned nibble and pack pairs along input-dim.
tmp
=
int4_as_int8_2d
.
add
(
8
)
# [out, in]
uint8_nibbles
=
((
tmp
[:,
1
::
2
]
<<
4
)
|
tmp
[:,
::
2
]).
to
(
torch
.
uint8
)
# [out, in//2]
# KleidiAI groupwise kernels accepts float32 scales
# KleidiAI groupwise kernels accepts bfloat16 scales
scale_dtype
=
torch
.
float32
if
g
==
-
1
else
torch
.
bfloat16
scales
=
scales_2d
.
to
(
scale_dtype
)
bias
=
None
if
bias_1d
is
None
else
bias_1d
.
to
(
torch
.
float32
)
return
torch
.
ops
.
aten
.
_dyn_quant_pack_4bit_weight
(
uint8_nibbles
,
scales
,
bias
,
g
if
g
!=
-
1
else
in_features
,
in_features
,
out_features
)
# Pack per expert
w13_packed_list
=
[]
w2_packed_list
=
[]
has_w13_bias
=
hasattr
(
layer
,
"w13_bias"
)
and
layer
.
w13_bias
is
not
None
has_w2_bias
=
hasattr
(
layer
,
"w2_bias"
)
and
layer
.
w2_bias
is
not
None
for
e
in
range
(
E
):
w13_packed_list
.
append
(
_pack_matrix
(
layer
.
w13_weight
[
e
],
# [2I, H]
layer
.
w13_weight_scale
[
e
],
# [2I, H/g or 1]
layer
.
w13_bias
[
e
]
if
has_w13_bias
else
None
,
# [2I]
H
,
I2
))
w2_packed_list
.
append
(
_pack_matrix
(
# w2 shape is [H, IN]; we need [out, in] == [H, IN].
layer
.
w2_weight
[
e
],
# [H, IN]
layer
.
w2_weight_scale
[
e
],
# [H, IN/g or 1]
layer
.
w2_bias
[
e
]
if
has_w2_bias
else
None
,
# [H]
IN
,
layer
.
w2_out_features
# in_features=IN, out_features=H
))
# each packed tensor has identical shape per expert; stack on dim 0
w13_packed
=
torch
.
stack
(
w13_packed_list
,
dim
=
0
)
w2_packed
=
torch
.
stack
(
w2_packed_list
,
dim
=
0
)
replace_parameter
(
layer
,
"w13_weight_packed"
,
torch
.
nn
.
Parameter
(
w13_packed
,
requires_grad
=
False
))
replace_parameter
(
layer
,
"w2_weight_packed"
,
torch
.
nn
.
Parameter
(
w2_packed
,
requires_grad
=
False
))
# free raw tensors/scales/bias now that they're packed into the payload.
replace_parameter
(
layer
,
"w13_weight"
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
))
replace_parameter
(
layer
,
"w2_weight"
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
))
replace_parameter
(
layer
,
"w13_weight_scale"
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
))
replace_parameter
(
layer
,
"w2_weight_scale"
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
))
if
has_w13_bias
:
replace_parameter
(
layer
,
"w13_bias"
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
))
if
has_w2_bias
:
replace_parameter
(
layer
,
"w2_bias"
,
torch
.
nn
.
Parameter
(
torch
.
empty
(
0
),
requires_grad
=
False
))
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
FusedMoEQuantConfig
]:
# CPU dynamic 4-bit MoE path does not use modular kernels or
# fused_experts; quant config is not needed.
return
None
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
routed_scaling_factor
:
float
=
1.0
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
assert
not
enable_eplb
,
"EPLB not supported for W4A8-int MoE yet."
assert
activation
in
(
"silu"
,
"swigluoai"
,
"swiglu"
),
"Only SiLU/SwiGLUGU/SwiGLUUG are supported."
assert
expert_map
is
None
,
"""expert_map/EP not implemented
for CPU dyn-4bit MoE."""
def
_act_kind
(
s
:
str
)
->
int
:
# 0 = SwiGLU_Gu (SiLU(g)*u), 1 = SwiGLU_Ug (SiLU(u)*g), 2 = SiLU
if
s
==
"swiglu"
:
return
0
if
s
==
"swigluoai"
:
return
1
if
s
==
"silu"
:
return
2
raise
ValueError
(
f
"Unknown activation '
{
s
}
'"
)
# Apply topk softmax on router output
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
routed_scaling_factor
=
routed_scaling_factor
,
e_score_correction_bias
=
e_score_correction_bias
,
)
return
torch
.
ops
.
_C
.
dynamic_4bit_int_moe
(
x
,
topk_ids
.
to
(
torch
.
long
),
topk_weights
,
layer
.
w13_weight_packed
,
layer
.
w2_weight_packed
,
layer
.
w2_out_features
,
layer
.
w2_in_features
,
layer
.
w13_out_features
,
layer
.
group_size
,
apply_router_weight_on_input
,
int
(
_act_kind
(
activation
)))
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment