Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5ad296bd
"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "c61b0b294c9153d5ba10704c5b6ebaf561fc8b15"
Unverified
Commit
5ad296bd
authored
Aug 29, 2025
by
Ma Mingfei
Committed by
GitHub
Aug 28, 2025
Browse files
Optimize prefill performance on cpu backend (#8750)
parent
9f81d741
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
681 additions
and
274 deletions
+681
-274
sgl-kernel/csrc/cpu/common.h
sgl-kernel/csrc/cpu/common.h
+118
-1
sgl-kernel/csrc/cpu/gemm.cpp
sgl-kernel/csrc/cpu/gemm.cpp
+30
-14
sgl-kernel/csrc/cpu/gemm.h
sgl-kernel/csrc/cpu/gemm.h
+4
-3
sgl-kernel/csrc/cpu/gemm_fp8.cpp
sgl-kernel/csrc/cpu/gemm_fp8.cpp
+34
-31
sgl-kernel/csrc/cpu/gemm_int8.cpp
sgl-kernel/csrc/cpu/gemm_int8.cpp
+70
-12
sgl-kernel/csrc/cpu/moe.cpp
sgl-kernel/csrc/cpu/moe.cpp
+39
-69
sgl-kernel/csrc/cpu/moe_fp8.cpp
sgl-kernel/csrc/cpu/moe_fp8.cpp
+51
-46
sgl-kernel/csrc/cpu/moe_int8.cpp
sgl-kernel/csrc/cpu/moe_int8.cpp
+334
-96
sgl-kernel/csrc/cpu/qkv_proj.cpp
sgl-kernel/csrc/cpu/qkv_proj.cpp
+1
-2
No files found.
sgl-kernel/csrc/cpu/common.h
View file @
5ad296bd
...
...
@@ -105,7 +105,19 @@ namespace {
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
// parallel routines
// [NB] Parallel Routines
//
// * at::parallel_for - applies for most of generic use cases, this will be compiled
// against openmp in default torch release.
//
// * parallel_for - same function as above, can choose payload partition scheme in
// balance211.
//
// * parallel_2d - parallel for 2 dimensions, used in GEMM, etc.
// this one will do payload balance across 2 dimensions.
//
// grain size for each thread
constexpr
int
GRAIN_SIZE
=
1024
;
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>
::
value
,
int
>::
type
=
0
>
...
...
@@ -113,6 +125,17 @@ inline T div_up(T x, T y) {
return
(
x
+
y
-
1
)
/
y
;
}
// you can only use at::get_thread_num() with at::parallel_for()
// as it is lazy initialized, otherwise it will always return 0.
inline
int
get_thread_num
()
{
#if defined(_OPENMP)
return
omp_get_thread_num
();
#else
return
0
;
#endif
}
// balance payload across each thread
template
<
typename
T
>
inline
void
balance211
(
T
n
,
T
nth
,
T
ith
,
T
&
n_start
,
T
&
n_end
)
{
#if 0
...
...
@@ -153,6 +176,100 @@ inline void parallel_for(int n, const func_t& f) {
#endif
}
// for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42
int
inline
adjust_num_threads
(
int
m
)
{
int
actual_nth
=
at
::
get_num_threads
();
if
(
m
==
1
)
{
return
actual_nth
;
}
return
std
::
max
(
1
,
(
actual_nth
>>
1
)
*
2
);
}
template
<
typename
func_t
>
inline
void
parallel_2d
(
int
m
,
int
n
,
const
func_t
&
f
)
{
// make sure we have even num_threads
int
nth
=
adjust_num_threads
(
m
);
// [NOTE] thread blocking:
//
// 1) prefer square block per thread
// 2) use even number of CPU cores
// 3) use all `num_threads` cores
//
// we have:
// TM * TN = T
// BM / TM = BN / TN
// then:
// TM = ((BM / BN) * T) ^ 0.5
//
float
r
=
float
(
m
)
/
n
;
int
nth_m
=
std
::
ceil
(
std
::
sqrt
(
r
*
nth
));
int
nth_n
=
1
;
for
(;
nth_m
>
0
;
--
nth_m
)
{
nth_n
=
nth
/
nth_m
;
if
(
nth_m
*
nth_n
==
nth
)
{
break
;
}
}
#if defined(_OPENMP)
#pragma omp parallel num_threads(nth)
{
int
ith
=
omp_get_thread_num
();
int
ith_m
=
ith
/
nth_n
;
int
ith_n
=
ith
%
nth_n
;
int
thread_block_m
=
div_up
(
m
,
nth_m
);
int
thread_block_n
=
div_up
(
n
,
nth_n
);
int
begin_m
=
ith_m
*
thread_block_m
;
int
end_m
=
std
::
min
(
m
,
begin_m
+
thread_block_m
);
int
begin_n
=
ith_n
*
thread_block_n
;
int
end_n
=
std
::
min
(
n
,
begin_n
+
thread_block_n
);
f
(
begin_m
,
end_m
,
begin_n
,
end_n
);
}
#else
f
(
0
,
m
,
0
,
n
);
#endif
}
// limit max cache blocks
// when we need to do pre-unpack for weights, e.g. fp8
#define MAX_CACHE_BLOCK_SIZE 4
template
<
typename
T
>
inline
int
get_cache_blocks
(
int
chunk_size
)
{
// L2 2MB and ratio of 50%
const
int
L2_size
=
2048
*
1024
>>
1
;
return
std
::
max
(
1
,
int
(
L2_size
/
(
chunk_size
*
sizeof
(
T
))));
}
template
<
>
inline
int
get_cache_blocks
<
at
::
Float8_e4m3fn
>
(
int
chunk_size
)
{
// fp8 uses bf16 as accumulate type
int
cache_block_size
=
get_cache_blocks
<
at
::
BFloat16
>
(
chunk_size
);
return
std
::
min
(
MAX_CACHE_BLOCK_SIZE
,
cache_block_size
);
}
// 2d sequential loop in range : [mb0, mb1), [nb0, nb1)
template
<
typename
T
,
typename
func_t
>
inline
void
loop_2d
(
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
,
int64_t
chunk_size
,
const
func_t
&
f
)
{
// get number of blocks for L2 in most inner loop
int64_t
cache_blocks_nb
=
get_cache_blocks
<
T
>
(
chunk_size
);
// loop order: [NB / cache_blocks_nb, MB, cache_blocks_nb]
// TODO: implement reverse order of [MB / cache_blocks_mb, NB, cache_blocks_mb]
for
(
int64_t
nbb
=
nb0
;
nbb
<
nb1
;
nbb
+=
cache_blocks_nb
)
{
for
(
int64_t
mb
=
mb0
;
mb
<
mb1
;
++
mb
)
{
for
(
int64_t
nb
=
nbb
;
nb
<
std
::
min
(
nbb
+
cache_blocks_nb
,
nb1
);
++
nb
)
{
f
(
mb
,
nb
,
nb
-
nbb
);
}
}
}
}
// data indexing for dimension collapse
template
<
typename
T
>
inline
T
data_index_init
(
T
offset
)
{
...
...
sgl-kernel/csrc/cpu/gemm.cpp
View file @
5ad296bd
...
...
@@ -254,7 +254,7 @@ void tinygemm_kernel(
return
;
}
// pattern: 1-4-16
// pattern: 1-4-16
, N = 16, 32, 48, 64
constexpr
int64_t
BLOCK_M
=
4
;
constexpr
int64_t
BLOCK_N
=
64
;
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
...
...
@@ -268,35 +268,59 @@ void tinygemm_kernel(
switch
(
mb_size
<<
4
|
nb_size
>>
4
)
{
// mb_size = 1
case
0x11
:
LAUNCH_TINYGEMM_KERNEL_NN
(
1
,
16
);
break
;
case
0x12
:
LAUNCH_TINYGEMM_KERNEL_NN
(
1
,
32
);
break
;
case
0x13
:
LAUNCH_TINYGEMM_KERNEL_NN
(
1
,
48
);
break
;
case
0x14
:
LAUNCH_TINYGEMM_KERNEL_NN
(
1
,
64
);
break
;
// mb_size = 2
case
0x21
:
LAUNCH_TINYGEMM_KERNEL_NN
(
2
,
16
);
break
;
case
0x22
:
LAUNCH_TINYGEMM_KERNEL_NN
(
2
,
32
);
break
;
case
0x23
:
LAUNCH_TINYGEMM_KERNEL_NN
(
2
,
48
);
break
;
case
0x24
:
LAUNCH_TINYGEMM_KERNEL_NN
(
2
,
64
);
break
;
// mb_size = 3
case
0x31
:
LAUNCH_TINYGEMM_KERNEL_NN
(
3
,
16
);
break
;
case
0x32
:
LAUNCH_TINYGEMM_KERNEL_NN
(
3
,
32
);
break
;
case
0x33
:
LAUNCH_TINYGEMM_KERNEL_NN
(
3
,
48
);
break
;
case
0x34
:
LAUNCH_TINYGEMM_KERNEL_NN
(
3
,
64
);
break
;
// mb_size = 4
case
0x41
:
LAUNCH_TINYGEMM_KERNEL_NN
(
4
,
16
);
break
;
case
0x42
:
LAUNCH_TINYGEMM_KERNEL_NN
(
4
,
32
);
break
;
case
0x43
:
LAUNCH_TINYGEMM_KERNEL_NN
(
4
,
48
);
break
;
case
0x44
:
LAUNCH_TINYGEMM_KERNEL_NN
(
4
,
64
);
break
;
default:
TORCH_CHECK
(
false
,
"Unexpected block size, "
,
mb_size
,
"
x
"
,
"
nb_size
"
);
TORCH_CHECK
(
false
,
"Unexpected block size, "
,
mb_size
,
"
x
"
,
nb_size
);
}
}
}
...
...
@@ -318,20 +342,15 @@ void weight_packed_linear_kernel_impl(
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
NB
=
div_up
(
N
,
BLOCK_N
);
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx c) N is small
const
bool
use_brgemm
=
(
M
>
4
)
||
(
!
std
::
is_same_v
<
scalar_t
,
at
::
BFloat16
>
)
||
(
N
<
64
);
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
M
);
// parallel on [MB, NB]
AT_DISPATCH_BOOL
(
bias
!=
nullptr
,
has_bias
,
[
&
]
{
at
::
parallel_for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int64_t
mb
{
0
},
nb
{
0
};
data_index_init
(
begin
,
mb
,
MB
,
nb
,
NB
);
parallel_2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// for brgemm, use float32 for accumulate
alignas
(
64
)
float
Ctmp
[
BLOCK_M
*
BLOCK_N
];
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
UNUSED
(
i
);
loop_2d
<
scalar_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
mb_start
=
mb
*
BLOCK_M
;
int64_t
mb_size
=
std
::
min
(
M
-
mb_start
,
BLOCK_M
);
int64_t
nb_start
=
nb
*
BLOCK_N
;
...
...
@@ -350,10 +369,7 @@ void weight_packed_linear_kernel_impl(
/* ldb */
nb_size
,
/* ldc */
out_strideM
,
/* brg */
use_brgemm
);
// move to the next index
data_index_step
(
mb
,
MB
,
nb
,
NB
);
}
});
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
...
...
sgl-kernel/csrc/cpu/gemm.h
View file @
5ad296bd
...
...
@@ -27,10 +27,10 @@ template <>
inline
bool
can_use_brgemm
<
at
::
Half
>
(
int
M
)
{
return
true
;
}
//
TODO: add u8s8 brgemm,
this requires PyTorch 2.7
// this requires PyTorch 2.7
or above
template
<
>
inline
bool
can_use_brgemm
<
int8_t
>
(
int
M
)
{
return
false
;
return
M
>
4
;
}
template
<
>
...
...
@@ -198,4 +198,5 @@ void tinygemm_kernel(
int64_t
ldb
,
int64_t
ldc
,
bool
brg
,
int64_t
block_size_K
);
int64_t
block_size_K
,
bool
do_unpack
=
true
);
sgl-kernel/csrc/cpu/gemm_fp8.cpp
View file @
5ad296bd
...
...
@@ -2,9 +2,6 @@
#include "gemm.h"
#include "vec.h"
// we use 4x32 for BLOCK_M
#define BLOCK_SIZE_M_SCALE 4
namespace
{
template
<
typename
scalar_t
>
...
...
@@ -250,7 +247,8 @@ struct brgemm {
int
K
,
int
lda
,
int
ldb
,
int
ldc
)
{
int
ldc
,
bool
do_unpack
=
true
)
{
TORCH_CHECK
(
false
,
"struct brgemm: primary template not implemented!"
);
}
};
...
...
@@ -270,18 +268,21 @@ struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
int
K
,
int
lda
,
int
ldb
,
int
ldc
)
{
int
ldc
,
bool
do_unpack
=
true
)
{
constexpr
int
BLOCK_N
=
block_size_n
();
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
const
int
ldb_tmp
=
BLOCK_N
;
if
(
do_unpack
)
{
for
(
int
k
=
0
;
k
<
K
;
k
+=
BLOCK_K
)
{
int
kb_size
=
std
::
min
(
BLOCK_K
,
K
-
k
);
int
idx
=
k
>>
7
;
// k / BLOCK_K where BLOCK_K = 128
unpack_B
(
Btmp
+
k
*
ldb_tmp
,
B
+
k
*
ldb
,
N
,
kb_size
,
ldb
,
ldb_tmp
,
scale
[
idx
]);
}
}
at
::
native
::
cpublas
::
brgemm
(
M
,
N
,
K
,
lda
,
ldb_tmp
,
BLOCK_N
,
/* add_C */
false
,
A
,
Btmp
,
Ctmp
);
...
...
@@ -312,9 +313,11 @@ void tinygemm_kernel(
int64_t
ldb
,
int64_t
ldc
,
bool
brg
,
int64_t
block_size_K
)
{
int64_t
block_size_K
,
bool
do_unpack
=
true
)
{
if
(
brg
)
{
brgemm
<
scalar_t
,
at
::
Float8_e4m3fn
,
has_bias
>::
apply
(
A
,
B
,
C
,
Btmp
,
Ctmp
,
bias
,
scale
,
M
,
N
,
K
,
lda
,
ldb
,
ldc
);
brgemm
<
scalar_t
,
at
::
Float8_e4m3fn
,
has_bias
>::
apply
(
A
,
B
,
C
,
Btmp
,
Ctmp
,
bias
,
scale
,
M
,
N
,
K
,
lda
,
ldb
,
ldc
,
do_unpack
);
return
;
}
...
...
@@ -366,7 +369,7 @@ void fp8_scaled_mm_kernel_impl(
int64_t
block_size_N
,
int64_t
block_size_K
,
int64_t
buffer_size_per_thread
)
{
constexpr
int64_t
BLOCK_M
=
block_size_m
()
*
BLOCK_SIZE_M_SCALE
;
constexpr
int64_t
BLOCK_M
=
block_size_m
();
constexpr
int64_t
BLOCK_N
=
block_size_n
();
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
NB
=
div_up
(
N
,
BLOCK_N
);
...
...
@@ -378,16 +381,12 @@ void fp8_scaled_mm_kernel_impl(
// parallel on [MB, NB]
AT_DISPATCH_BOOL
(
bias
!=
nullptr
,
has_bias
,
[
&
]
{
at
::
parallel_for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int64_t
mb
{
0
},
nb
{
0
};
data_index_init
(
begin
,
mb
,
MB
,
nb
,
NB
);
int
tid
=
at
::
get_thread_num
();
parallel_2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
int
tid
=
get_thread_num
();
scalar_t
*
__restrict__
Btmp
=
buffer
+
tid
*
buffer_size_per_thread
;
float
*
__restrict__
Ctmp
=
(
float
*
)((
void
*
)(
Btmp
+
BLOCK_N
*
K
));
float
*
__restrict__
Ctmp
=
(
float
*
)((
void
*
)(
Btmp
+
MAX_CACHE_BLOCK_SIZE
*
BLOCK_N
*
K
));
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
UNUSED
(
i
);
loop_2d
<
at
::
Float8_e4m3fn
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
const
float
*
scale_ptr
=
scales2
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
;
int64_t
mb_start
=
mb
*
BLOCK_M
;
...
...
@@ -395,11 +394,14 @@ void fp8_scaled_mm_kernel_impl(
int64_t
nb_start
=
nb
*
BLOCK_N
;
int64_t
nb_size
=
std
::
min
(
N
-
nb_start
,
BLOCK_N
);
// only do unpacking for the first row
bool
do_unpack
=
(
mb
==
mb0
);
tinygemm_kernel
<
scalar_t
,
has_bias
>
(
/* A */
mat1
+
mb_start
*
mat1_strideM
,
/* B */
mat2
+
nb_start
*
K
,
// nb * BLOCK_N * K
/* C */
out
+
mb_start
*
out_strideM
+
nb_start
,
/* Btmp */
Btmp
,
/* Btmp */
Btmp
+
nb_offset
*
BLOCK_N
*
K
,
/* Ctmp */
Ctmp
,
/* scale */
scale_ptr
,
/* bias */
bias
+
nb_start
,
...
...
@@ -410,11 +412,9 @@ void fp8_scaled_mm_kernel_impl(
/* ldb */
nb_size
,
/* ldc */
out_strideM
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
// move to the next index
data_index_step
(
mb
,
MB
,
nb
,
NB
);
}
/* block_size_K */
block_size_K
,
/* do_unpack */
do_unpack
);
});
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
...
...
@@ -441,8 +441,10 @@ void tinygemm_kernel(
int64_t
ldb
,
int64_t
ldc
,
bool
brg
,
int64_t
block_size_K
)
{
tinygemm_kernel
<
scalar_t
,
false
>
(
A
,
B
,
C
,
Btmp
,
Ctmp
,
scale
,
nullptr
,
M
,
N
,
K
,
lda
,
ldb
,
ldc
,
brg
,
block_size_K
);
int64_t
block_size_K
,
bool
do_unpack
)
{
tinygemm_kernel
<
scalar_t
,
false
>
(
A
,
B
,
C
,
Btmp
,
Ctmp
,
scale
,
nullptr
,
M
,
N
,
K
,
lda
,
ldb
,
ldc
,
brg
,
block_size_K
,
do_unpack
);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
...
...
@@ -460,7 +462,8 @@ void tinygemm_kernel(
int64_t ldb, \
int64_t ldc, \
bool brg, \
int64_t block_size_K)
int64_t block_size_K, \
bool do_unpack)
INSTANTIATE_TINYGEMM_TEMPLATE
(
at
::
BFloat16
);
INSTANTIATE_TINYGEMM_TEMPLATE
(
at
::
Half
);
...
...
@@ -495,7 +498,7 @@ at::Tensor fp8_scaled_mm_cpu(
int64_t
block_size_N
=
block_size
[
0
];
int64_t
block_size_K
=
block_size
[
1
];
constexpr
int64_t
BLOCK_M
=
block_size_m
()
*
BLOCK_SIZE_M_SCALE
;
constexpr
int64_t
BLOCK_M
=
block_size_m
();
constexpr
int64_t
BLOCK_N
=
block_size_n
();
TORCH_CHECK
(
block_size_N
%
BLOCK_N
==
0
,
"fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"
);
TORCH_CHECK
(
block_size_K
==
BLOCK_K
,
"fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"
);
...
...
@@ -523,7 +526,7 @@ at::Tensor fp8_scaled_mm_cpu(
// Btmp : [T, BLOCK_N * K]
// Ctmp : [T, BLOCK_M * BLOCK_N]
int
num_threads
=
at
::
get_num_threads
();
int64_t
size_per_thread
=
BLOCK_N
*
K
+
BLOCK_M
*
BLOCK_N
*
2
;
int64_t
size_per_thread
=
MAX_CACHE_BLOCK_SIZE
*
BLOCK_N
*
K
+
BLOCK_M
*
BLOCK_N
*
2
;
auto
buffer
=
at
::
empty
({
num_threads
,
size_per_thread
},
mat1
.
options
());
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
out_dtype
,
"fp8_scaled_mm_kernel_impl"
,
[
&
]
{
...
...
sgl-kernel/csrc/cpu/gemm_int8.cpp
View file @
5ad296bd
...
...
@@ -4,6 +4,61 @@
namespace
{
template
<
typename
scalar_t
,
bool
has_bias
,
int
BLOCK_N
>
struct
scale_C
{
static
inline
void
apply
(
scalar_t
*
__restrict__
C
,
const
int32_t
*
__restrict__
Ctmp
,
const
int32_t
*
__restrict__
Bcomp
,
const
float
*
__restrict__
bias
,
float
As
,
const
float
*
__restrict__
Bs
)
{
TORCH_CHECK
(
false
,
"scale_C: scalar path not implemented!"
);
}
};
#if defined(CPU_CAPABILITY_AVX512)
template
<
bool
has_bias
,
int
BLOCK_N
>
struct
scale_C
<
at
::
BFloat16
,
has_bias
,
BLOCK_N
>
{
static
inline
void
apply
(
at
::
BFloat16
*
__restrict__
C
,
const
int32_t
*
__restrict__
Ctmp
,
const
int32_t
*
__restrict__
Bcomp
,
const
float
*
__restrict__
bias
,
float
As
,
const
float
*
__restrict__
Bs
)
{
constexpr
int
COLS
=
BLOCK_N
/
16
;
static_assert
(
COLS
%
2
==
0
);
__m512
vc
[
COLS
];
__m512
vd0
=
_mm512_set1_ps
(
As
);
auto
compute
=
[
&
](
auto
col
)
{
__m512
vd1
=
_mm512_loadu_ps
(
Bs
+
col
*
16
);
__m512i
vcomp
=
_mm512_loadu_si512
(
Bcomp
+
col
*
16
);
__m512i
vc32
=
_mm512_loadu_si512
(
Ctmp
+
col
*
16
);
vc
[
col
]
=
_mm512_cvtepi32_ps
(
_mm512_sub_epi32
(
vc32
,
vcomp
));
if
constexpr
(
has_bias
)
{
__m512
vbias
=
_mm512_loadu_ps
(
bias
+
col
*
16
);
vc
[
col
]
=
_mm512_fmadd_ps
(
_mm512_mul_ps
(
vc
[
col
],
vd0
),
vd1
,
vbias
);
}
else
{
vc
[
col
]
=
_mm512_mul_ps
(
_mm512_mul_ps
(
vc
[
col
],
vd0
),
vd1
);
}
};
Unroll
<
COLS
>
{}(
compute
);
auto
storec
=
[
&
](
auto
col
)
{
// for COLS = 2, 4 use 512bit store
if
constexpr
(
col
%
2
==
0
)
{
_mm512_storeu_si512
(
reinterpret_cast
<
__m512i
*>
((
C
+
col
*
16
)),
(
__m512i
)(
_mm512_cvtne2ps_pbh
(
vc
[
col
+
1
],
vc
[
col
+
0
])));
}
};
Unroll
<
COLS
>
{}(
storec
);
}
};
#endif
template
<
typename
scalar_t
,
bool
has_bias
,
int
BLOCK_M
,
int
BLOCK_N
>
struct
tinygemm_kernel_nn
{
static
inline
void
apply
(
...
...
@@ -169,6 +224,17 @@ void tinygemm_kernel(
// B compensation
const
int32_t
*
Bcomp
=
reinterpret_cast
<
const
int32_t
*>
(
B
+
block_size_n
()
*
K
);
if
(
brg
)
{
constexpr
int
BLOCK_N
=
block_size_n
();
at
::
native
::
cpublas
::
brgemm
(
M
,
N
,
K
,
lda
,
ldb
,
BLOCK_N
,
/* add_C */
false
,
A
,
B
,
Ctmp
);
// apply compensation and scale
for
(
int64_t
m
=
0
;
m
<
M
;
++
m
)
{
scale_C
<
scalar_t
,
has_bias
,
BLOCK_N
>::
apply
(
C
+
m
*
ldc
,
Ctmp
+
m
*
BLOCK_N
,
Bcomp
,
bias
,
As
[
m
],
Bs
);
}
return
;
}
// pattern: 1-4-16
constexpr
int64_t
BLOCK_M
=
4
;
constexpr
int64_t
BLOCK_N
=
64
;
...
...
@@ -233,22 +299,17 @@ void int8_scaled_mm_kernel_impl(
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
NB
=
div_up
(
N
,
BLOCK_N
);
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
const
bool
use_brgemm
=
false
;
const
bool
use_brgemm
=
can_use_brgemm
<
int8_t
>
(
M
);
// K + 4 after compensation
const
int64_t
packed_row_size
=
get_row_size
<
int8_t
>
(
K
);
AT_DISPATCH_BOOL
(
bias
!=
nullptr
,
has_bias
,
[
&
]
{
at
::
parallel_for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int64_t
mb
{
0
},
nb
{
0
};
data_index_init
(
begin
,
mb
,
MB
,
nb
,
NB
);
parallel_2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// for brgemm, use int32_t for accumulate
alignas
(
64
)
int32_t
Ctmp
[
BLOCK_M
*
BLOCK_N
];
for
(
int
i
=
begin
;
i
<
end
;
++
i
)
{
UNUSED
(
i
);
loop_2d
<
int8_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int
mb_start
=
mb
*
BLOCK_M
;
int
mb_size
=
std
::
min
(
M
-
mb_start
,
BLOCK_M
);
int
nb_start
=
nb
*
BLOCK_N
;
...
...
@@ -269,10 +330,7 @@ void int8_scaled_mm_kernel_impl(
/* ldb */
nb_size
,
/* ldc */
N
,
/* brg */
use_brgemm
);
// move to the next index
data_index_step
(
mb
,
MB
,
nb
,
NB
);
}
});
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
...
...
sgl-kernel/csrc/cpu/moe.cpp
View file @
5ad296bd
...
...
@@ -579,36 +579,31 @@ void fused_experts_kernel_impl(
const
int64_t
stride_e
=
2
*
N
*
K
;
const
int64_t
stride_n
=
K
;
int64_t
avg_M
=
std
::
max
(
int64_t
(
1
),
M
*
topk
/
E
);
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
avg_M
);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at
::
parallel_
for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
scalar_t
*
__restrict__
A
=
A_tmp
+
tid
*
BLOCK_M
*
K
;
float
*
__restrict__
C0
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
float
*
__restrict__
C1
=
C0
+
BLOCK_M
*
BLOCK_N
;
bool
is_brgemm_used
=
false
;
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB
;
int64_t
nb
=
i
%
NB
;
// nb0 from top half and nb1 from bottom half
int64_t
nb0
=
nb
,
nb1
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb0
*
BLOCK_N
,
BLOCK_N
);
loop_2d
<
scalar_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
*
2
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
// nb_upper from top half and nb_lower from bottom half
int64_t
nb_upper
=
nb
,
nb_lower
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
// B shape [K, n_size] in vnni format
int32_t
expert_id
=
expert_ids
[
mb
];
const
scalar_t
*
__restrict__
B0
=
packed_w1
+
expert_id
*
stride_e
+
nb
0
*
BLOCK_N
*
stride_n
;
const
scalar_t
*
__restrict__
B1
=
packed_w1
+
expert_id
*
stride_e
+
nb
1
*
BLOCK_N
*
stride_n
;
const
scalar_t
*
__restrict__
B0
=
packed_w1
+
expert_id
*
stride_e
+
nb
_upper
*
BLOCK_N
*
stride_n
;
const
scalar_t
*
__restrict__
B1
=
packed_w1
+
expert_id
*
stride_e
+
nb
_lower
*
BLOCK_N
*
stride_n
;
// 1.a load A
const
int32_t
*
A_ids
=
sorted_ids
+
mb
*
BLOCK_M
;
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
int32_t
index
=
A_ids
[
m
]
/
topk
;
copy_stub
(
A
+
m
*
K
,
input
+
index
*
K
,
K
);
...
...
@@ -659,9 +654,9 @@ void fused_experts_kernel_impl(
/* ldb */
n_size
,
/* ldc */
N
);
}
}
}
);
if
(
is
_brgemm
_used
)
{
if
(
use
_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
...
...
@@ -676,24 +671,16 @@ void fused_experts_kernel_impl(
const
int64_t
stride_oc
=
IC
;
// parallel on [MB2, NB2]
at
::
parallel_
for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB2
,
NB2
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
// we won't be using C1 for gemm2
float
*
__restrict__
C
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
bool
is_brgemm_used
=
false
;
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
loop_2d
<
scalar_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
IC
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
// A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again
const
scalar_t
*
__restrict__
A
=
ic1
+
offsets
[
mb
]
*
N
;
...
...
@@ -736,9 +723,9 @@ void fused_experts_kernel_impl(
float
weight
=
topk_weights
[
index
];
copy_mul_stub
(
ic2
+
index
*
K
+
nb
*
BLOCK_N
,
C
+
m
*
BLOCK_N
,
weight
,
n_size
);
}
}
}
);
if
(
is
_brgemm
_used
)
{
if
(
use
_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
...
...
@@ -776,36 +763,27 @@ void shared_expert_kernel_impl(
TORCH_CHECK
(
N
%
BLOCK_N
==
0
,
"Fixme when N is not multiples of "
,
BLOCK_N
);
const
int64_t
stride_n
=
K
;
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
M
);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at
::
parallel_
for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
float
*
__restrict__
C0
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
float
*
__restrict__
C1
=
C0
+
BLOCK_M
*
BLOCK_N
;
bool
is_brgemm_used
=
false
;
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB
;
int64_t
nb
=
i
%
NB
;
// nb0 from top half and nb1 from bottom half
int64_t
nb0
=
nb
,
nb1
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb0
*
BLOCK_N
,
BLOCK_N
);
loop_2d
<
scalar_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
*
2
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
// nb_upper from top half and nb_lower from bottom half
int64_t
nb_upper
=
nb
,
nb_lower
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
// int64_t mb_start = mb * BLOCK_M;
// int64_t mb_size = std::min(M - mb_start, BLOCK_M);
// A shape [m_size, K]
const
scalar_t
*
A
=
input
+
mb
*
BLOCK_M
*
K
;
// B shape [K, n_size] in vnni format
const
scalar_t
*
__restrict__
B0
=
packed_w1
+
nb0
*
BLOCK_N
*
stride_n
;
const
scalar_t
*
__restrict__
B1
=
packed_w1
+
nb1
*
BLOCK_N
*
stride_n
;
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
const
scalar_t
*
__restrict__
B0
=
packed_w1
+
nb_upper
*
BLOCK_N
*
stride_n
;
const
scalar_t
*
__restrict__
B1
=
packed_w1
+
nb_lower
*
BLOCK_N
*
stride_n
;
if
(
use_brgemm
)
{
// 1.b gemm: C0 = A @ B0
...
...
@@ -850,9 +828,9 @@ void shared_expert_kernel_impl(
/* ldb */
n_size
,
/* ldc */
N
);
}
}
}
);
if
(
is
_brgemm
_used
)
{
if
(
use
_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
...
...
@@ -866,24 +844,16 @@ void shared_expert_kernel_impl(
const
int64_t
stride_oc
=
IC
;
// parallel on [MB2, NB2]
at
::
parallel_
for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB2
,
NB2
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
// we won't be using C1 for gemm2
float
*
__restrict__
C
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
bool
is_brgemm_used
=
false
;
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
loop_2d
<
scalar_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
IC
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
// A shape [m_size, IC]
const
scalar_t
*
__restrict__
A
=
ic1
+
mb
*
BLOCK_M
*
N
;
...
...
@@ -922,9 +892,9 @@ void shared_expert_kernel_impl(
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
add_mul_stub
(
out
+
m
*
K
,
C
+
m
*
BLOCK_N
,
fused_out
+
m
*
K
,
routed_scaling_factor
,
n_size
);
}
}
}
);
if
(
is
_brgemm
_used
)
{
if
(
use
_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
...
...
@@ -1086,7 +1056,7 @@ at::Tensor fused_experts_cpu(
//
// for fp8 w8a16:
// 7. intermediate_cache0 : [M * topk, 2N]
// 8. B_tmp : [T, BLOCK_N, std::max(K, N)]
// 8. B_tmp : [T,
MAX_CACHE_BLOCK_SIZE,
BLOCK_N, std::max(K, N)]
//
int64_t
buffer_size_nbytes
=
M
*
topk
*
N
*
2
+
M
*
topk
*
K
*
2
+
num_threads
*
BLOCK_M
*
K
*
(
use_int8_w8a8
?
1
:
2
)
+
...
...
@@ -1096,7 +1066,7 @@ at::Tensor fused_experts_cpu(
buffer_size_nbytes
+=
std
::
max
(
M
*
K
,
M
*
topk
*
N
)
+
M
*
topk
*
sizeof
(
float
);
}
if
(
use_fp8_w8a16
)
{
buffer_size_nbytes
+=
M
*
topk
*
2
*
N
*
2
+
num_threads
*
BLOCK_N
*
std
::
max
(
K
,
N
)
*
2
;
buffer_size_nbytes
+=
M
*
topk
*
2
*
N
*
2
+
num_threads
*
MAX_CACHE_BLOCK_SIZE
*
BLOCK_N
*
std
::
max
(
K
,
N
)
*
2
;
}
auto
buffer2
=
at
::
empty
({
buffer_size_nbytes
},
hidden_states
.
options
().
dtype
(
at
::
kChar
));
...
...
@@ -1268,7 +1238,7 @@ at::Tensor shared_expert_cpu(
//
// for fp8 w8a16:
// 5. intermediate_cache0 : [M, 2N]
// 6. B_tmp: [T, BLOCK_M, max(K, N)]
// 6. B_tmp: [T,
MAX_CACHE_BLOCK_SIZE,
BLOCK_M, max(K, N)]
//
int
num_threads
=
at
::
get_num_threads
();
int64_t
buffer_size_nbytes
=
M
*
N
*
2
+
num_threads
*
2
*
BLOCK_M
*
BLOCK_N
*
sizeof
(
float
);
...
...
@@ -1277,7 +1247,7 @@ at::Tensor shared_expert_cpu(
buffer_size_nbytes
+=
std
::
max
(
M
*
K
,
M
*
N
)
+
M
*
sizeof
(
float
);
}
if
(
use_fp8_w8a16
)
{
buffer_size_nbytes
+=
M
*
2
*
N
*
2
+
num_threads
*
BLOCK_M
*
std
::
max
(
K
,
N
)
*
2
;
buffer_size_nbytes
+=
M
*
2
*
N
*
2
+
num_threads
*
MAX_CACHE_BLOCK_SIZE
*
BLOCK_M
*
std
::
max
(
K
,
N
)
*
2
;
}
auto
buffer
=
at
::
empty
({
buffer_size_nbytes
},
hidden_states
.
options
().
dtype
(
at
::
kChar
));
...
...
sgl-kernel/csrc/cpu/moe_fp8.cpp
View file @
5ad296bd
...
...
@@ -174,18 +174,18 @@ void fused_experts_fp8_kernel_impl(
const
int64_t
stride_e
=
2
*
N
*
K
;
const
int64_t
stride_n
=
K
;
int64_t
avg_M
=
std
::
max
(
int64_t
(
1
),
M
*
topk
/
E
);
const
bool
use_brgemm
=
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
avg_M
);
int64_t
B_tmp_size_per_thread
=
MAX_CACHE_BLOCK_SIZE
*
BLOCK_N
*
std
::
max
(
K
,
N
);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at
::
parallel_
for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
scalar_t
*
__restrict__
A
=
A_tmp
+
tid
*
BLOCK_M
*
K
;
bool
is_brgemm_used
=
false
;
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB
;
int64_t
nb
=
i
%
NB
;
loop_2d
<
at
::
Float8_e4m3fn
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
n_size
=
std
::
min
(
2
*
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
// B shape [K, n_size] in vnni format
...
...
@@ -194,13 +194,14 @@ void fused_experts_fp8_kernel_impl(
const
float
*
__restrict__
Bs
=
w1s
+
expert_id
*
scale_size_N
*
scale_size_K
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
;
// do unpacking for the first row or a new expert
int32_t
pre_expert_id
=
mb
==
0
?
-
1
:
expert_ids
[
mb
-
1
];
bool
do_unpack
=
(
mb
==
mb0
)
||
(
expert_id
!=
pre_expert_id
);
// 1.a load A
const
int32_t
*
A_ids
=
sorted_ids
+
mb
*
BLOCK_M
;
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
const
bool
use_brgemm
=
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
int32_t
index
=
A_ids
[
m
]
/
topk
;
copy_stub
(
A
+
m
*
K
,
input
+
index
*
K
,
K
);
...
...
@@ -211,7 +212,7 @@ void fused_experts_fp8_kernel_impl(
/* A */
A
,
/* B */
B
,
/* C */
ic0
+
offset
*
2
*
N
+
nb
*
BLOCK_N
,
/* Btmp */
B_tmp
+
tid
*
B
LOCK_N
*
std
::
max
(
K
,
N
)
,
/* Btmp */
B_tmp
+
tid
*
B
_tmp_size_per_thread
+
nb_offset
*
BLOCK_N
*
K
,
/* Ctmp */
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
,
/* scale */
Bs
,
/* M */
m_size
,
...
...
@@ -221,10 +222,11 @@ void fused_experts_fp8_kernel_impl(
/* ldb */
n_size
,
/* ldc */
2
*
N
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
}
/* block_size_K */
block_size_K
,
/* do_unpack */
do_unpack
);
});
if
(
is
_brgemm
_used
)
{
if
(
use
_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
...
...
@@ -248,22 +250,14 @@ void fused_experts_fp8_kernel_impl(
const
int64_t
stride_oc
=
IC
;
// parallel on [MB2, NB2]
at
::
parallel_
for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int
tid
=
at
::
get_thread_num
();
parallel_
2d
(
MB2
,
NB2
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
int
tid
=
get_thread_num
();
alignas
(
64
)
scalar_t
C
[
BLOCK_M
*
BLOCK_K
];
bool
is_brgemm_used
=
false
;
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
loop_2d
<
at
::
Float8_e4m3fn
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
IC
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
const
bool
use_brgemm
=
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
// A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again
const
scalar_t
*
__restrict__
A
=
ic1
+
offsets
[
mb
]
*
N
;
...
...
@@ -275,11 +269,15 @@ void fused_experts_fp8_kernel_impl(
const
float
*
__restrict__
Bs
=
w2s
+
expert_id
*
scale_size_N
*
scale_size_K
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
;
// do unpacking for the first row or a new expert
int32_t
pre_expert_id
=
mb
==
0
?
-
1
:
expert_ids
[
mb
-
1
];
bool
do_unpack
=
(
mb
==
mb0
)
||
(
expert_id
!=
pre_expert_id
);
tinygemm_kernel
<
scalar_t
>
(
/* A */
A
,
/* B */
B
,
/* C */
C
,
/* Btmp */
B_tmp
+
tid
*
B
LOCK_N
*
std
::
max
(
K
,
N
)
,
/* Btmp */
B_tmp
+
tid
*
B
_tmp_size_per_thread
+
nb_offset
*
BLOCK_N
*
IC
,
/* Ctmp */
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
,
/* scale */
Bs
,
/* M */
m_size
,
...
...
@@ -289,7 +287,8 @@ void fused_experts_fp8_kernel_impl(
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
/* block_size_K */
block_size_K
,
/* do_unpack */
do_unpack
);
// 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32
...
...
@@ -298,9 +297,9 @@ void fused_experts_fp8_kernel_impl(
float
weight
=
topk_weights
[
index
];
copy_mul_stub
(
ic2
+
index
*
K
+
nb
*
BLOCK_N
,
C
+
m
*
BLOCK_N
,
weight
,
n_size
);
}
}
}
);
if
(
is
_brgemm
_used
)
{
if
(
use
_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
...
...
@@ -374,20 +373,23 @@ void shared_expert_fp8_kernel_impl(
const
bool
use_brgemm
=
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
M
);
at
::
parallel_for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int
tid
=
at
::
get_thread_num
();
int64_t
B_tmp_size_per_thread
=
MAX_CACHE_BLOCK_SIZE
*
BLOCK_N
*
std
::
max
(
K
,
N
);
parallel_2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
int
tid
=
get_thread_num
();
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB
;
int64_t
nb
=
i
%
NB
;
loop_2d
<
at
::
Float8_e4m3fn
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
n_size
=
std
::
min
(
2
*
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
// do unpacking for the first row
bool
do_unpack
=
(
mb
==
mb0
);
tinygemm_kernel
<
scalar_t
>
(
/* A */
input
+
mb
*
BLOCK_M
*
K
,
/* B */
packed_w1
+
nb
*
BLOCK_N
*
K
,
/* C */
ic0
+
mb
*
BLOCK_M
*
2
*
N
+
nb
*
BLOCK_N
,
/* Btmp */
B_tmp
+
tid
*
B
LOCK_N
*
std
::
max
(
K
,
N
)
,
/* Btmp */
B_tmp
+
tid
*
B
_tmp_size_per_thread
+
nb_offset
*
BLOCK_N
*
K
,
/* Ctmp */
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
,
/* scale */
w1s
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
,
/* M */
m_size
,
...
...
@@ -397,8 +399,9 @@ void shared_expert_fp8_kernel_impl(
/* ldb */
n_size
,
/* ldc */
2
*
N
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
}
/* block_size_K */
block_size_K
,
/* do_unpack */
do_unpack
);
});
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
...
...
@@ -421,22 +424,23 @@ void shared_expert_fp8_kernel_impl(
scale_size_K
=
div_up
(
N
,
block_size_K
);
// parallel on [MB2, NB2]
at
::
parallel_
for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int
tid
=
at
::
get_thread_num
();
parallel_
2d
(
MB2
,
NB2
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
int
tid
=
get_thread_num
();
alignas
(
64
)
scalar_t
C
[
BLOCK_M
*
BLOCK_K
];
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
loop_2d
<
at
::
Float8_e4m3fn
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
IC
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
// do unpacking for the first row
bool
do_unpack
=
(
mb
==
mb0
);
// 2.a gemm: C = A @ B
tinygemm_kernel
<
scalar_t
>
(
/* A */
ic1
+
mb
*
BLOCK_M
*
N
,
/* B */
packed_w2
+
nb
*
BLOCK_N
*
N
,
/* C */
C
,
/* Btmp */
B_tmp
+
tid
*
B
LOCK_N
*
std
::
max
(
K
,
N
)
,
/* Btmp */
B_tmp
+
tid
*
B
_tmp_size_per_thread
+
nb_offset
*
BLOCK_N
*
IC
,
/* Ctmp */
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
,
/* scale */
w2s
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
,
/* M */
m_size
,
...
...
@@ -446,7 +450,8 @@ void shared_expert_fp8_kernel_impl(
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
/* block_size_K */
block_size_K
,
/* do_unpack */
do_unpack
);
// 2.b copy from C to output and add fused_experts_out
scalar_t
*
__restrict__
out
=
output
+
mb
*
BLOCK_M
*
K
+
nb
*
BLOCK_N
;
...
...
@@ -454,7 +459,7 @@ void shared_expert_fp8_kernel_impl(
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
add_mul_stub
(
out
+
m
*
K
,
C
+
m
*
BLOCK_N
,
fused_out
+
m
*
K
,
routed_scaling_factor
,
n_size
);
}
}
}
);
});
if
(
use_brgemm
)
{
...
...
sgl-kernel/csrc/cpu/moe_int8.cpp
View file @
5ad296bd
...
...
@@ -109,6 +109,120 @@ inline void add_mul_stub(
}
}
template
<
typename
scalar_t
,
int
BLOCK_N
>
inline
void
silu_and_mul
(
scalar_t
*
__restrict__
C
,
const
int32_t
*
__restrict__
C0
,
// x: x0, x1
const
int32_t
*
__restrict__
C1
,
// y: y0, y1
const
float
*
__restrict__
As
,
const
float
*
__restrict__
Bs0
,
const
float
*
__restrict__
Bs1
,
const
int32_t
*
__restrict__
Bcomp0
,
const
int32_t
*
__restrict__
Bcomp1
,
int64_t
m_size
,
int64_t
N
)
{
#if defined(CPU_CAPABILITY_AVX512)
constexpr
int
COLS
=
BLOCK_N
/
16
;
static_assert
(
COLS
%
2
==
0
);
__m512
vc0
[
COLS
];
__m512
vc1
[
COLS
];
__m512i
vcomp0
[
COLS
];
__m512i
vcomp1
[
COLS
];
__m512
vas
;
__m512
vbs0
[
COLS
];
__m512
vbs1
[
COLS
];
auto
load_scale_and_comp
=
[
&
](
auto
col
)
{
vcomp0
[
col
]
=
_mm512_loadu_si512
(
Bcomp0
+
col
*
16
);
vcomp1
[
col
]
=
_mm512_loadu_si512
(
Bcomp1
+
col
*
16
);
vbs0
[
col
]
=
_mm512_loadu_ps
(
Bs0
+
col
*
16
);
vbs1
[
col
]
=
_mm512_loadu_ps
(
Bs1
+
col
*
16
);
};
Unroll
<
COLS
>
{}(
load_scale_and_comp
);
auto
scalec
=
[
&
](
auto
col
,
int64_t
m
)
{
// update As
vas
=
_mm512_set1_ps
(
As
[
m
]);
// C = As * (C - Bcomp) * Bs
__m512i
vc32_0
=
_mm512_loadu_si512
(
C0
+
m
*
BLOCK_N
+
col
*
16
);
__m512i
vc32_1
=
_mm512_loadu_si512
(
C1
+
m
*
BLOCK_N
+
col
*
16
);
vc0
[
col
]
=
_mm512_cvtepi32_ps
(
_mm512_sub_epi32
(
vc32_0
,
vcomp0
[
col
]));
vc1
[
col
]
=
_mm512_cvtepi32_ps
(
_mm512_sub_epi32
(
vc32_1
,
vcomp1
[
col
]));
vc0
[
col
]
=
_mm512_mul_ps
(
_mm512_mul_ps
(
vc0
[
col
],
vas
),
vbs0
[
col
]);
vc1
[
col
]
=
_mm512_mul_ps
(
_mm512_mul_ps
(
vc1
[
col
],
vas
),
vbs1
[
col
]);
};
using
bVec
=
at
::
vec
::
Vectorized
<
scalar_t
>
;
using
fVec
=
at
::
vec
::
Vectorized
<
float
>
;
const
fVec
one
=
fVec
(
1.
f
);
auto
silu_and_mul
=
[
&
](
auto
col
)
{
fVec
x
=
fVec
(
vc0
[
col
]);
fVec
y
=
fVec
(
vc1
[
col
]);
x
=
x
/
(
one
+
x
.
neg
().
exp_u20
());
vc0
[
col
]
=
x
*
y
;
};
auto
storec
=
[
&
](
auto
col
,
int64_t
m
)
{
if
constexpr
(
col
%
2
==
0
)
{
fVec
x0
=
fVec
(
vc0
[
col
+
0
]);
fVec
x1
=
fVec
(
vc0
[
col
+
1
]);
bVec
out_vec
=
convert_from_float_ext
<
scalar_t
>
(
x0
,
x1
);
out_vec
.
store
(
C
+
m
*
N
+
col
*
16
);
}
};
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
Unroll
<
COLS
>
{}(
scalec
,
m
);
Unroll
<
COLS
>
{}(
silu_and_mul
);
Unroll
<
COLS
>
{}(
storec
,
m
);
}
#else
TORCH_CHECK
(
false
,
"silu_and_mul: scalar path not implemented!"
);
#endif
}
template
<
int
BLOCK_N
>
inline
void
scale_C
(
float
*
__restrict__
C
,
const
int32_t
*
__restrict__
Ctmp
,
const
float
*
__restrict__
As
,
const
float
*
__restrict__
Bs
,
const
int32_t
*
__restrict__
Bcomp
,
int64_t
m_size
)
{
#if defined(CPU_CAPABILITY_AVX512)
constexpr
int
COLS
=
BLOCK_N
/
16
;
static_assert
(
COLS
%
2
==
0
);
__m512
vc
[
COLS
];
__m512i
vcomp
[
COLS
];
__m512
vas
;
__m512
vbs
[
COLS
];
auto
load_scale_and_comp
=
[
&
](
auto
col
)
{
vcomp
[
col
]
=
_mm512_loadu_si512
(
Bcomp
+
col
*
16
);
vbs
[
col
]
=
_mm512_loadu_ps
(
Bs
+
col
*
16
);
};
Unroll
<
COLS
>
{}(
load_scale_and_comp
);
auto
scalec
=
[
&
](
auto
col
,
int64_t
m
)
{
// update As
vas
=
_mm512_set1_ps
(
As
[
m
]);
// C = As * (C - Bcomp) * Bs
__m512i
vc32
=
_mm512_loadu_si512
(
Ctmp
+
m
*
BLOCK_N
+
col
*
16
);
vc
[
col
]
=
_mm512_cvtepi32_ps
(
_mm512_sub_epi32
(
vc32
,
vcomp
[
col
]));
vc
[
col
]
=
_mm512_mul_ps
(
_mm512_mul_ps
(
vc
[
col
],
vas
),
vbs
[
col
]);
_mm512_storeu_ps
(
C
+
m
*
BLOCK_N
+
col
*
16
,
vc
[
col
]);
};
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
Unroll
<
COLS
>
{}(
scalec
,
m
);
}
#else
TORCH_CHECK
(
false
,
"scale_C: scalar path not implemented!"
);
#endif
}
/// gemm for w13
template
<
typename
scalar_t
,
int
BLOCK_M
,
int
BLOCK_N
>
struct
tinygemm_kernel_vnni
{
...
...
@@ -515,28 +629,31 @@ void fused_experts_int8_kernel_impl(
const
int64_t
stride_e
=
2
*
N
*
packed_K
;
const
int64_t
stride_n
=
packed_K
;
int64_t
avg_M
=
std
::
max
(
int64_t
(
1
),
M
*
topk
/
E
);
const
bool
use_brgemm
=
can_use_brgemm
<
int8_t
>
(
avg_M
);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at
::
parallel_
for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
uint8_t
*
__restrict__
A
=
A_tmp
+
tid
*
BLOCK_M
*
K
;
int32_t
*
__restrict__
C0
=
reinterpret_cast
<
int32_t
*>
(
C_tmp
)
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
int32_t
*
__restrict__
C1
=
C0
+
BLOCK_M
*
BLOCK_N
;
alignas
(
64
)
float
As
[
BLOCK_M
];
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB
;
int64_t
nb
=
i
%
NB
;
// nb0 from top half and nb1 from bottom half
int64_t
nb0
=
nb
,
nb1
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb0
*
BLOCK_N
,
BLOCK_N
);
loop_2d
<
int8_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
*
2
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
// nb_upper from top half and nb_lower from bottom half
int64_t
nb_upper
=
nb
,
nb_lower
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
// B shape [K, n_size] in vnni format
int32_t
expert_id
=
expert_ids
[
mb
];
const
int8_t
*
__restrict__
B0
=
packed_w1
+
expert_id
*
stride_e
+
nb
0
*
BLOCK_N
*
stride_n
;
const
int8_t
*
__restrict__
B1
=
packed_w1
+
expert_id
*
stride_e
+
nb
1
*
BLOCK_N
*
stride_n
;
const
float
*
__restrict__
Bs0
=
w1s
+
expert_id
*
2
*
N
+
nb
0
*
BLOCK_N
;
const
float
*
__restrict__
Bs1
=
w1s
+
expert_id
*
2
*
N
+
nb
1
*
BLOCK_N
;
const
int8_t
*
__restrict__
B0
=
packed_w1
+
expert_id
*
stride_e
+
nb
_upper
*
BLOCK_N
*
stride_n
;
const
int8_t
*
__restrict__
B1
=
packed_w1
+
expert_id
*
stride_e
+
nb
_lower
*
BLOCK_N
*
stride_n
;
const
float
*
__restrict__
Bs0
=
w1s
+
expert_id
*
2
*
N
+
nb
_upper
*
BLOCK_N
;
const
float
*
__restrict__
Bs1
=
w1s
+
expert_id
*
2
*
N
+
nb
_lower
*
BLOCK_N
;
// 1.a load A
const
int32_t
*
A_ids
=
sorted_ids
+
mb
*
BLOCK_M
;
...
...
@@ -548,7 +665,42 @@ void fused_experts_int8_kernel_impl(
As
[
m
]
=
As_tmp
[
index
];
}
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
if
(
use_brgemm
)
{
// 1.b gemm: C0 = A @ B0
at
::
native
::
cpublas
::
brgemm
(
/* M */
m_size
,
/* N */
n_size
,
/* K */
K
,
/* lda */
K
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* add_C */
false
,
/* A */
A
,
/* B */
B0
,
/* C */
C0
);
// 1.c gemm: C1 = A @ B1
at
::
native
::
cpublas
::
brgemm
(
/* M */
m_size
,
/* N */
n_size
,
/* K */
K
,
/* lda */
K
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* add_C */
false
,
/* A */
A
,
/* B */
B1
,
/* C */
C1
);
const
int32_t
*
Bcomp0
=
reinterpret_cast
<
const
int32_t
*>
(
B0
+
block_size_n
()
*
K
);
const
int32_t
*
Bcomp1
=
reinterpret_cast
<
const
int32_t
*>
(
B1
+
block_size_n
()
*
K
);
// 1.d silu and mul
const
int64_t
offset
=
offsets
[
mb
];
silu_and_mul
<
scalar_t
,
BLOCK_N
>
(
ic1
+
offset
*
N
+
nb
*
BLOCK_N
,
C0
,
C1
,
As
,
Bs0
,
Bs1
,
Bcomp0
,
Bcomp1
,
m_size
,
N
);
}
else
{
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
const
int64_t
offset
=
offsets
[
mb
];
tinygemm_kernel
(
/* A */
A
,
...
...
@@ -567,6 +719,11 @@ void fused_experts_int8_kernel_impl(
}
});
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
at
::
parallel_for
(
0
,
M
*
topk
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
m
=
begin
;
m
<
end
;
++
m
)
{
...
...
@@ -584,16 +741,13 @@ void fused_experts_int8_kernel_impl(
const
int64_t
stride_oc
=
packed_N
;
// parallel on [MB2, NB2]
at
::
parallel_
for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB2
,
NB2
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
int
tid
=
at
::
get_thread_num
();
// we won't be using C1 for gemm2
int
tid
=
get_thread_num
();
float
*
__restrict__
C
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
int32_t
*
__restrict__
C32
=
reinterpret_cast
<
int32_t
*>
(
C
+
BLOCK_M
*
BLOCK_N
);
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
loop_2d
<
int8_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
IC
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
...
...
@@ -609,6 +763,23 @@ void fused_experts_int8_kernel_impl(
const
float
*
__restrict__
Bs
=
w2s
+
expert_id
*
K
+
nb
*
BLOCK_N
;
// 2.a gemm: C = A @ B
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm
(
/* M */
m_size
,
/* N */
n_size
,
/* K */
IC
,
/* lda */
IC
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* add_C */
false
,
/* A */
A
,
/* B */
B
,
/* C */
C32
);
// apply scales
const
int32_t
*
Bcomp
=
reinterpret_cast
<
const
int32_t
*>
(
B
+
block_size_n
()
*
IC
);
scale_C
<
BLOCK_N
>
(
C
,
C32
,
As
,
Bs
,
Bcomp
,
m_size
);
}
else
{
tinygemm_kernel
<
scalar_t
>
(
/* A */
A
,
/* B */
B
,
...
...
@@ -621,6 +792,7 @@ void fused_experts_int8_kernel_impl(
/* lda */
IC
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
);
}
// 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32
...
...
@@ -629,6 +801,10 @@ void fused_experts_int8_kernel_impl(
float
weight
=
topk_weights
[
index
];
copy_mul_stub
(
ic2
+
index
*
K
+
nb
*
BLOCK_N
,
C
+
m
*
BLOCK_N
,
weight
,
n_size
);
}
});
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
...
...
@@ -708,15 +884,19 @@ void shared_expert_int8_kernel_impl(
const
int64_t
packed_N
=
get_row_size
<
int8_t
>
(
N
);
const
int64_t
stride_n
=
packed_K
;
const
bool
use_brgemm
=
can_use_brgemm
<
int8_t
>
(
M
);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at
::
parallel_for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB
;
int64_t
nb
=
i
%
NB
;
// nb0 from top half and nb1 from bottom half
int64_t
nb0
=
nb
,
nb1
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb0
*
BLOCK_N
,
BLOCK_N
);
parallel_2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
int
tid
=
get_thread_num
();
int32_t
*
__restrict__
C0
=
reinterpret_cast
<
int32_t
*>
(
C_tmp
)
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
int32_t
*
__restrict__
C1
=
C0
+
BLOCK_M
*
BLOCK_N
;
loop_2d
<
int8_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
*
2
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
// nb_upper from top half and nb_lower from bottom half
int64_t
nb_upper
=
nb
,
nb_lower
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
// A shape [m_size, K]
...
...
@@ -724,12 +904,46 @@ void shared_expert_int8_kernel_impl(
const
float
*
As
=
As_tmp
+
mb
*
BLOCK_M
;
// B shape [K, n_size] in vnni format
const
int8_t
*
__restrict__
B0
=
packed_w1
+
nb0
*
BLOCK_N
*
stride_n
;
const
int8_t
*
__restrict__
B1
=
packed_w1
+
nb1
*
BLOCK_N
*
stride_n
;
const
float
*
__restrict__
Bs0
=
w1s
+
nb0
*
BLOCK_N
;
const
float
*
__restrict__
Bs1
=
w1s
+
nb1
*
BLOCK_N
;
const
int8_t
*
__restrict__
B0
=
packed_w1
+
nb_upper
*
BLOCK_N
*
stride_n
;
const
int8_t
*
__restrict__
B1
=
packed_w1
+
nb_lower
*
BLOCK_N
*
stride_n
;
const
float
*
__restrict__
Bs0
=
w1s
+
nb_upper
*
BLOCK_N
;
const
float
*
__restrict__
Bs1
=
w1s
+
nb_lower
*
BLOCK_N
;
if
(
use_brgemm
)
{
// 1.b gemm: C0 = A @ B0
at
::
native
::
cpublas
::
brgemm
(
/* M */
m_size
,
/* N */
n_size
,
/* K */
K
,
/* lda */
K
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* add_C */
false
,
/* A */
A
,
/* B */
B0
,
/* C */
C0
);
// 1.c gemm: C1 = A @ B1
at
::
native
::
cpublas
::
brgemm
(
/* M */
m_size
,
/* N */
n_size
,
/* K */
K
,
/* lda */
K
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* add_C */
false
,
/* A */
A
,
/* B */
B1
,
/* C */
C1
);
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
const
int32_t
*
Bcomp0
=
reinterpret_cast
<
const
int32_t
*>
(
B0
+
block_size_n
()
*
K
);
const
int32_t
*
Bcomp1
=
reinterpret_cast
<
const
int32_t
*>
(
B1
+
block_size_n
()
*
K
);
// 1.d silu and mul
silu_and_mul
<
scalar_t
,
BLOCK_N
>
(
ic1
+
mb
*
BLOCK_M
*
N
+
nb
*
BLOCK_N
,
C0
,
C1
,
As
,
Bs0
,
Bs1
,
Bcomp0
,
Bcomp1
,
m_size
,
N
);
}
else
{
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
tinygemm_kernel
(
/* A */
A
,
/* B0 */
B0
,
...
...
@@ -747,6 +961,11 @@ void shared_expert_int8_kernel_impl(
}
});
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
at
::
parallel_for
(
0
,
M
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
m
=
begin
;
m
<
end
;
++
m
)
{
...
...
@@ -763,16 +982,13 @@ void shared_expert_int8_kernel_impl(
const
int64_t
stride_oc
=
packed_N
;
// parallel on [MB2, NB2]
at
::
parallel_
for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB2
,
NB2
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
int
tid
=
at
::
get_thread_num
();
// we won't be using C1 for gemm2
int
tid
=
get_thread_num
();
float
*
__restrict__
C
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
int32_t
*
__restrict__
C32
=
reinterpret_cast
<
int32_t
*>
(
C
+
BLOCK_M
*
BLOCK_N
);
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
loop_2d
<
int8_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
IC
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
...
...
@@ -784,6 +1000,23 @@ void shared_expert_int8_kernel_impl(
const
int8_t
*
__restrict__
B
=
packed_w2
+
nb
*
BLOCK_N
*
stride_oc
;
const
float
*
__restrict__
Bs
=
w2s
+
nb
*
BLOCK_N
;
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm
(
/* M */
m_size
,
/* N */
n_size
,
/* K */
IC
,
/* lda */
IC
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* add_C */
false
,
/* A */
A
,
/* B */
B
,
/* C */
C32
);
// apply scales
const
int32_t
*
Bcomp
=
reinterpret_cast
<
const
int32_t
*>
(
B
+
block_size_n
()
*
IC
);
scale_C
<
BLOCK_N
>
(
C
,
C32
,
As
,
Bs
,
Bcomp
,
m_size
);
}
else
{
// 2.a gemm: C = A @ B
tinygemm_kernel
<
scalar_t
>
(
/* A */
A
,
...
...
@@ -797,6 +1030,7 @@ void shared_expert_int8_kernel_impl(
/* lda */
IC
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
);
}
// 2.b copy from C to output and add fused_experts_out
scalar_t
*
__restrict__
out
=
output
+
mb
*
BLOCK_M
*
K
+
nb
*
BLOCK_N
;
...
...
@@ -804,6 +1038,10 @@ void shared_expert_int8_kernel_impl(
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
add_mul_stub
(
out
+
m
*
K
,
C
+
m
*
BLOCK_N
,
fused_out
+
m
*
K
,
routed_scaling_factor
,
n_size
);
}
});
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
}
...
...
sgl-kernel/csrc/cpu/qkv_proj.cpp
View file @
5ad296bd
...
...
@@ -100,8 +100,7 @@ void segment_gemm_kernel_impl(
const
int64_t
NB1
=
div_up
(
N1
,
BLOCK_N
);
const
int64_t
NB
=
NB0
+
NB1
;
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
const
bool
use_brgemm
=
false
;
const
bool
use_brgemm
=
can_use_brgemm
<
int8_t
>
(
M
);
// K + 4 after compensation
const
int64_t
packed_row_size
=
get_row_size
<
int8_t
>
(
K
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment