Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3ded6235
Unverified
Commit
3ded6235
authored
May 23, 2025
by
Chunyuan WU
Committed by
GitHub
May 23, 2025
Browse files
Add fp8 fused_experts kernel for CPU in sgl-kernel and add UT (#6404)
parent
4ba1eea8
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
752 additions
and
157 deletions
+752
-157
sgl-kernel/csrc/cpu/gemm.h
sgl-kernel/csrc/cpu/gemm.h
+26
-0
sgl-kernel/csrc/cpu/moe.cpp
sgl-kernel/csrc/cpu/moe.cpp
+76
-28
sgl-kernel/csrc/cpu/moe_fp8.cpp
sgl-kernel/csrc/cpu/moe_fp8.cpp
+291
-12
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
+4
-1
sgl-kernel/setup_cpu.py
sgl-kernel/setup_cpu.py
+0
-116
test/srt/cpu/test_moe.py
test/srt/cpu/test_moe.py
+259
-0
test/srt/cpu/utils.py
test/srt/cpu/utils.py
+96
-0
No files found.
sgl-kernel/csrc/cpu/gemm.h
View file @
3ded6235
...
...
@@ -85,6 +85,32 @@ void fused_experts_int8_kernel_impl(
int64_t
topk
,
int64_t
num_tokens_post_pad
);
// moe implementations for fp8 w8a16
template
<
typename
scalar_t
>
void
fused_experts_fp8_kernel_impl
(
scalar_t
*
__restrict__
output
,
scalar_t
*
__restrict__
ic0
,
scalar_t
*
__restrict__
ic1
,
scalar_t
*
__restrict__
ic2
,
scalar_t
*
__restrict__
A_tmp
,
const
scalar_t
*
__restrict__
input
,
const
at
::
Float8_e4m3fn
*
__restrict__
packed_w1
,
const
at
::
Float8_e4m3fn
*
__restrict__
packed_w2
,
const
float
*
__restrict__
w1s
,
const
float
*
__restrict__
w2s
,
int64_t
block_size_N
,
int64_t
block_size_K
,
const
float
*
__restrict__
topk_weights
,
const
int32_t
*
__restrict__
sorted_ids
,
const
int32_t
*
__restrict__
expert_ids
,
const
int32_t
*
__restrict__
offsets
,
int64_t
M
,
int64_t
N
,
int64_t
K
,
int64_t
E
,
int64_t
topk
,
int64_t
num_tokens_post_pad
);
// shared expert implementation for int8 w8a8
template
<
typename
scalar_t
>
void
shared_expert_int8_kernel_impl
(
...
...
sgl-kernel/csrc/cpu/moe.cpp
View file @
3ded6235
...
...
@@ -932,6 +932,40 @@ void shared_expert_kernel_impl(
}
// anonymous namespace
// common checks
static
inline
void
check_moe_scales
(
bool
use_int8_w8a8
,
bool
use_fp8_w8a16
,
const
std
::
optional
<
at
::
Tensor
>&
w1_scale
,
const
std
::
optional
<
at
::
Tensor
>&
w2_scale
,
const
std
::
optional
<
std
::
vector
<
int64_t
>>
block_size
,
const
std
::
optional
<
at
::
Tensor
>&
a1_scale
,
const
std
::
optional
<
at
::
Tensor
>&
a2_scale
)
{
if
(
use_int8_w8a8
)
{
TORCH_CHECK
(
w1_scale
.
has_value
(),
"missing w1_scale for int8 w8a8."
);
TORCH_CHECK
(
w2_scale
.
has_value
(),
"missing w2_scale for int8 w8a8."
);
TORCH_CHECK
(
!
a1_scale
.
has_value
(),
"static quantization for activation not supported."
);
TORCH_CHECK
(
!
a2_scale
.
has_value
(),
"static quantization for activation not supported."
);
}
if
(
use_fp8_w8a16
)
{
TORCH_CHECK
(
w1_scale
.
has_value
(),
"missing w1_scale for fp8 w8a16."
);
TORCH_CHECK
(
w2_scale
.
has_value
(),
"missing w2_scale for fp8 w8a16."
);
TORCH_CHECK
(
block_size
.
has_value
(),
"missing block_size for fp8 w8a16."
);
TORCH_CHECK
(
block_size
.
value
().
size
()
==
2
,
"expect block_size.size() to be 2."
);
}
}
#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \
auto w1s = w1_scale.value(); \
auto w2s = w2_scale.value(); \
auto block_size_val = block_size.value(); \
int64_t block_size_N = block_size_val[0]; \
int64_t block_size_K = block_size_val[1]; \
TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \
TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \
TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \
TORCH_CHECK(w2s.size(DIM1) == N / block_size_K)
// hidden_states: [M, K]
// w1: [E, 2N, K]
// w2: [E, K, N]
...
...
@@ -946,8 +980,10 @@ at::Tensor fused_experts_cpu(
at
::
Tensor
&
topk_ids
,
bool
inplace
,
bool
use_int8_w8a8
,
bool
use_fp8_w8a16
,
const
std
::
optional
<
at
::
Tensor
>&
w1_scale
,
const
std
::
optional
<
at
::
Tensor
>&
w2_scale
,
const
std
::
optional
<
std
::
vector
<
int64_t
>>
block_size
,
const
std
::
optional
<
at
::
Tensor
>&
a1_scale
,
const
std
::
optional
<
at
::
Tensor
>&
a2_scale
,
bool
is_vnni
)
{
...
...
@@ -990,12 +1026,8 @@ at::Tensor fused_experts_cpu(
CHECK_EQ
(
packed_w1
.
size
(
2
),
packed_K
);
CHECK_EQ
(
packed_w2
.
size
(
2
),
packed_N
);
if
(
use_int8_w8a8
)
{
TORCH_CHECK
(
w1_scale
.
has_value
(),
"missing w1_scale for int8 w8a8."
);
TORCH_CHECK
(
w2_scale
.
has_value
(),
"missing w2_scale for int8 w8a8."
);
TORCH_CHECK
(
!
a1_scale
.
has_value
(),
"static quantization for activation not supported."
);
TORCH_CHECK
(
!
a2_scale
.
has_value
(),
"static quantization for activation not supported."
);
}
// check scales
check_moe_scales
(
use_int8_w8a8
,
use_fp8_w8a16
,
w1_scale
,
w2_scale
,
block_size
,
a1_scale
,
a2_scale
);
at
::
Tensor
out_hidden_states
=
inplace
?
hidden_states
:
at
::
empty_like
(
hidden_states
);
...
...
@@ -1047,6 +1079,9 @@ at::Tensor fused_experts_cpu(
// 5. Aq_tmp : [M, K] or [M * topk, N]
// 6. As_tmp : [M * topk]
//
// for fp8 w8a16:
// 7. intermediate_cache1 : [M * topk, 2N]
//
int64_t
buffer_size_nbytes
=
M
*
topk
*
N
*
2
+
M
*
topk
*
K
*
2
+
num_threads
*
BLOCK_M
*
K
*
(
use_int8_w8a8
?
1
:
2
)
+
num_threads
*
2
*
BLOCK_M
*
BLOCK_N
*
sizeof
(
float
);
...
...
@@ -1054,6 +1089,9 @@ at::Tensor fused_experts_cpu(
if
(
use_int8_w8a8
)
{
buffer_size_nbytes
+=
std
::
max
(
M
*
K
,
M
*
topk
*
N
)
+
M
*
topk
*
sizeof
(
float
);
}
if
(
use_fp8_w8a16
)
{
buffer_size_nbytes
+=
M
*
topk
*
2
*
N
*
2
;
}
auto
buffer2
=
at
::
empty
({
buffer_size_nbytes
},
hidden_states
.
options
().
dtype
(
at
::
kChar
));
...
...
@@ -1095,6 +1133,35 @@ at::Tensor fused_experts_cpu(
E
,
topk
,
num_tokens_post_pad
);
}
else
if
(
use_fp8_w8a16
)
{
// here we just ignore C_tmp as it is not used
scalar_t
*
__restrict__
A_tmp
=
(
scalar_t
*
)((
void
*
)(
intermediate_cache2
+
M
*
topk
*
K
));
scalar_t
*
__restrict__
intermediate_cache0
=
(
scalar_t
*
)((
void
*
)(
A_tmp
+
num_threads
*
BLOCK_M
*
K
));
CHECK_MOE_SCALES_FP8
(
1
,
2
);
fused_experts_fp8_kernel_impl
(
out_hidden_states
.
data_ptr
<
scalar_t
>
(),
intermediate_cache0
,
intermediate_cache1
,
intermediate_cache2
,
A_tmp
,
hidden_states
.
data_ptr
<
scalar_t
>
(),
packed_w1
.
data_ptr
<
at
::
Float8_e4m3fn
>
(),
packed_w2
.
data_ptr
<
at
::
Float8_e4m3fn
>
(),
w1s
.
data_ptr
<
float
>
(),
w2s
.
data_ptr
<
float
>
(),
block_size_N
,
block_size_K
,
topk_weights
.
data_ptr
<
float
>
(),
sorted_ids
,
expert_ids
,
offsets
,
M
,
N
,
K
,
E
,
topk
,
num_tokens_post_pad
);
}
else
{
scalar_t
*
__restrict__
A_tmp
=
intermediate_cache2
+
M
*
topk
*
K
;
float
*
__restrict__
C_tmp
=
(
float
*
)((
void
*
)(
A_tmp
+
num_threads
*
BLOCK_M
*
K
));
...
...
@@ -1176,17 +1243,8 @@ at::Tensor shared_expert_cpu(
CHECK_EQ
(
packed_w1
.
size
(
1
),
packed_K
);
CHECK_EQ
(
packed_w2
.
size
(
1
),
packed_N
);
if
(
use_int8_w8a8
)
{
TORCH_CHECK
(
w1_scale
.
has_value
(),
"missing w1_scale for int8 w8a8."
);
TORCH_CHECK
(
w2_scale
.
has_value
(),
"missing w2_scale for int8 w8a8."
);
TORCH_CHECK
(
!
a1_scale
.
has_value
(),
"static quantization for activation not supported."
);
TORCH_CHECK
(
!
a2_scale
.
has_value
(),
"static quantization for activation not supported."
);
}
if
(
use_fp8_w8a16
)
{
TORCH_CHECK
(
w1_scale
.
has_value
(),
"missing w1_scale for fp8 w8a16."
);
TORCH_CHECK
(
w2_scale
.
has_value
(),
"missing w2_scale for fp8 w8a16."
);
TORCH_CHECK
(
block_size
.
has_value
(),
"missing block_size for fp8 w8a16."
);
}
// check scales
check_moe_scales
(
use_int8_w8a8
,
use_fp8_w8a16
,
w1_scale
,
w2_scale
,
block_size
,
a1_scale
,
a2_scale
);
at
::
Tensor
out_hidden_states
=
inplace
?
hidden_states
:
at
::
empty_like
(
hidden_states
);
...
...
@@ -1244,17 +1302,7 @@ at::Tensor shared_expert_cpu(
}
else
if
(
use_fp8_w8a16
)
{
scalar_t
*
__restrict__
intermediate_cache0
=
(
scalar_t
*
)((
void
*
)(
C_tmp
+
num_threads
*
2
*
BLOCK_M
*
BLOCK_N
));
auto
w1s
=
w1_scale
.
value
();
auto
w2s
=
w2_scale
.
value
();
auto
block_size_val
=
block_size
.
value
();
TORCH_CHECK
(
block_size_val
.
size
()
==
2
,
"shared_expert: expect block_size.size() to be 2."
);
int64_t
block_size_N
=
block_size_val
[
0
];
int64_t
block_size_K
=
block_size_val
[
1
];
TORCH_CHECK
(
w1s
.
size
(
0
)
==
2
*
N
/
block_size_N
);
TORCH_CHECK
(
w1s
.
size
(
1
)
==
K
/
block_size_K
);
TORCH_CHECK
(
w2s
.
size
(
0
)
==
K
/
block_size_N
);
TORCH_CHECK
(
w2s
.
size
(
1
)
==
N
/
block_size_K
);
CHECK_MOE_SCALES_FP8
(
0
,
1
);
shared_expert_fp8_kernel_impl
<
scalar_t
>
(
out_hidden_states
.
data_ptr
<
scalar_t
>
(),
intermediate_cache0
,
...
...
sgl-kernel/csrc/cpu/moe_fp8.cpp
View file @
3ded6235
...
...
@@ -4,6 +4,76 @@
namespace
{
template
<
typename
scalar_t
>
inline
void
copy_stub
(
scalar_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
int64_t
size
)
{
using
Vec
=
at
::
vec
::
Vectorized
<
scalar_t
>
;
// no remainder
#pragma GCC unroll 4
for
(
int64_t
d
=
0
;
d
<
size
;
d
+=
Vec
::
size
())
{
Vec
data
=
Vec
::
loadu
(
input
+
d
);
data
.
store
(
out
+
d
);
}
}
template
<
typename
scalar_t
>
inline
void
copy_mul_stub
(
scalar_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
float
weight
,
int64_t
size
)
{
using
bVec
=
at
::
vec
::
Vectorized
<
scalar_t
>
;
using
fVec
=
at
::
vec
::
Vectorized
<
float
>
;
constexpr
int
kVecSize
=
bVec
::
size
();
const
fVec
weight_vec
=
fVec
(
weight
);
int64_t
d
;
#pragma GCC unroll 4
for
(
d
=
0
;
d
<=
size
-
kVecSize
;
d
+=
kVecSize
)
{
bVec
x
=
bVec
::
loadu
(
input
+
d
);
fVec
x0
,
x1
;
std
::
tie
(
x0
,
x1
)
=
at
::
vec
::
convert_to_float
(
x
);
x0
=
x0
*
weight_vec
;
x1
=
x1
*
weight_vec
;
bVec
out_vec
=
convert_from_float_ext
<
scalar_t
>
(
x0
,
x1
);
out_vec
.
store
(
out
+
d
);
}
for
(;
d
<
size
;
++
d
)
{
out
[
d
]
=
static_cast
<
scalar_t
>
(
input
[
d
]
*
weight
);
}
}
// acc from [topk, K] to [K]
template
<
typename
scalar_t
>
inline
void
sum_stub
(
scalar_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
int64_t
topk
,
int64_t
K
)
{
using
bVec
=
at
::
vec
::
Vectorized
<
scalar_t
>
;
using
fVec
=
at
::
vec
::
Vectorized
<
float
>
;
constexpr
int
kVecSize
=
bVec
::
size
();
if
(
topk
==
1
)
{
// do copy for topk = 1
copy_stub
(
out
,
input
,
K
);
}
else
{
// do sum for topk != 1
int64_t
d
;
#pragma GCC unroll 4
for
(
d
=
0
;
d
<=
K
-
kVecSize
;
d
+=
kVecSize
)
{
fVec
sum_fvec0
=
fVec
(
0.
f
);
fVec
sum_fvec1
=
fVec
(
0.
f
);
for
(
int
t
=
0
;
t
<
topk
;
++
t
)
{
bVec
x_bvec
=
bVec
::
loadu
(
input
+
t
*
K
+
d
);
fVec
x_fvec0
,
x_fvec1
;
std
::
tie
(
x_fvec0
,
x_fvec1
)
=
at
::
vec
::
convert_to_float
(
x_bvec
);
sum_fvec0
+=
x_fvec0
;
sum_fvec1
+=
x_fvec1
;
}
bVec
out_bvec
=
convert_from_float_ext
<
scalar_t
>
(
sum_fvec0
,
sum_fvec1
);
out_bvec
.
store
(
out
+
d
);
}
for
(;
d
<
K
;
++
d
)
{
float
sum_val
=
0.
f
;
for
(
int
t
=
0
;
t
<
topk
;
++
t
)
{
sum_val
+=
static_cast
<
float
>
(
input
[
t
*
K
+
d
]);
}
out
[
d
]
=
static_cast
<
scalar_t
>
(
sum_val
);
}
}
}
// out = input + input2 * scale
template
<
typename
scalar_t
>
inline
void
add_mul_stub
(
...
...
@@ -65,6 +135,215 @@ inline void silu_and_mul_stub(
}
// anonymous namespace
template
<
typename
scalar_t
>
void
fused_experts_fp8_kernel_impl
(
scalar_t
*
__restrict__
output
,
scalar_t
*
__restrict__
ic0
,
scalar_t
*
__restrict__
ic1
,
scalar_t
*
__restrict__
ic2
,
scalar_t
*
__restrict__
A_tmp
,
const
scalar_t
*
__restrict__
input
,
const
at
::
Float8_e4m3fn
*
__restrict__
packed_w1
,
const
at
::
Float8_e4m3fn
*
__restrict__
packed_w2
,
const
float
*
__restrict__
w1s
,
const
float
*
__restrict__
w2s
,
int64_t
block_size_N
,
int64_t
block_size_K
,
const
float
*
__restrict__
topk_weights
,
const
int32_t
*
__restrict__
sorted_ids
,
const
int32_t
*
__restrict__
expert_ids
,
const
int32_t
*
__restrict__
offsets
,
int64_t
M
,
int64_t
N
,
int64_t
K
,
int64_t
E
,
int64_t
topk
,
int64_t
num_tokens_post_pad
)
{
constexpr
int64_t
BLOCK_M
=
block_size_m
();
constexpr
int64_t
BLOCK_N
=
block_size_n
();
// stage 1: intermediate_cache0 = hidden_states @ w1
const
int64_t
MB
=
div_up
(
num_tokens_post_pad
,
BLOCK_M
);
const
int64_t
NB
=
div_up
(
2
*
N
,
BLOCK_N
);
int64_t
scale_size_N
=
div_up
(
2
*
N
,
block_size_N
);
int64_t
scale_size_K
=
div_up
(
K
,
block_size_K
);
int64_t
blocks_n_per_group
=
block_size_N
/
BLOCK_N
;
const
int64_t
stride_e
=
2
*
N
*
K
;
const
int64_t
stride_n
=
K
;
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at
::
parallel_for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
// get local pointers
int
tid
=
at
::
get_thread_num
();
scalar_t
*
__restrict__
A
=
A_tmp
+
tid
*
BLOCK_M
*
K
;
alignas
(
64
)
scalar_t
Btmp
[
BLOCK_N
*
BLOCK_K
];
alignas
(
64
)
float
Ctmp
[
BLOCK_M
*
BLOCK_N
];
bool
is_brgemm_used
=
false
;
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB
;
int64_t
nb
=
i
%
NB
;
int64_t
n_size
=
std
::
min
(
2
*
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
// B shape [K, n_size] in vnni format
int32_t
expert_id
=
expert_ids
[
mb
];
const
at
::
Float8_e4m3fn
*
__restrict__
B
=
packed_w1
+
expert_id
*
stride_e
+
nb
*
BLOCK_N
*
stride_n
;
const
float
*
__restrict__
Bs
=
w1s
+
expert_id
*
scale_size_N
*
scale_size_K
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
;
// 1.a load A
const
int32_t
*
A_ids
=
sorted_ids
+
mb
*
BLOCK_M
;
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
const
bool
use_brgemm
=
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
int32_t
index
=
A_ids
[
m
]
/
topk
;
copy_stub
(
A
+
m
*
K
,
input
+
index
*
K
,
K
);
}
const
int64_t
offset
=
offsets
[
mb
];
tinygemm_kernel
<
scalar_t
>
(
/* A */
A
,
/* B */
B
,
/* C */
ic0
+
offset
*
2
*
N
+
nb
*
BLOCK_N
,
/* Btmp */
Btmp
,
/* Ctmp */
Ctmp
,
/* scale */
Bs
,
/* M */
m_size
,
/* N */
n_size
,
/* K */
K
,
/* lda */
K
,
/* ldb */
n_size
,
/* ldc */
2
*
N
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
}
if
(
is_brgemm_used
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
at
::
parallel_for
(
0
,
M
*
topk
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
m
=
begin
;
m
<
end
;
++
m
)
{
silu_and_mul_stub
(
ic1
+
m
*
N
,
ic0
+
m
*
2
*
N
,
ic0
+
m
*
2
*
N
+
N
,
N
);
}
});
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
// w2 : [E, K, N] as [E, OC, IC]
const
int64_t
OC
=
K
;
// rename K as OC
const
int64_t
IC
=
N
;
// rename N as IC
const
int64_t
MB2
=
MB
;
const
int64_t
NB2
=
div_up
(
OC
,
BLOCK_N
);
scale_size_N
=
div_up
(
K
,
block_size_N
);
scale_size_K
=
div_up
(
N
,
block_size_K
);
const
int64_t
stride_e2
=
OC
*
IC
;
const
int64_t
stride_oc
=
IC
;
// parallel on [MB2, NB2]
at
::
parallel_for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
alignas
(
64
)
scalar_t
Btmp
[
BLOCK_K
*
BLOCK_N
];
alignas
(
64
)
scalar_t
C
[
BLOCK_M
*
BLOCK_K
];
alignas
(
64
)
float
Ctmp
[
BLOCK_M
*
BLOCK_K
];
bool
is_brgemm_used
=
false
;
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
const
bool
use_brgemm
=
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
// A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again
const
scalar_t
*
__restrict__
A
=
ic1
+
offsets
[
mb
]
*
N
;
const
int32_t
*
A_ids
=
sorted_ids
+
mb
*
BLOCK_M
;
// B shape [IC, n_size] in vnni format
int32_t
expert_id
=
expert_ids
[
mb
];
const
at
::
Float8_e4m3fn
*
__restrict__
B
=
packed_w2
+
expert_id
*
stride_e2
+
nb
*
BLOCK_N
*
stride_oc
;
const
float
*
__restrict__
Bs
=
w2s
+
expert_id
*
scale_size_N
*
scale_size_K
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
;
tinygemm_kernel
<
scalar_t
>
(
/* A */
A
,
/* B */
B
,
/* C */
C
,
/* Btmp */
Btmp
,
/* Ctmp */
Ctmp
,
/* scale */
Bs
,
/* M */
m_size
,
/* N */
n_size
,
/* K */
IC
,
/* lda */
IC
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
// 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
int32_t
index
=
A_ids
[
m
];
float
weight
=
topk_weights
[
index
];
copy_mul_stub
(
ic2
+
index
*
K
+
nb
*
BLOCK_N
,
C
+
m
*
BLOCK_N
,
weight
,
n_size
);
}
}
if
(
is_brgemm_used
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
// stage 3: out = intermediate_cache2.sum(dim=1)
// from [M, topk, K] to [M, K]
at
::
parallel_for
(
0
,
M
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
m
=
begin
;
m
<
end
;
++
m
)
{
sum_stub
(
output
+
m
*
K
,
ic2
+
m
*
topk
*
K
,
topk
,
K
);
}
});
}
#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \
template void fused_experts_fp8_kernel_impl<TYPE>( \
TYPE* __restrict__ output, \
TYPE* __restrict__ ic0, \
TYPE* __restrict__ ic1, \
TYPE* __restrict__ ic2, \
TYPE* __restrict__ A_tmp, \
const TYPE* __restrict__ input, \
const at::Float8_e4m3fn* __restrict__ packed_w1, \
const at::Float8_e4m3fn* __restrict__ packed_w2, \
const float* __restrict__ w1s, \
const float* __restrict__ w2s, \
int64_t block_size_N, \
int64_t block_size_K, \
const float* __restrict__ topk_weights, \
const int32_t* __restrict__ sorted_ids, \
const int32_t* __restrict__ expert_ids, \
const int32_t* __restrict__ offsets, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t E, \
int64_t topk, \
int64_t num_tokens_post_pad)
INSTANTIATE_MOE_FP8_TEMPLATE
(
at
::
BFloat16
);
INSTANTIATE_MOE_FP8_TEMPLATE
(
at
::
Half
);
template
<
typename
scalar_t
>
void
shared_expert_fp8_kernel_impl
(
scalar_t
*
__restrict__
output
,
...
...
@@ -100,8 +379,8 @@ void shared_expert_fp8_kernel_impl(
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB
;
int64_t
nb
=
i
%
NB
;
int64_t
m
b
_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
n
b
_size
=
std
::
min
(
2
*
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
n_size
=
std
::
min
(
2
*
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
tinygemm_kernel
<
scalar_t
>
(
/* A */
input
+
mb
*
BLOCK_M
*
K
,
...
...
@@ -110,11 +389,11 @@ void shared_expert_fp8_kernel_impl(
/* Btmp */
Btmp
,
/* Ctmp */
Ctmp
,
/* scale */
w1s
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
,
/* M */
m
b
_size
,
/* N */
n
b
_size
,
/* M */
m_size
,
/* N */
n_size
,
/* K */
K
,
/* lda */
K
,
/* ldb */
n
b
_size
,
/* ldb */
n_size
,
/* ldc */
2
*
N
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
...
...
@@ -149,8 +428,8 @@ void shared_expert_fp8_kernel_impl(
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
int64_t
m
b
_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
n
b
_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
// 2.a gemm: C = A @ B
tinygemm_kernel
<
scalar_t
>
(
...
...
@@ -160,11 +439,11 @@ void shared_expert_fp8_kernel_impl(
/* Btmp */
Btmp
,
/* Ctmp */
Ctmp
,
/* scale */
w2s
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
,
/* M */
m
b
_size
,
/* N */
n
b
_size
,
/* M */
m_size
,
/* N */
n_size
,
/* K */
IC
,
/* lda */
IC
,
/* ldb */
n
b
_size
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
...
...
@@ -172,8 +451,8 @@ void shared_expert_fp8_kernel_impl(
// 2.b copy from C to output and add fused_experts_out
scalar_t
*
__restrict__
out
=
output
+
mb
*
BLOCK_M
*
K
+
nb
*
BLOCK_N
;
const
scalar_t
*
__restrict__
fused_out
=
fused_experts_out
+
mb
*
BLOCK_M
*
K
+
nb
*
BLOCK_N
;
for
(
int64_t
m
=
0
;
m
<
m
b
_size
;
++
m
)
{
add_mul_stub
(
out
+
m
*
K
,
C
+
m
*
BLOCK_N
,
fused_out
+
m
*
K
,
routed_scaling_factor
,
n
b
_size
);
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
add_mul_stub
(
out
+
m
*
K
,
C
+
m
*
BLOCK_N
,
fused_out
+
m
*
K
,
routed_scaling_factor
,
n_size
);
}
}
});
...
...
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
View file @
3ded6235
...
...
@@ -130,8 +130,10 @@ at::Tensor fused_experts_cpu(
at
::
Tensor
&
topk_ids
,
bool
inplace
,
bool
use_int8_w8a8
,
bool
use_fp8_w8a16
,
const
std
::
optional
<
at
::
Tensor
>&
w1_scale
,
const
std
::
optional
<
at
::
Tensor
>&
w2_scale
,
const
std
::
optional
<
std
::
vector
<
int64_t
>>
block_size
,
const
std
::
optional
<
at
::
Tensor
>&
a1_scale
,
const
std
::
optional
<
at
::
Tensor
>&
a2_scale
,
bool
is_vnni
);
...
...
@@ -260,7 +262,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// moe
m
.
def
(
"fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool "
"inplace, bool use_int8_w8a8, Tensor? w1_scale, Tensor? w2_scale, Tensor? a1_scale, Tensor? a2_scale, bool "
"inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, int[]? block_size, Tensor? "
"a1_scale, Tensor? a2_scale, bool "
"is_vnni) -> Tensor"
);
m
.
impl
(
"fused_experts_cpu"
,
torch
::
kCPU
,
&
fused_experts_cpu
);
...
...
sgl-kernel/setup_cpu.py
deleted
100644 → 0
View file @
4ba1eea8
# Copyright 2025 SGLang Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
os
import
platform
import
shutil
import
sys
from
pathlib
import
Path
import
torch
from
setuptools
import
find_packages
,
setup
from
setuptools.command.build_py
import
build_py
from
torch.utils.cpp_extension
import
BuildExtension
,
CppExtension
root
=
Path
(
__file__
).
parent
.
resolve
()
arch
=
platform
.
machine
().
lower
()
if
arch
in
(
"x86_64"
,
"amd64"
):
plat_name
=
"manylinux2014_x86_64"
elif
arch
in
(
"aarch64"
,
"arm64"
):
plat_name
=
"manylinux2014_aarch64"
elif
arch
.
startswith
(
"ppc"
):
plat_name
=
"manylinux2014_ppc64le"
else
:
plat_name
=
f
"manylinux2014_
{
arch
}
"
if
"bdist_wheel"
in
sys
.
argv
and
"--plat-name"
not
in
sys
.
argv
:
sys
.
argv
.
extend
([
"--plat-name"
,
plat_name
])
def
_get_version
():
with
open
(
root
/
"pyproject.toml"
)
as
f
:
for
line
in
f
:
if
line
.
startswith
(
"version"
):
return
line
.
split
(
"="
)[
1
].
strip
().
strip
(
'"'
)
cpu_fp8_ftz
=
os
.
getenv
(
"SGLANG_CPU_FP8_CVT_FTZ"
,
"1"
)
==
"1"
operator_namespace
=
"sgl_kernel"
include_dirs
=
[
"../../include"
,
]
sources
=
[
"csrc/cpu/activation.cpp"
,
"csrc/cpu/bmm.cpp"
,
"csrc/cpu/decode.cpp"
,
"csrc/cpu/extend.cpp"
,
"csrc/cpu/gemm.cpp"
,
"csrc/cpu/gemm_fp8.cpp"
,
"csrc/cpu/gemm_int8.cpp"
,
"csrc/cpu/moe.cpp"
,
"csrc/cpu/moe_fp8.cpp"
,
"csrc/cpu/moe_int8.cpp"
,
"csrc/cpu/norm.cpp"
,
"csrc/cpu/qkv_proj.cpp"
,
"csrc/cpu/topk.cpp"
,
"csrc/cpu/interface.cpp"
,
"csrc/cpu/shm.cpp"
,
"csrc/cpu/rope.cpp"
,
"csrc/cpu/torch_extension_cpu.cpp"
,
]
extra_compile_args
=
{
"cxx"
:
[
"-O3"
,
"-Wno-unknown-pragmas"
,
"-march=native"
,
"-fopenmp"
,
]
}
if
cpu_fp8_ftz
:
extra_compile_args
[
"cxx"
].
append
(
"-DSGLANG_CPU_FP8_CVT_FTZ"
)
libraries
=
[
"c10"
,
"torch"
,
"torch_python"
]
cmdclass
=
{
"build_ext"
:
BuildExtension
.
with_options
(
use_ninja
=
True
),
}
Extension
=
CppExtension
extra_link_args
=
[
"-Wl,-rpath,$ORIGIN/../../torch/lib"
,
f
"-L/usr/lib/
{
arch
}
-linux-gnu"
]
ext_modules
=
[
Extension
(
name
=
"sgl_kernel.common_ops"
,
sources
=
sources
,
include_dirs
=
include_dirs
,
extra_compile_args
=
extra_compile_args
,
libraries
=
libraries
,
extra_link_args
=
extra_link_args
,
py_limited_api
=
False
,
),
]
setup
(
name
=
"sgl-kernel"
,
version
=
_get_version
(),
packages
=
find_packages
(
where
=
"python"
),
package_dir
=
{
""
:
"python"
},
ext_modules
=
ext_modules
,
cmdclass
=
cmdclass
,
options
=
{
"bdist_wheel"
:
{
"py_limited_api"
:
"cp39"
}},
)
test/srt/cpu/test_moe.py
0 → 100644
View file @
3ded6235
import
itertools
import
math
import
unittest
# TODO: use interface in cpu.py
import
sgl_kernel
import
torch
kernel
=
torch
.
ops
.
sgl_kernel
from
utils
import
(
BLOCK_K
,
BLOCK_N
,
factor_for_scale
,
fp8_max
,
fp8_min
,
native_fp8_fused_moe
,
precision
,
scaled_weight
,
torch_naive_fused_moe
,
torch_w8a8_per_column_fused_moe
,
)
from
sglang.test.test_utils
import
CustomTestCase
def
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
,
prepack
):
G
=
1
topk_group
=
1
B
,
D
=
a
.
shape
topk_weights
=
torch
.
empty
(
B
,
topk
,
dtype
=
torch
.
float32
)
topk_ids
=
torch
.
empty
(
B
,
topk
,
dtype
=
torch
.
int32
)
topk_weights
,
topk_ids
=
kernel
.
grouped_topk_cpu
(
a
,
score
,
topk
,
renormalize
,
G
,
topk_group
)
packed_w1
=
kernel
.
convert_weight_packed
(
w1
)
if
prepack
else
w1
packed_w2
=
kernel
.
convert_weight_packed
(
w2
)
if
prepack
else
w2
inplace
=
True
return
kernel
.
fused_experts_cpu
(
a
,
packed_w1
,
packed_w2
,
topk_weights
,
topk_ids
,
inplace
,
False
,
False
,
None
,
None
,
None
,
None
,
None
,
prepack
,
)
class
TestFusedExperts
(
CustomTestCase
):
M
=
[
2
,
114
]
N
=
[
32
]
K
=
[
32
]
E
=
[
4
]
topk
=
[
2
]
renormalize
=
[
False
,
True
]
M_int8
=
[
1
,
39
]
N_int8
=
[
128
]
K_int8
=
[
256
]
E_int8
=
[
8
]
topk_int8
=
[
3
]
M_fp8
=
[
2
,
121
]
N_fp8
=
[
512
]
K_fp8
=
[
256
]
E_fp8
=
[
8
]
topk_fp8
=
[
4
]
def
_bf16_moe
(
self
,
m
,
n
,
k
,
e
,
topk
,
renormalize
):
dtype
=
torch
.
bfloat16
prepack
=
True
a
=
torch
.
randn
((
m
,
k
),
device
=
"cpu"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cpu"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cpu"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cpu"
,
dtype
=
dtype
)
torch_output
=
torch_naive_fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
)
fused_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
,
prepack
)
atol
=
rtol
=
precision
[
torch_output
.
dtype
]
self
.
assertTrue
(
torch
.
allclose
(
torch_output
,
fused_output
,
atol
=
atol
,
rtol
=
rtol
)
)
def
test_bf16_moe
(
self
):
for
params
in
itertools
.
product
(
self
.
M
,
self
.
N
,
self
.
K
,
self
.
E
,
self
.
topk
,
self
.
renormalize
,
):
with
self
.
subTest
(
m
=
params
[
0
],
n
=
params
[
1
],
k
=
params
[
2
],
e
=
params
[
3
],
topk
=
params
[
4
],
renormalize
=
params
[
5
],
):
self
.
_bf16_moe
(
*
params
)
def
_int8_moe
(
self
,
M
,
N
,
K
,
E
,
topk
):
dtype
=
torch
.
bfloat16
prepack
=
True
# Initialize int8 quantization parameters
int8_factor_for_scale
=
1e-2
int8_max
=
127
int8_min
=
-
128
# Input tensor
# M * K
a
=
torch
.
randn
((
M
,
K
),
dtype
=
dtype
)
/
math
.
sqrt
(
K
)
# Generate int8 weights
w1_fp32
=
(
torch
.
rand
((
E
,
2
*
N
,
K
),
dtype
=
torch
.
float32
)
-
0.5
)
*
2
w1
=
(
w1_fp32
*
int8_max
).
clamp
(
min
=
int8_min
,
max
=
int8_max
).
to
(
torch
.
int8
)
w2_fp32
=
(
torch
.
rand
((
E
,
K
,
N
),
dtype
=
torch
.
float32
)
-
0.5
)
*
2
w2
=
(
w2_fp32
*
int8_max
).
clamp
(
min
=
int8_min
,
max
=
int8_max
).
to
(
torch
.
int8
)
# Generate scale for each column (per-column quantization)
w1_s
=
torch
.
rand
(
E
,
2
*
N
,
device
=
w1_fp32
.
device
)
*
int8_factor_for_scale
w2_s
=
torch
.
rand
(
E
,
K
,
device
=
w2_fp32
.
device
)
*
int8_factor_for_scale
# Calculate routing
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
ref_out
=
torch_w8a8_per_column_fused_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weight
,
topk_ids
,
topk
)
inplace
=
True
packed_w1
=
kernel
.
convert_weight_packed
(
w1
)
if
prepack
else
w1
packed_w2
=
kernel
.
convert_weight_packed
(
w2
)
if
prepack
else
w2
out
=
kernel
.
fused_experts_cpu
(
a
,
packed_w1
,
packed_w2
,
topk_weight
,
topk_ids
.
to
(
torch
.
int32
),
inplace
,
True
,
False
,
w1_s
,
w2_s
,
None
,
None
,
None
,
prepack
,
)
atol
=
rtol
=
precision
[
ref_out
.
dtype
]
# Increase the tolerance for large input shapes
if
M
>
35
:
atol
=
rtol
=
0.02
self
.
assertTrue
(
torch
.
allclose
(
ref_out
,
out
,
atol
=
atol
,
rtol
=
rtol
))
def
test_int8_moe
(
self
):
for
params
in
itertools
.
product
(
self
.
M_int8
,
self
.
N_int8
,
self
.
K_int8
,
self
.
E_int8
,
self
.
topk_int8
,
):
with
self
.
subTest
(
M
=
params
[
0
],
N
=
params
[
1
],
K
=
params
[
2
],
E
=
params
[
3
],
topk
=
params
[
4
],
):
self
.
_int8_moe
(
*
params
)
def
_fp8_moe
(
self
,
M
,
N
,
K
,
E
,
topk
):
dtype
=
torch
.
bfloat16
a
=
torch
.
randn
(
M
,
K
,
dtype
=
dtype
)
/
math
.
sqrt
(
K
)
w1_fp32
=
torch
.
randn
(
E
,
2
*
N
,
K
)
w1
=
(
w1_fp32
*
fp8_max
).
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
w2_fp32
=
torch
.
randn
(
E
,
K
,
N
)
w2
=
(
w2_fp32
*
fp8_max
).
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
w1s
=
torch
.
randn
(
E
,
2
*
N
//
BLOCK_N
,
K
//
BLOCK_K
)
*
factor_for_scale
w2s
=
torch
.
randn
(
E
,
K
//
BLOCK_N
,
N
//
BLOCK_K
)
*
factor_for_scale
w1_scaled
=
scaled_weight
(
w1
,
w1s
)
w2_scaled
=
scaled_weight
(
w2
,
w2s
)
score
=
torch
.
randn
((
M
,
E
),
dtype
=
dtype
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
w1
=
kernel
.
convert_weight_packed
(
w1
)
w2
=
kernel
.
convert_weight_packed
(
w2
)
ref_out
=
native_fp8_fused_moe
(
a
,
w1_scaled
,
w2_scaled
,
topk_weight
,
topk_ids
,
topk
)
out
=
kernel
.
fused_experts_cpu
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
.
to
(
torch
.
int32
),
False
,
False
,
True
,
w1s
,
w2s
,
[
BLOCK_N
,
BLOCK_K
],
None
,
None
,
True
,
)
atol
=
rtol
=
precision
[
dtype
]
self
.
assertTrue
(
torch
.
allclose
(
ref_out
.
bfloat16
(),
out
,
atol
=
atol
,
rtol
=
rtol
))
def
test_fp8_moe
(
self
):
for
params
in
itertools
.
product
(
self
.
M_fp8
,
self
.
N_fp8
,
self
.
K_fp8
,
self
.
E_fp8
,
self
.
topk_fp8
,
):
with
self
.
subTest
(
M
=
params
[
0
],
N
=
params
[
1
],
K
=
params
[
2
],
E
=
params
[
3
],
topk
=
params
[
4
],
):
self
.
_fp8_moe
(
*
params
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/cpu/utils.py
View file @
3ded6235
...
...
@@ -148,3 +148,99 @@ def scaled_weight(weight, scales):
.
contiguous
()
.
view
(
E
,
N
,
K
)
)
def
torch_naive_fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
a
.
dtype
,
device
=
a
.
device
)
score
=
torch
.
softmax
(
score
,
dim
=-
1
,
dtype
=
torch
.
float32
)
topk_weight
,
topk_ids
=
torch
.
topk
(
score
,
topk
)
if
renormalize
:
topk_weight
=
topk_weight
/
topk_weight
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
out
[
mask
]
=
SiluAndMul
(
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
))
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
)
).
sum
(
dim
=
1
)
def
torch_w8a8_per_column_fused_moe
(
a
,
w1
,
w2
,
w1_s
,
w2_s
,
topk_weight
,
topk_ids
,
topk
):
"""This function performs fused moe with per-column int8 quantization using native torch."""
B
,
D
=
a
.
shape
# Perform per-token quantization
a_q
,
a_s
=
per_token_quant_int8
(
a
)
# Repeat tokens to match topk
a_q
=
a_q
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
# Also repeat the scale
a_s
=
a_s
.
view
(
B
,
-
1
,
1
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
1
)
# [B*topk, 1]
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
torch
.
float32
,
device
=
a
.
device
)
# Calculate routing
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
# Process each expert
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
# First MLP layer: note that a_s is now per-token
inter_out
=
native_w8a8_per_token_matmul
(
a_q
[
mask
],
w1
[
i
],
a_s
[
mask
],
w1_s
[
i
],
bias
=
None
,
output_dtype
=
torch
.
float32
,
)
# Activation function
act_out
=
SiluAndMul
(
inter_out
)
# Quantize activation output with per-token
act_out_q
,
act_out_s
=
per_token_quant_int8
(
act_out
)
# Second MLP layer
out
[
mask
]
=
native_w8a8_per_token_matmul
(
act_out_q
,
w2
[
i
],
act_out_s
,
w2_s
[
i
],
bias
=
None
,
output_dtype
=
torch
.
float32
,
)
# Apply routing weights and sum
return
(
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
))
.
sum
(
dim
=
1
)
.
to
(
a
.
dtype
)
)
def
native_fp8_fused_moe
(
a
,
w1
,
w2
,
topk_weight
,
topk_ids
,
topk
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
).
float
()
out
=
torch
.
zeros
(
B
*
topk
,
w2
.
shape
[
1
],
dtype
=
torch
.
float32
,
device
=
a
.
device
)
# Calculate routing
topk_weight
=
topk_weight
.
view
(
-
1
)
topk_ids
=
topk_ids
.
view
(
-
1
)
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
ic0
=
torch
.
matmul
(
a
[
mask
],
w1
[
i
].
transpose
(
0
,
1
))
ic1
=
SiluAndMul
(
ic0
)
out
[
mask
]
=
torch
.
matmul
(
ic1
,
w2
[
i
].
transpose
(
0
,
1
))
return
(
(
out
.
view
(
B
,
-
1
,
w2
.
shape
[
1
])
*
topk_weight
.
view
(
B
,
-
1
,
1
).
to
(
out
.
dtype
))
.
sum
(
dim
=
1
)
.
to
(
a
.
dtype
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment