Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5ad296bd
Unverified
Commit
5ad296bd
authored
Aug 29, 2025
by
Ma Mingfei
Committed by
GitHub
Aug 28, 2025
Browse files
Optimize prefill performance on cpu backend (#8750)
parent
9f81d741
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
681 additions
and
274 deletions
+681
-274
sgl-kernel/csrc/cpu/common.h
sgl-kernel/csrc/cpu/common.h
+118
-1
sgl-kernel/csrc/cpu/gemm.cpp
sgl-kernel/csrc/cpu/gemm.cpp
+30
-14
sgl-kernel/csrc/cpu/gemm.h
sgl-kernel/csrc/cpu/gemm.h
+4
-3
sgl-kernel/csrc/cpu/gemm_fp8.cpp
sgl-kernel/csrc/cpu/gemm_fp8.cpp
+34
-31
sgl-kernel/csrc/cpu/gemm_int8.cpp
sgl-kernel/csrc/cpu/gemm_int8.cpp
+70
-12
sgl-kernel/csrc/cpu/moe.cpp
sgl-kernel/csrc/cpu/moe.cpp
+39
-69
sgl-kernel/csrc/cpu/moe_fp8.cpp
sgl-kernel/csrc/cpu/moe_fp8.cpp
+51
-46
sgl-kernel/csrc/cpu/moe_int8.cpp
sgl-kernel/csrc/cpu/moe_int8.cpp
+334
-96
sgl-kernel/csrc/cpu/qkv_proj.cpp
sgl-kernel/csrc/cpu/qkv_proj.cpp
+1
-2
No files found.
sgl-kernel/csrc/cpu/common.h
View file @
5ad296bd
...
@@ -105,7 +105,19 @@ namespace {
...
@@ -105,7 +105,19 @@ namespace {
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
// parallel routines
// [NB] Parallel Routines
//
// * at::parallel_for - applies for most of generic use cases, this will be compiled
// against openmp in default torch release.
//
// * parallel_for - same function as above, can choose payload partition scheme in
// balance211.
//
// * parallel_2d - parallel for 2 dimensions, used in GEMM, etc.
// this one will do payload balance across 2 dimensions.
//
// grain size for each thread
constexpr
int
GRAIN_SIZE
=
1024
;
constexpr
int
GRAIN_SIZE
=
1024
;
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>
::
value
,
int
>::
type
=
0
>
template
<
typename
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>
::
value
,
int
>::
type
=
0
>
...
@@ -113,6 +125,17 @@ inline T div_up(T x, T y) {
...
@@ -113,6 +125,17 @@ inline T div_up(T x, T y) {
return
(
x
+
y
-
1
)
/
y
;
return
(
x
+
y
-
1
)
/
y
;
}
}
// you can only use at::get_thread_num() with at::parallel_for()
// as it is lazy initialized, otherwise it will always return 0.
inline
int
get_thread_num
()
{
#if defined(_OPENMP)
return
omp_get_thread_num
();
#else
return
0
;
#endif
}
// balance payload across each thread
template
<
typename
T
>
template
<
typename
T
>
inline
void
balance211
(
T
n
,
T
nth
,
T
ith
,
T
&
n_start
,
T
&
n_end
)
{
inline
void
balance211
(
T
n
,
T
nth
,
T
ith
,
T
&
n_start
,
T
&
n_end
)
{
#if 0
#if 0
...
@@ -153,6 +176,100 @@ inline void parallel_for(int n, const func_t& f) {
...
@@ -153,6 +176,100 @@ inline void parallel_for(int n, const func_t& f) {
#endif
#endif
}
}
// for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42
int
inline
adjust_num_threads
(
int
m
)
{
int
actual_nth
=
at
::
get_num_threads
();
if
(
m
==
1
)
{
return
actual_nth
;
}
return
std
::
max
(
1
,
(
actual_nth
>>
1
)
*
2
);
}
template
<
typename
func_t
>
inline
void
parallel_2d
(
int
m
,
int
n
,
const
func_t
&
f
)
{
// make sure we have even num_threads
int
nth
=
adjust_num_threads
(
m
);
// [NOTE] thread blocking:
//
// 1) prefer square block per thread
// 2) use even number of CPU cores
// 3) use all `num_threads` cores
//
// we have:
// TM * TN = T
// BM / TM = BN / TN
// then:
// TM = ((BM / BN) * T) ^ 0.5
//
float
r
=
float
(
m
)
/
n
;
int
nth_m
=
std
::
ceil
(
std
::
sqrt
(
r
*
nth
));
int
nth_n
=
1
;
for
(;
nth_m
>
0
;
--
nth_m
)
{
nth_n
=
nth
/
nth_m
;
if
(
nth_m
*
nth_n
==
nth
)
{
break
;
}
}
#if defined(_OPENMP)
#pragma omp parallel num_threads(nth)
{
int
ith
=
omp_get_thread_num
();
int
ith_m
=
ith
/
nth_n
;
int
ith_n
=
ith
%
nth_n
;
int
thread_block_m
=
div_up
(
m
,
nth_m
);
int
thread_block_n
=
div_up
(
n
,
nth_n
);
int
begin_m
=
ith_m
*
thread_block_m
;
int
end_m
=
std
::
min
(
m
,
begin_m
+
thread_block_m
);
int
begin_n
=
ith_n
*
thread_block_n
;
int
end_n
=
std
::
min
(
n
,
begin_n
+
thread_block_n
);
f
(
begin_m
,
end_m
,
begin_n
,
end_n
);
}
#else
f
(
0
,
m
,
0
,
n
);
#endif
}
// limit max cache blocks
// when we need to do pre-unpack for weights, e.g. fp8
#define MAX_CACHE_BLOCK_SIZE 4
template
<
typename
T
>
inline
int
get_cache_blocks
(
int
chunk_size
)
{
// L2 2MB and ratio of 50%
const
int
L2_size
=
2048
*
1024
>>
1
;
return
std
::
max
(
1
,
int
(
L2_size
/
(
chunk_size
*
sizeof
(
T
))));
}
template
<
>
inline
int
get_cache_blocks
<
at
::
Float8_e4m3fn
>
(
int
chunk_size
)
{
// fp8 uses bf16 as accumulate type
int
cache_block_size
=
get_cache_blocks
<
at
::
BFloat16
>
(
chunk_size
);
return
std
::
min
(
MAX_CACHE_BLOCK_SIZE
,
cache_block_size
);
}
// 2d sequential loop in range : [mb0, mb1), [nb0, nb1)
template
<
typename
T
,
typename
func_t
>
inline
void
loop_2d
(
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
,
int64_t
chunk_size
,
const
func_t
&
f
)
{
// get number of blocks for L2 in most inner loop
int64_t
cache_blocks_nb
=
get_cache_blocks
<
T
>
(
chunk_size
);
// loop order: [NB / cache_blocks_nb, MB, cache_blocks_nb]
// TODO: implement reverse order of [MB / cache_blocks_mb, NB, cache_blocks_mb]
for
(
int64_t
nbb
=
nb0
;
nbb
<
nb1
;
nbb
+=
cache_blocks_nb
)
{
for
(
int64_t
mb
=
mb0
;
mb
<
mb1
;
++
mb
)
{
for
(
int64_t
nb
=
nbb
;
nb
<
std
::
min
(
nbb
+
cache_blocks_nb
,
nb1
);
++
nb
)
{
f
(
mb
,
nb
,
nb
-
nbb
);
}
}
}
}
// data indexing for dimension collapse
// data indexing for dimension collapse
template
<
typename
T
>
template
<
typename
T
>
inline
T
data_index_init
(
T
offset
)
{
inline
T
data_index_init
(
T
offset
)
{
...
...
sgl-kernel/csrc/cpu/gemm.cpp
View file @
5ad296bd
...
@@ -254,7 +254,7 @@ void tinygemm_kernel(
...
@@ -254,7 +254,7 @@ void tinygemm_kernel(
return
;
return
;
}
}
// pattern: 1-4-16
// pattern: 1-4-16
, N = 16, 32, 48, 64
constexpr
int64_t
BLOCK_M
=
4
;
constexpr
int64_t
BLOCK_M
=
4
;
constexpr
int64_t
BLOCK_N
=
64
;
constexpr
int64_t
BLOCK_N
=
64
;
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
...
@@ -268,35 +268,59 @@ void tinygemm_kernel(
...
@@ -268,35 +268,59 @@ void tinygemm_kernel(
switch
(
mb_size
<<
4
|
nb_size
>>
4
)
{
switch
(
mb_size
<<
4
|
nb_size
>>
4
)
{
// mb_size = 1
// mb_size = 1
case
0x11
:
LAUNCH_TINYGEMM_KERNEL_NN
(
1
,
16
);
break
;
case
0x12
:
case
0x12
:
LAUNCH_TINYGEMM_KERNEL_NN
(
1
,
32
);
LAUNCH_TINYGEMM_KERNEL_NN
(
1
,
32
);
break
;
break
;
case
0x13
:
LAUNCH_TINYGEMM_KERNEL_NN
(
1
,
48
);
break
;
case
0x14
:
case
0x14
:
LAUNCH_TINYGEMM_KERNEL_NN
(
1
,
64
);
LAUNCH_TINYGEMM_KERNEL_NN
(
1
,
64
);
break
;
break
;
// mb_size = 2
// mb_size = 2
case
0x21
:
LAUNCH_TINYGEMM_KERNEL_NN
(
2
,
16
);
break
;
case
0x22
:
case
0x22
:
LAUNCH_TINYGEMM_KERNEL_NN
(
2
,
32
);
LAUNCH_TINYGEMM_KERNEL_NN
(
2
,
32
);
break
;
break
;
case
0x23
:
LAUNCH_TINYGEMM_KERNEL_NN
(
2
,
48
);
break
;
case
0x24
:
case
0x24
:
LAUNCH_TINYGEMM_KERNEL_NN
(
2
,
64
);
LAUNCH_TINYGEMM_KERNEL_NN
(
2
,
64
);
break
;
break
;
// mb_size = 3
// mb_size = 3
case
0x31
:
LAUNCH_TINYGEMM_KERNEL_NN
(
3
,
16
);
break
;
case
0x32
:
case
0x32
:
LAUNCH_TINYGEMM_KERNEL_NN
(
3
,
32
);
LAUNCH_TINYGEMM_KERNEL_NN
(
3
,
32
);
break
;
break
;
case
0x33
:
LAUNCH_TINYGEMM_KERNEL_NN
(
3
,
48
);
break
;
case
0x34
:
case
0x34
:
LAUNCH_TINYGEMM_KERNEL_NN
(
3
,
64
);
LAUNCH_TINYGEMM_KERNEL_NN
(
3
,
64
);
break
;
break
;
// mb_size = 4
// mb_size = 4
case
0x41
:
LAUNCH_TINYGEMM_KERNEL_NN
(
4
,
16
);
break
;
case
0x42
:
case
0x42
:
LAUNCH_TINYGEMM_KERNEL_NN
(
4
,
32
);
LAUNCH_TINYGEMM_KERNEL_NN
(
4
,
32
);
break
;
break
;
case
0x43
:
LAUNCH_TINYGEMM_KERNEL_NN
(
4
,
48
);
break
;
case
0x44
:
case
0x44
:
LAUNCH_TINYGEMM_KERNEL_NN
(
4
,
64
);
LAUNCH_TINYGEMM_KERNEL_NN
(
4
,
64
);
break
;
break
;
default:
default:
TORCH_CHECK
(
false
,
"Unexpected block size, "
,
mb_size
,
"
x
"
,
"
nb_size
"
);
TORCH_CHECK
(
false
,
"Unexpected block size, "
,
mb_size
,
"
x
"
,
nb_size
);
}
}
}
}
}
}
...
@@ -318,20 +342,15 @@ void weight_packed_linear_kernel_impl(
...
@@ -318,20 +342,15 @@ void weight_packed_linear_kernel_impl(
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
NB
=
div_up
(
N
,
BLOCK_N
);
const
int64_t
NB
=
div_up
(
N
,
BLOCK_N
);
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx c) N is small
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
M
);
const
bool
use_brgemm
=
(
M
>
4
)
||
(
!
std
::
is_same_v
<
scalar_t
,
at
::
BFloat16
>
)
||
(
N
<
64
);
// parallel on [MB, NB]
// parallel on [MB, NB]
AT_DISPATCH_BOOL
(
bias
!=
nullptr
,
has_bias
,
[
&
]
{
AT_DISPATCH_BOOL
(
bias
!=
nullptr
,
has_bias
,
[
&
]
{
at
::
parallel_for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
int64_t
mb
{
0
},
nb
{
0
};
data_index_init
(
begin
,
mb
,
MB
,
nb
,
NB
);
// for brgemm, use float32 for accumulate
// for brgemm, use float32 for accumulate
alignas
(
64
)
float
Ctmp
[
BLOCK_M
*
BLOCK_N
];
alignas
(
64
)
float
Ctmp
[
BLOCK_M
*
BLOCK_N
];
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
loop_2d
<
scalar_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
UNUSED
(
i
);
int64_t
mb_start
=
mb
*
BLOCK_M
;
int64_t
mb_start
=
mb
*
BLOCK_M
;
int64_t
mb_size
=
std
::
min
(
M
-
mb_start
,
BLOCK_M
);
int64_t
mb_size
=
std
::
min
(
M
-
mb_start
,
BLOCK_M
);
int64_t
nb_start
=
nb
*
BLOCK_N
;
int64_t
nb_start
=
nb
*
BLOCK_N
;
...
@@ -350,10 +369,7 @@ void weight_packed_linear_kernel_impl(
...
@@ -350,10 +369,7 @@ void weight_packed_linear_kernel_impl(
/* ldb */
nb_size
,
/* ldb */
nb_size
,
/* ldc */
out_strideM
,
/* ldc */
out_strideM
,
/* brg */
use_brgemm
);
/* brg */
use_brgemm
);
});
// move to the next index
data_index_step
(
mb
,
MB
,
nb
,
NB
);
}
if
(
use_brgemm
)
{
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
at
::
native
::
cpublas
::
brgemm_release
();
...
...
sgl-kernel/csrc/cpu/gemm.h
View file @
5ad296bd
...
@@ -27,10 +27,10 @@ template <>
...
@@ -27,10 +27,10 @@ template <>
inline
bool
can_use_brgemm
<
at
::
Half
>
(
int
M
)
{
inline
bool
can_use_brgemm
<
at
::
Half
>
(
int
M
)
{
return
true
;
return
true
;
}
}
//
TODO: add u8s8 brgemm,
this requires PyTorch 2.7
// this requires PyTorch 2.7
or above
template
<
>
template
<
>
inline
bool
can_use_brgemm
<
int8_t
>
(
int
M
)
{
inline
bool
can_use_brgemm
<
int8_t
>
(
int
M
)
{
return
false
;
return
M
>
4
;
}
}
template
<
>
template
<
>
...
@@ -198,4 +198,5 @@ void tinygemm_kernel(
...
@@ -198,4 +198,5 @@ void tinygemm_kernel(
int64_t
ldb
,
int64_t
ldb
,
int64_t
ldc
,
int64_t
ldc
,
bool
brg
,
bool
brg
,
int64_t
block_size_K
);
int64_t
block_size_K
,
bool
do_unpack
=
true
);
sgl-kernel/csrc/cpu/gemm_fp8.cpp
View file @
5ad296bd
...
@@ -2,9 +2,6 @@
...
@@ -2,9 +2,6 @@
#include "gemm.h"
#include "gemm.h"
#include "vec.h"
#include "vec.h"
// we use 4x32 for BLOCK_M
#define BLOCK_SIZE_M_SCALE 4
namespace
{
namespace
{
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
...
@@ -250,7 +247,8 @@ struct brgemm {
...
@@ -250,7 +247,8 @@ struct brgemm {
int
K
,
int
K
,
int
lda
,
int
lda
,
int
ldb
,
int
ldb
,
int
ldc
)
{
int
ldc
,
bool
do_unpack
=
true
)
{
TORCH_CHECK
(
false
,
"struct brgemm: primary template not implemented!"
);
TORCH_CHECK
(
false
,
"struct brgemm: primary template not implemented!"
);
}
}
};
};
...
@@ -270,17 +268,20 @@ struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
...
@@ -270,17 +268,20 @@ struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
int
K
,
int
K
,
int
lda
,
int
lda
,
int
ldb
,
int
ldb
,
int
ldc
)
{
int
ldc
,
bool
do_unpack
=
true
)
{
constexpr
int
BLOCK_N
=
block_size_n
();
constexpr
int
BLOCK_N
=
block_size_n
();
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
const
int
ldb_tmp
=
BLOCK_N
;
const
int
ldb_tmp
=
BLOCK_N
;
for
(
int
k
=
0
;
k
<
K
;
k
+=
BLOCK_K
)
{
if
(
do_unpack
)
{
int
kb_size
=
std
::
min
(
BLOCK_K
,
K
-
k
);
for
(
int
k
=
0
;
k
<
K
;
k
+=
BLOCK_K
)
{
int
kb_size
=
std
::
min
(
BLOCK_K
,
K
-
k
);
int
idx
=
k
>>
7
;
// k / BLOCK_K where BLOCK_K = 128
int
idx
=
k
>>
7
;
// k / BLOCK_K where BLOCK_K = 128
unpack_B
(
Btmp
+
k
*
ldb_tmp
,
B
+
k
*
ldb
,
N
,
kb_size
,
ldb
,
ldb_tmp
,
scale
[
idx
]);
unpack_B
(
Btmp
+
k
*
ldb_tmp
,
B
+
k
*
ldb
,
N
,
kb_size
,
ldb
,
ldb_tmp
,
scale
[
idx
]);
}
}
}
at
::
native
::
cpublas
::
brgemm
(
M
,
N
,
K
,
lda
,
ldb_tmp
,
BLOCK_N
,
/* add_C */
false
,
A
,
Btmp
,
Ctmp
);
at
::
native
::
cpublas
::
brgemm
(
M
,
N
,
K
,
lda
,
ldb_tmp
,
BLOCK_N
,
/* add_C */
false
,
A
,
Btmp
,
Ctmp
);
...
@@ -312,9 +313,11 @@ void tinygemm_kernel(
...
@@ -312,9 +313,11 @@ void tinygemm_kernel(
int64_t
ldb
,
int64_t
ldb
,
int64_t
ldc
,
int64_t
ldc
,
bool
brg
,
bool
brg
,
int64_t
block_size_K
)
{
int64_t
block_size_K
,
bool
do_unpack
=
true
)
{
if
(
brg
)
{
if
(
brg
)
{
brgemm
<
scalar_t
,
at
::
Float8_e4m3fn
,
has_bias
>::
apply
(
A
,
B
,
C
,
Btmp
,
Ctmp
,
bias
,
scale
,
M
,
N
,
K
,
lda
,
ldb
,
ldc
);
brgemm
<
scalar_t
,
at
::
Float8_e4m3fn
,
has_bias
>::
apply
(
A
,
B
,
C
,
Btmp
,
Ctmp
,
bias
,
scale
,
M
,
N
,
K
,
lda
,
ldb
,
ldc
,
do_unpack
);
return
;
return
;
}
}
...
@@ -366,7 +369,7 @@ void fp8_scaled_mm_kernel_impl(
...
@@ -366,7 +369,7 @@ void fp8_scaled_mm_kernel_impl(
int64_t
block_size_N
,
int64_t
block_size_N
,
int64_t
block_size_K
,
int64_t
block_size_K
,
int64_t
buffer_size_per_thread
)
{
int64_t
buffer_size_per_thread
)
{
constexpr
int64_t
BLOCK_M
=
block_size_m
()
*
BLOCK_SIZE_M_SCALE
;
constexpr
int64_t
BLOCK_M
=
block_size_m
();
constexpr
int64_t
BLOCK_N
=
block_size_n
();
constexpr
int64_t
BLOCK_N
=
block_size_n
();
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
NB
=
div_up
(
N
,
BLOCK_N
);
const
int64_t
NB
=
div_up
(
N
,
BLOCK_N
);
...
@@ -378,16 +381,12 @@ void fp8_scaled_mm_kernel_impl(
...
@@ -378,16 +381,12 @@ void fp8_scaled_mm_kernel_impl(
// parallel on [MB, NB]
// parallel on [MB, NB]
AT_DISPATCH_BOOL
(
bias
!=
nullptr
,
has_bias
,
[
&
]
{
AT_DISPATCH_BOOL
(
bias
!=
nullptr
,
has_bias
,
[
&
]
{
at
::
parallel_for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
int64_t
mb
{
0
},
nb
{
0
};
int
tid
=
get_thread_num
();
data_index_init
(
begin
,
mb
,
MB
,
nb
,
NB
);
int
tid
=
at
::
get_thread_num
();
scalar_t
*
__restrict__
Btmp
=
buffer
+
tid
*
buffer_size_per_thread
;
scalar_t
*
__restrict__
Btmp
=
buffer
+
tid
*
buffer_size_per_thread
;
float
*
__restrict__
Ctmp
=
(
float
*
)((
void
*
)(
Btmp
+
BLOCK_N
*
K
));
float
*
__restrict__
Ctmp
=
(
float
*
)((
void
*
)(
Btmp
+
MAX_CACHE_BLOCK_SIZE
*
BLOCK_N
*
K
));
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
loop_2d
<
at
::
Float8_e4m3fn
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
UNUSED
(
i
);
const
float
*
scale_ptr
=
scales2
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
;
const
float
*
scale_ptr
=
scales2
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
;
int64_t
mb_start
=
mb
*
BLOCK_M
;
int64_t
mb_start
=
mb
*
BLOCK_M
;
...
@@ -395,11 +394,14 @@ void fp8_scaled_mm_kernel_impl(
...
@@ -395,11 +394,14 @@ void fp8_scaled_mm_kernel_impl(
int64_t
nb_start
=
nb
*
BLOCK_N
;
int64_t
nb_start
=
nb
*
BLOCK_N
;
int64_t
nb_size
=
std
::
min
(
N
-
nb_start
,
BLOCK_N
);
int64_t
nb_size
=
std
::
min
(
N
-
nb_start
,
BLOCK_N
);
// only do unpacking for the first row
bool
do_unpack
=
(
mb
==
mb0
);
tinygemm_kernel
<
scalar_t
,
has_bias
>
(
tinygemm_kernel
<
scalar_t
,
has_bias
>
(
/* A */
mat1
+
mb_start
*
mat1_strideM
,
/* A */
mat1
+
mb_start
*
mat1_strideM
,
/* B */
mat2
+
nb_start
*
K
,
// nb * BLOCK_N * K
/* B */
mat2
+
nb_start
*
K
,
// nb * BLOCK_N * K
/* C */
out
+
mb_start
*
out_strideM
+
nb_start
,
/* C */
out
+
mb_start
*
out_strideM
+
nb_start
,
/* Btmp */
Btmp
,
/* Btmp */
Btmp
+
nb_offset
*
BLOCK_N
*
K
,
/* Ctmp */
Ctmp
,
/* Ctmp */
Ctmp
,
/* scale */
scale_ptr
,
/* scale */
scale_ptr
,
/* bias */
bias
+
nb_start
,
/* bias */
bias
+
nb_start
,
...
@@ -410,11 +412,9 @@ void fp8_scaled_mm_kernel_impl(
...
@@ -410,11 +412,9 @@ void fp8_scaled_mm_kernel_impl(
/* ldb */
nb_size
,
/* ldb */
nb_size
,
/* ldc */
out_strideM
,
/* ldc */
out_strideM
,
/* brg */
use_brgemm
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
/* block_size_K */
block_size_K
,
/* do_unpack */
do_unpack
);
// move to the next index
});
data_index_step
(
mb
,
MB
,
nb
,
NB
);
}
if
(
use_brgemm
)
{
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
at
::
native
::
cpublas
::
brgemm_release
();
...
@@ -441,8 +441,10 @@ void tinygemm_kernel(
...
@@ -441,8 +441,10 @@ void tinygemm_kernel(
int64_t
ldb
,
int64_t
ldb
,
int64_t
ldc
,
int64_t
ldc
,
bool
brg
,
bool
brg
,
int64_t
block_size_K
)
{
int64_t
block_size_K
,
tinygemm_kernel
<
scalar_t
,
false
>
(
A
,
B
,
C
,
Btmp
,
Ctmp
,
scale
,
nullptr
,
M
,
N
,
K
,
lda
,
ldb
,
ldc
,
brg
,
block_size_K
);
bool
do_unpack
)
{
tinygemm_kernel
<
scalar_t
,
false
>
(
A
,
B
,
C
,
Btmp
,
Ctmp
,
scale
,
nullptr
,
M
,
N
,
K
,
lda
,
ldb
,
ldc
,
brg
,
block_size_K
,
do_unpack
);
}
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
...
@@ -460,7 +462,8 @@ void tinygemm_kernel(
...
@@ -460,7 +462,8 @@ void tinygemm_kernel(
int64_t ldb, \
int64_t ldb, \
int64_t ldc, \
int64_t ldc, \
bool brg, \
bool brg, \
int64_t block_size_K)
int64_t block_size_K, \
bool do_unpack)
INSTANTIATE_TINYGEMM_TEMPLATE
(
at
::
BFloat16
);
INSTANTIATE_TINYGEMM_TEMPLATE
(
at
::
BFloat16
);
INSTANTIATE_TINYGEMM_TEMPLATE
(
at
::
Half
);
INSTANTIATE_TINYGEMM_TEMPLATE
(
at
::
Half
);
...
@@ -495,7 +498,7 @@ at::Tensor fp8_scaled_mm_cpu(
...
@@ -495,7 +498,7 @@ at::Tensor fp8_scaled_mm_cpu(
int64_t
block_size_N
=
block_size
[
0
];
int64_t
block_size_N
=
block_size
[
0
];
int64_t
block_size_K
=
block_size
[
1
];
int64_t
block_size_K
=
block_size
[
1
];
constexpr
int64_t
BLOCK_M
=
block_size_m
()
*
BLOCK_SIZE_M_SCALE
;
constexpr
int64_t
BLOCK_M
=
block_size_m
();
constexpr
int64_t
BLOCK_N
=
block_size_n
();
constexpr
int64_t
BLOCK_N
=
block_size_n
();
TORCH_CHECK
(
block_size_N
%
BLOCK_N
==
0
,
"fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"
);
TORCH_CHECK
(
block_size_N
%
BLOCK_N
==
0
,
"fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"
);
TORCH_CHECK
(
block_size_K
==
BLOCK_K
,
"fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"
);
TORCH_CHECK
(
block_size_K
==
BLOCK_K
,
"fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"
);
...
@@ -523,7 +526,7 @@ at::Tensor fp8_scaled_mm_cpu(
...
@@ -523,7 +526,7 @@ at::Tensor fp8_scaled_mm_cpu(
// Btmp : [T, BLOCK_N * K]
// Btmp : [T, BLOCK_N * K]
// Ctmp : [T, BLOCK_M * BLOCK_N]
// Ctmp : [T, BLOCK_M * BLOCK_N]
int
num_threads
=
at
::
get_num_threads
();
int
num_threads
=
at
::
get_num_threads
();
int64_t
size_per_thread
=
BLOCK_N
*
K
+
BLOCK_M
*
BLOCK_N
*
2
;
int64_t
size_per_thread
=
MAX_CACHE_BLOCK_SIZE
*
BLOCK_N
*
K
+
BLOCK_M
*
BLOCK_N
*
2
;
auto
buffer
=
at
::
empty
({
num_threads
,
size_per_thread
},
mat1
.
options
());
auto
buffer
=
at
::
empty
({
num_threads
,
size_per_thread
},
mat1
.
options
());
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
out_dtype
,
"fp8_scaled_mm_kernel_impl"
,
[
&
]
{
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
out_dtype
,
"fp8_scaled_mm_kernel_impl"
,
[
&
]
{
...
...
sgl-kernel/csrc/cpu/gemm_int8.cpp
View file @
5ad296bd
...
@@ -4,6 +4,61 @@
...
@@ -4,6 +4,61 @@
namespace
{
namespace
{
template
<
typename
scalar_t
,
bool
has_bias
,
int
BLOCK_N
>
struct
scale_C
{
static
inline
void
apply
(
scalar_t
*
__restrict__
C
,
const
int32_t
*
__restrict__
Ctmp
,
const
int32_t
*
__restrict__
Bcomp
,
const
float
*
__restrict__
bias
,
float
As
,
const
float
*
__restrict__
Bs
)
{
TORCH_CHECK
(
false
,
"scale_C: scalar path not implemented!"
);
}
};
#if defined(CPU_CAPABILITY_AVX512)
template
<
bool
has_bias
,
int
BLOCK_N
>
struct
scale_C
<
at
::
BFloat16
,
has_bias
,
BLOCK_N
>
{
static
inline
void
apply
(
at
::
BFloat16
*
__restrict__
C
,
const
int32_t
*
__restrict__
Ctmp
,
const
int32_t
*
__restrict__
Bcomp
,
const
float
*
__restrict__
bias
,
float
As
,
const
float
*
__restrict__
Bs
)
{
constexpr
int
COLS
=
BLOCK_N
/
16
;
static_assert
(
COLS
%
2
==
0
);
__m512
vc
[
COLS
];
__m512
vd0
=
_mm512_set1_ps
(
As
);
auto
compute
=
[
&
](
auto
col
)
{
__m512
vd1
=
_mm512_loadu_ps
(
Bs
+
col
*
16
);
__m512i
vcomp
=
_mm512_loadu_si512
(
Bcomp
+
col
*
16
);
__m512i
vc32
=
_mm512_loadu_si512
(
Ctmp
+
col
*
16
);
vc
[
col
]
=
_mm512_cvtepi32_ps
(
_mm512_sub_epi32
(
vc32
,
vcomp
));
if
constexpr
(
has_bias
)
{
__m512
vbias
=
_mm512_loadu_ps
(
bias
+
col
*
16
);
vc
[
col
]
=
_mm512_fmadd_ps
(
_mm512_mul_ps
(
vc
[
col
],
vd0
),
vd1
,
vbias
);
}
else
{
vc
[
col
]
=
_mm512_mul_ps
(
_mm512_mul_ps
(
vc
[
col
],
vd0
),
vd1
);
}
};
Unroll
<
COLS
>
{}(
compute
);
auto
storec
=
[
&
](
auto
col
)
{
// for COLS = 2, 4 use 512bit store
if
constexpr
(
col
%
2
==
0
)
{
_mm512_storeu_si512
(
reinterpret_cast
<
__m512i
*>
((
C
+
col
*
16
)),
(
__m512i
)(
_mm512_cvtne2ps_pbh
(
vc
[
col
+
1
],
vc
[
col
+
0
])));
}
};
Unroll
<
COLS
>
{}(
storec
);
}
};
#endif
template
<
typename
scalar_t
,
bool
has_bias
,
int
BLOCK_M
,
int
BLOCK_N
>
template
<
typename
scalar_t
,
bool
has_bias
,
int
BLOCK_M
,
int
BLOCK_N
>
struct
tinygemm_kernel_nn
{
struct
tinygemm_kernel_nn
{
static
inline
void
apply
(
static
inline
void
apply
(
...
@@ -169,6 +224,17 @@ void tinygemm_kernel(
...
@@ -169,6 +224,17 @@ void tinygemm_kernel(
// B compensation
// B compensation
const
int32_t
*
Bcomp
=
reinterpret_cast
<
const
int32_t
*>
(
B
+
block_size_n
()
*
K
);
const
int32_t
*
Bcomp
=
reinterpret_cast
<
const
int32_t
*>
(
B
+
block_size_n
()
*
K
);
if
(
brg
)
{
constexpr
int
BLOCK_N
=
block_size_n
();
at
::
native
::
cpublas
::
brgemm
(
M
,
N
,
K
,
lda
,
ldb
,
BLOCK_N
,
/* add_C */
false
,
A
,
B
,
Ctmp
);
// apply compensation and scale
for
(
int64_t
m
=
0
;
m
<
M
;
++
m
)
{
scale_C
<
scalar_t
,
has_bias
,
BLOCK_N
>::
apply
(
C
+
m
*
ldc
,
Ctmp
+
m
*
BLOCK_N
,
Bcomp
,
bias
,
As
[
m
],
Bs
);
}
return
;
}
// pattern: 1-4-16
// pattern: 1-4-16
constexpr
int64_t
BLOCK_M
=
4
;
constexpr
int64_t
BLOCK_M
=
4
;
constexpr
int64_t
BLOCK_N
=
64
;
constexpr
int64_t
BLOCK_N
=
64
;
...
@@ -233,22 +299,17 @@ void int8_scaled_mm_kernel_impl(
...
@@ -233,22 +299,17 @@ void int8_scaled_mm_kernel_impl(
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
NB
=
div_up
(
N
,
BLOCK_N
);
const
int64_t
NB
=
div_up
(
N
,
BLOCK_N
);
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
const
bool
use_brgemm
=
can_use_brgemm
<
int8_t
>
(
M
);
const
bool
use_brgemm
=
false
;
// K + 4 after compensation
// K + 4 after compensation
const
int64_t
packed_row_size
=
get_row_size
<
int8_t
>
(
K
);
const
int64_t
packed_row_size
=
get_row_size
<
int8_t
>
(
K
);
AT_DISPATCH_BOOL
(
bias
!=
nullptr
,
has_bias
,
[
&
]
{
AT_DISPATCH_BOOL
(
bias
!=
nullptr
,
has_bias
,
[
&
]
{
at
::
parallel_for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
int64_t
mb
{
0
},
nb
{
0
};
data_index_init
(
begin
,
mb
,
MB
,
nb
,
NB
);
// for brgemm, use int32_t for accumulate
// for brgemm, use int32_t for accumulate
alignas
(
64
)
int32_t
Ctmp
[
BLOCK_M
*
BLOCK_N
];
alignas
(
64
)
int32_t
Ctmp
[
BLOCK_M
*
BLOCK_N
];
for
(
int
i
=
begin
;
i
<
end
;
++
i
)
{
loop_2d
<
int8_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
UNUSED
(
i
);
int
mb_start
=
mb
*
BLOCK_M
;
int
mb_start
=
mb
*
BLOCK_M
;
int
mb_size
=
std
::
min
(
M
-
mb_start
,
BLOCK_M
);
int
mb_size
=
std
::
min
(
M
-
mb_start
,
BLOCK_M
);
int
nb_start
=
nb
*
BLOCK_N
;
int
nb_start
=
nb
*
BLOCK_N
;
...
@@ -269,10 +330,7 @@ void int8_scaled_mm_kernel_impl(
...
@@ -269,10 +330,7 @@ void int8_scaled_mm_kernel_impl(
/* ldb */
nb_size
,
/* ldb */
nb_size
,
/* ldc */
N
,
/* ldc */
N
,
/* brg */
use_brgemm
);
/* brg */
use_brgemm
);
});
// move to the next index
data_index_step
(
mb
,
MB
,
nb
,
NB
);
}
if
(
use_brgemm
)
{
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
at
::
native
::
cpublas
::
brgemm_release
();
...
...
sgl-kernel/csrc/cpu/moe.cpp
View file @
5ad296bd
...
@@ -579,36 +579,31 @@ void fused_experts_kernel_impl(
...
@@ -579,36 +579,31 @@ void fused_experts_kernel_impl(
const
int64_t
stride_e
=
2
*
N
*
K
;
const
int64_t
stride_e
=
2
*
N
*
K
;
const
int64_t
stride_n
=
K
;
const
int64_t
stride_n
=
K
;
int64_t
avg_M
=
std
::
max
(
int64_t
(
1
),
M
*
topk
/
E
);
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
avg_M
);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at
::
parallel_
for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
scalar_t
*
__restrict__
A
=
A_tmp
+
tid
*
BLOCK_M
*
K
;
scalar_t
*
__restrict__
A
=
A_tmp
+
tid
*
BLOCK_M
*
K
;
float
*
__restrict__
C0
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
float
*
__restrict__
C0
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
float
*
__restrict__
C1
=
C0
+
BLOCK_M
*
BLOCK_N
;
float
*
__restrict__
C1
=
C0
+
BLOCK_M
*
BLOCK_N
;
bool
is_brgemm_used
=
false
;
loop_2d
<
scalar_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
*
2
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
// nb_upper from top half and nb_lower from bottom half
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
nb_upper
=
nb
,
nb_lower
=
nb
+
NB
;
int64_t
mb
=
i
/
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
nb
=
i
%
NB
;
// nb0 from top half and nb1 from bottom half
int64_t
nb0
=
nb
,
nb1
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb0
*
BLOCK_N
,
BLOCK_N
);
// B shape [K, n_size] in vnni format
// B shape [K, n_size] in vnni format
int32_t
expert_id
=
expert_ids
[
mb
];
int32_t
expert_id
=
expert_ids
[
mb
];
const
scalar_t
*
__restrict__
B0
=
packed_w1
+
expert_id
*
stride_e
+
nb
0
*
BLOCK_N
*
stride_n
;
const
scalar_t
*
__restrict__
B0
=
packed_w1
+
expert_id
*
stride_e
+
nb
_upper
*
BLOCK_N
*
stride_n
;
const
scalar_t
*
__restrict__
B1
=
packed_w1
+
expert_id
*
stride_e
+
nb
1
*
BLOCK_N
*
stride_n
;
const
scalar_t
*
__restrict__
B1
=
packed_w1
+
expert_id
*
stride_e
+
nb
_lower
*
BLOCK_N
*
stride_n
;
// 1.a load A
// 1.a load A
const
int32_t
*
A_ids
=
sorted_ids
+
mb
*
BLOCK_M
;
const
int32_t
*
A_ids
=
sorted_ids
+
mb
*
BLOCK_M
;
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
int32_t
index
=
A_ids
[
m
]
/
topk
;
int32_t
index
=
A_ids
[
m
]
/
topk
;
copy_stub
(
A
+
m
*
K
,
input
+
index
*
K
,
K
);
copy_stub
(
A
+
m
*
K
,
input
+
index
*
K
,
K
);
...
@@ -659,9 +654,9 @@ void fused_experts_kernel_impl(
...
@@ -659,9 +654,9 @@ void fused_experts_kernel_impl(
/* ldb */
n_size
,
/* ldb */
n_size
,
/* ldc */
N
);
/* ldc */
N
);
}
}
}
}
);
if
(
is
_brgemm
_used
)
{
if
(
use
_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
at
::
native
::
cpublas
::
brgemm_release
();
}
}
});
});
...
@@ -676,24 +671,16 @@ void fused_experts_kernel_impl(
...
@@ -676,24 +671,16 @@ void fused_experts_kernel_impl(
const
int64_t
stride_oc
=
IC
;
const
int64_t
stride_oc
=
IC
;
// parallel on [MB2, NB2]
// parallel on [MB2, NB2]
at
::
parallel_
for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB2
,
NB2
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
// we won't be using C1 for gemm2
// we won't be using C1 for gemm2
float
*
__restrict__
C
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
float
*
__restrict__
C
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
bool
is_brgemm_used
=
false
;
loop_2d
<
scalar_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
IC
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
// A ptr from ic1 of [M * topk, N] in sorted order
// A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again
// so as to avoid copy A to tmp buffer again
const
scalar_t
*
__restrict__
A
=
ic1
+
offsets
[
mb
]
*
N
;
const
scalar_t
*
__restrict__
A
=
ic1
+
offsets
[
mb
]
*
N
;
...
@@ -736,9 +723,9 @@ void fused_experts_kernel_impl(
...
@@ -736,9 +723,9 @@ void fused_experts_kernel_impl(
float
weight
=
topk_weights
[
index
];
float
weight
=
topk_weights
[
index
];
copy_mul_stub
(
ic2
+
index
*
K
+
nb
*
BLOCK_N
,
C
+
m
*
BLOCK_N
,
weight
,
n_size
);
copy_mul_stub
(
ic2
+
index
*
K
+
nb
*
BLOCK_N
,
C
+
m
*
BLOCK_N
,
weight
,
n_size
);
}
}
}
}
);
if
(
is
_brgemm
_used
)
{
if
(
use
_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
at
::
native
::
cpublas
::
brgemm_release
();
}
}
});
});
...
@@ -776,36 +763,27 @@ void shared_expert_kernel_impl(
...
@@ -776,36 +763,27 @@ void shared_expert_kernel_impl(
TORCH_CHECK
(
N
%
BLOCK_N
==
0
,
"Fixme when N is not multiples of "
,
BLOCK_N
);
TORCH_CHECK
(
N
%
BLOCK_N
==
0
,
"Fixme when N is not multiples of "
,
BLOCK_N
);
const
int64_t
stride_n
=
K
;
const
int64_t
stride_n
=
K
;
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
M
);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at
::
parallel_
for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
float
*
__restrict__
C0
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
float
*
__restrict__
C0
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
float
*
__restrict__
C1
=
C0
+
BLOCK_M
*
BLOCK_N
;
float
*
__restrict__
C1
=
C0
+
BLOCK_M
*
BLOCK_N
;
bool
is_brgemm_used
=
false
;
loop_2d
<
scalar_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
*
2
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
// nb_upper from top half and nb_lower from bottom half
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
nb_upper
=
nb
,
nb_lower
=
nb
+
NB
;
int64_t
mb
=
i
/
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
nb
=
i
%
NB
;
// nb0 from top half and nb1 from bottom half
int64_t
nb0
=
nb
,
nb1
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb0
*
BLOCK_N
,
BLOCK_N
);
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
// int64_t mb_start = mb * BLOCK_M;
// int64_t mb_size = std::min(M - mb_start, BLOCK_M);
// A shape [m_size, K]
// A shape [m_size, K]
const
scalar_t
*
A
=
input
+
mb
*
BLOCK_M
*
K
;
const
scalar_t
*
A
=
input
+
mb
*
BLOCK_M
*
K
;
// B shape [K, n_size] in vnni format
// B shape [K, n_size] in vnni format
const
scalar_t
*
__restrict__
B0
=
packed_w1
+
nb0
*
BLOCK_N
*
stride_n
;
const
scalar_t
*
__restrict__
B0
=
packed_w1
+
nb_upper
*
BLOCK_N
*
stride_n
;
const
scalar_t
*
__restrict__
B1
=
packed_w1
+
nb1
*
BLOCK_N
*
stride_n
;
const
scalar_t
*
__restrict__
B1
=
packed_w1
+
nb_lower
*
BLOCK_N
*
stride_n
;
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
if
(
use_brgemm
)
{
if
(
use_brgemm
)
{
// 1.b gemm: C0 = A @ B0
// 1.b gemm: C0 = A @ B0
...
@@ -850,9 +828,9 @@ void shared_expert_kernel_impl(
...
@@ -850,9 +828,9 @@ void shared_expert_kernel_impl(
/* ldb */
n_size
,
/* ldb */
n_size
,
/* ldc */
N
);
/* ldc */
N
);
}
}
}
}
);
if
(
is
_brgemm
_used
)
{
if
(
use
_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
at
::
native
::
cpublas
::
brgemm_release
();
}
}
});
});
...
@@ -866,24 +844,16 @@ void shared_expert_kernel_impl(
...
@@ -866,24 +844,16 @@ void shared_expert_kernel_impl(
const
int64_t
stride_oc
=
IC
;
const
int64_t
stride_oc
=
IC
;
// parallel on [MB2, NB2]
// parallel on [MB2, NB2]
at
::
parallel_
for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB2
,
NB2
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
// we won't be using C1 for gemm2
// we won't be using C1 for gemm2
float
*
__restrict__
C
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
float
*
__restrict__
C
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
bool
is_brgemm_used
=
false
;
loop_2d
<
scalar_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
IC
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
const
bool
use_brgemm
=
can_use_brgemm
<
scalar_t
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
// A shape [m_size, IC]
// A shape [m_size, IC]
const
scalar_t
*
__restrict__
A
=
ic1
+
mb
*
BLOCK_M
*
N
;
const
scalar_t
*
__restrict__
A
=
ic1
+
mb
*
BLOCK_M
*
N
;
...
@@ -922,9 +892,9 @@ void shared_expert_kernel_impl(
...
@@ -922,9 +892,9 @@ void shared_expert_kernel_impl(
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
add_mul_stub
(
out
+
m
*
K
,
C
+
m
*
BLOCK_N
,
fused_out
+
m
*
K
,
routed_scaling_factor
,
n_size
);
add_mul_stub
(
out
+
m
*
K
,
C
+
m
*
BLOCK_N
,
fused_out
+
m
*
K
,
routed_scaling_factor
,
n_size
);
}
}
}
}
);
if
(
is
_brgemm
_used
)
{
if
(
use
_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
at
::
native
::
cpublas
::
brgemm_release
();
}
}
});
});
...
@@ -1086,7 +1056,7 @@ at::Tensor fused_experts_cpu(
...
@@ -1086,7 +1056,7 @@ at::Tensor fused_experts_cpu(
//
//
// for fp8 w8a16:
// for fp8 w8a16:
// 7. intermediate_cache0 : [M * topk, 2N]
// 7. intermediate_cache0 : [M * topk, 2N]
// 8. B_tmp : [T, BLOCK_N, std::max(K, N)]
// 8. B_tmp : [T,
MAX_CACHE_BLOCK_SIZE,
BLOCK_N, std::max(K, N)]
//
//
int64_t
buffer_size_nbytes
=
M
*
topk
*
N
*
2
+
M
*
topk
*
K
*
2
+
int64_t
buffer_size_nbytes
=
M
*
topk
*
N
*
2
+
M
*
topk
*
K
*
2
+
num_threads
*
BLOCK_M
*
K
*
(
use_int8_w8a8
?
1
:
2
)
+
num_threads
*
BLOCK_M
*
K
*
(
use_int8_w8a8
?
1
:
2
)
+
...
@@ -1096,7 +1066,7 @@ at::Tensor fused_experts_cpu(
...
@@ -1096,7 +1066,7 @@ at::Tensor fused_experts_cpu(
buffer_size_nbytes
+=
std
::
max
(
M
*
K
,
M
*
topk
*
N
)
+
M
*
topk
*
sizeof
(
float
);
buffer_size_nbytes
+=
std
::
max
(
M
*
K
,
M
*
topk
*
N
)
+
M
*
topk
*
sizeof
(
float
);
}
}
if
(
use_fp8_w8a16
)
{
if
(
use_fp8_w8a16
)
{
buffer_size_nbytes
+=
M
*
topk
*
2
*
N
*
2
+
num_threads
*
BLOCK_N
*
std
::
max
(
K
,
N
)
*
2
;
buffer_size_nbytes
+=
M
*
topk
*
2
*
N
*
2
+
num_threads
*
MAX_CACHE_BLOCK_SIZE
*
BLOCK_N
*
std
::
max
(
K
,
N
)
*
2
;
}
}
auto
buffer2
=
at
::
empty
({
buffer_size_nbytes
},
hidden_states
.
options
().
dtype
(
at
::
kChar
));
auto
buffer2
=
at
::
empty
({
buffer_size_nbytes
},
hidden_states
.
options
().
dtype
(
at
::
kChar
));
...
@@ -1268,7 +1238,7 @@ at::Tensor shared_expert_cpu(
...
@@ -1268,7 +1238,7 @@ at::Tensor shared_expert_cpu(
//
//
// for fp8 w8a16:
// for fp8 w8a16:
// 5. intermediate_cache0 : [M, 2N]
// 5. intermediate_cache0 : [M, 2N]
// 6. B_tmp: [T, BLOCK_M, max(K, N)]
// 6. B_tmp: [T,
MAX_CACHE_BLOCK_SIZE,
BLOCK_M, max(K, N)]
//
//
int
num_threads
=
at
::
get_num_threads
();
int
num_threads
=
at
::
get_num_threads
();
int64_t
buffer_size_nbytes
=
M
*
N
*
2
+
num_threads
*
2
*
BLOCK_M
*
BLOCK_N
*
sizeof
(
float
);
int64_t
buffer_size_nbytes
=
M
*
N
*
2
+
num_threads
*
2
*
BLOCK_M
*
BLOCK_N
*
sizeof
(
float
);
...
@@ -1277,7 +1247,7 @@ at::Tensor shared_expert_cpu(
...
@@ -1277,7 +1247,7 @@ at::Tensor shared_expert_cpu(
buffer_size_nbytes
+=
std
::
max
(
M
*
K
,
M
*
N
)
+
M
*
sizeof
(
float
);
buffer_size_nbytes
+=
std
::
max
(
M
*
K
,
M
*
N
)
+
M
*
sizeof
(
float
);
}
}
if
(
use_fp8_w8a16
)
{
if
(
use_fp8_w8a16
)
{
buffer_size_nbytes
+=
M
*
2
*
N
*
2
+
num_threads
*
BLOCK_M
*
std
::
max
(
K
,
N
)
*
2
;
buffer_size_nbytes
+=
M
*
2
*
N
*
2
+
num_threads
*
MAX_CACHE_BLOCK_SIZE
*
BLOCK_M
*
std
::
max
(
K
,
N
)
*
2
;
}
}
auto
buffer
=
at
::
empty
({
buffer_size_nbytes
},
hidden_states
.
options
().
dtype
(
at
::
kChar
));
auto
buffer
=
at
::
empty
({
buffer_size_nbytes
},
hidden_states
.
options
().
dtype
(
at
::
kChar
));
...
...
sgl-kernel/csrc/cpu/moe_fp8.cpp
View file @
5ad296bd
...
@@ -174,18 +174,18 @@ void fused_experts_fp8_kernel_impl(
...
@@ -174,18 +174,18 @@ void fused_experts_fp8_kernel_impl(
const
int64_t
stride_e
=
2
*
N
*
K
;
const
int64_t
stride_e
=
2
*
N
*
K
;
const
int64_t
stride_n
=
K
;
const
int64_t
stride_n
=
K
;
int64_t
avg_M
=
std
::
max
(
int64_t
(
1
),
M
*
topk
/
E
);
const
bool
use_brgemm
=
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
avg_M
);
int64_t
B_tmp_size_per_thread
=
MAX_CACHE_BLOCK_SIZE
*
BLOCK_N
*
std
::
max
(
K
,
N
);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at
::
parallel_
for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
scalar_t
*
__restrict__
A
=
A_tmp
+
tid
*
BLOCK_M
*
K
;
scalar_t
*
__restrict__
A
=
A_tmp
+
tid
*
BLOCK_M
*
K
;
bool
is_brgemm_used
=
false
;
loop_2d
<
at
::
Float8_e4m3fn
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB
;
int64_t
nb
=
i
%
NB
;
int64_t
n_size
=
std
::
min
(
2
*
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
n_size
=
std
::
min
(
2
*
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
// B shape [K, n_size] in vnni format
// B shape [K, n_size] in vnni format
...
@@ -194,13 +194,14 @@ void fused_experts_fp8_kernel_impl(
...
@@ -194,13 +194,14 @@ void fused_experts_fp8_kernel_impl(
const
float
*
__restrict__
Bs
=
const
float
*
__restrict__
Bs
=
w1s
+
expert_id
*
scale_size_N
*
scale_size_K
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
;
w1s
+
expert_id
*
scale_size_N
*
scale_size_K
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
;
// do unpacking for the first row or a new expert
int32_t
pre_expert_id
=
mb
==
0
?
-
1
:
expert_ids
[
mb
-
1
];
bool
do_unpack
=
(
mb
==
mb0
)
||
(
expert_id
!=
pre_expert_id
);
// 1.a load A
// 1.a load A
const
int32_t
*
A_ids
=
sorted_ids
+
mb
*
BLOCK_M
;
const
int32_t
*
A_ids
=
sorted_ids
+
mb
*
BLOCK_M
;
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
const
bool
use_brgemm
=
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
int32_t
index
=
A_ids
[
m
]
/
topk
;
int32_t
index
=
A_ids
[
m
]
/
topk
;
copy_stub
(
A
+
m
*
K
,
input
+
index
*
K
,
K
);
copy_stub
(
A
+
m
*
K
,
input
+
index
*
K
,
K
);
...
@@ -211,7 +212,7 @@ void fused_experts_fp8_kernel_impl(
...
@@ -211,7 +212,7 @@ void fused_experts_fp8_kernel_impl(
/* A */
A
,
/* A */
A
,
/* B */
B
,
/* B */
B
,
/* C */
ic0
+
offset
*
2
*
N
+
nb
*
BLOCK_N
,
/* C */
ic0
+
offset
*
2
*
N
+
nb
*
BLOCK_N
,
/* Btmp */
B_tmp
+
tid
*
B
LOCK_N
*
std
::
max
(
K
,
N
)
,
/* Btmp */
B_tmp
+
tid
*
B
_tmp_size_per_thread
+
nb_offset
*
BLOCK_N
*
K
,
/* Ctmp */
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
,
/* Ctmp */
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
,
/* scale */
Bs
,
/* scale */
Bs
,
/* M */
m_size
,
/* M */
m_size
,
...
@@ -221,10 +222,11 @@ void fused_experts_fp8_kernel_impl(
...
@@ -221,10 +222,11 @@ void fused_experts_fp8_kernel_impl(
/* ldb */
n_size
,
/* ldb */
n_size
,
/* ldc */
2
*
N
,
/* ldc */
2
*
N
,
/* brg */
use_brgemm
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
/* block_size_K */
block_size_K
,
}
/* do_unpack */
do_unpack
);
});
if
(
is
_brgemm
_used
)
{
if
(
use
_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
at
::
native
::
cpublas
::
brgemm_release
();
}
}
});
});
...
@@ -248,22 +250,14 @@ void fused_experts_fp8_kernel_impl(
...
@@ -248,22 +250,14 @@ void fused_experts_fp8_kernel_impl(
const
int64_t
stride_oc
=
IC
;
const
int64_t
stride_oc
=
IC
;
// parallel on [MB2, NB2]
// parallel on [MB2, NB2]
at
::
parallel_
for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB2
,
NB2
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
alignas
(
64
)
scalar_t
C
[
BLOCK_M
*
BLOCK_K
];
alignas
(
64
)
scalar_t
C
[
BLOCK_M
*
BLOCK_K
];
bool
is_brgemm_used
=
false
;
loop_2d
<
at
::
Float8_e4m3fn
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
IC
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
const
bool
use_brgemm
=
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
m_size
);
is_brgemm_used
=
is_brgemm_used
||
use_brgemm
;
// A ptr from ic1 of [M * topk, N] in sorted order
// A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again
// so as to avoid copy A to tmp buffer again
const
scalar_t
*
__restrict__
A
=
ic1
+
offsets
[
mb
]
*
N
;
const
scalar_t
*
__restrict__
A
=
ic1
+
offsets
[
mb
]
*
N
;
...
@@ -275,11 +269,15 @@ void fused_experts_fp8_kernel_impl(
...
@@ -275,11 +269,15 @@ void fused_experts_fp8_kernel_impl(
const
float
*
__restrict__
Bs
=
const
float
*
__restrict__
Bs
=
w2s
+
expert_id
*
scale_size_N
*
scale_size_K
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
;
w2s
+
expert_id
*
scale_size_N
*
scale_size_K
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
;
// do unpacking for the first row or a new expert
int32_t
pre_expert_id
=
mb
==
0
?
-
1
:
expert_ids
[
mb
-
1
];
bool
do_unpack
=
(
mb
==
mb0
)
||
(
expert_id
!=
pre_expert_id
);
tinygemm_kernel
<
scalar_t
>
(
tinygemm_kernel
<
scalar_t
>
(
/* A */
A
,
/* A */
A
,
/* B */
B
,
/* B */
B
,
/* C */
C
,
/* C */
C
,
/* Btmp */
B_tmp
+
tid
*
B
LOCK_N
*
std
::
max
(
K
,
N
)
,
/* Btmp */
B_tmp
+
tid
*
B
_tmp_size_per_thread
+
nb_offset
*
BLOCK_N
*
IC
,
/* Ctmp */
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
,
/* Ctmp */
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
,
/* scale */
Bs
,
/* scale */
Bs
,
/* M */
m_size
,
/* M */
m_size
,
...
@@ -289,7 +287,8 @@ void fused_experts_fp8_kernel_impl(
...
@@ -289,7 +287,8 @@ void fused_experts_fp8_kernel_impl(
/* ldb */
n_size
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* ldc */
BLOCK_N
,
/* brg */
use_brgemm
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
/* block_size_K */
block_size_K
,
/* do_unpack */
do_unpack
);
// 2.b copy from C to ic2 in original order
// 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32
// and also mul topk_weights in float32
...
@@ -298,9 +297,9 @@ void fused_experts_fp8_kernel_impl(
...
@@ -298,9 +297,9 @@ void fused_experts_fp8_kernel_impl(
float
weight
=
topk_weights
[
index
];
float
weight
=
topk_weights
[
index
];
copy_mul_stub
(
ic2
+
index
*
K
+
nb
*
BLOCK_N
,
C
+
m
*
BLOCK_N
,
weight
,
n_size
);
copy_mul_stub
(
ic2
+
index
*
K
+
nb
*
BLOCK_N
,
C
+
m
*
BLOCK_N
,
weight
,
n_size
);
}
}
}
}
);
if
(
is
_brgemm
_used
)
{
if
(
use
_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
at
::
native
::
cpublas
::
brgemm_release
();
}
}
});
});
...
@@ -374,20 +373,23 @@ void shared_expert_fp8_kernel_impl(
...
@@ -374,20 +373,23 @@ void shared_expert_fp8_kernel_impl(
const
bool
use_brgemm
=
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
M
);
const
bool
use_brgemm
=
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
M
);
at
::
parallel_for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int64_t
B_tmp_size_per_thread
=
MAX_CACHE_BLOCK_SIZE
*
BLOCK_N
*
std
::
max
(
K
,
N
);
int
tid
=
at
::
get_thread_num
();
parallel_2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
int
tid
=
get_thread_num
();
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
loop_2d
<
at
::
Float8_e4m3fn
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
mb
=
i
/
NB
;
int64_t
nb
=
i
%
NB
;
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
n_size
=
std
::
min
(
2
*
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
n_size
=
std
::
min
(
2
*
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
// do unpacking for the first row
bool
do_unpack
=
(
mb
==
mb0
);
tinygemm_kernel
<
scalar_t
>
(
tinygemm_kernel
<
scalar_t
>
(
/* A */
input
+
mb
*
BLOCK_M
*
K
,
/* A */
input
+
mb
*
BLOCK_M
*
K
,
/* B */
packed_w1
+
nb
*
BLOCK_N
*
K
,
/* B */
packed_w1
+
nb
*
BLOCK_N
*
K
,
/* C */
ic0
+
mb
*
BLOCK_M
*
2
*
N
+
nb
*
BLOCK_N
,
/* C */
ic0
+
mb
*
BLOCK_M
*
2
*
N
+
nb
*
BLOCK_N
,
/* Btmp */
B_tmp
+
tid
*
B
LOCK_N
*
std
::
max
(
K
,
N
)
,
/* Btmp */
B_tmp
+
tid
*
B
_tmp_size_per_thread
+
nb_offset
*
BLOCK_N
*
K
,
/* Ctmp */
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
,
/* Ctmp */
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
,
/* scale */
w1s
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
,
/* scale */
w1s
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
,
/* M */
m_size
,
/* M */
m_size
,
...
@@ -397,8 +399,9 @@ void shared_expert_fp8_kernel_impl(
...
@@ -397,8 +399,9 @@ void shared_expert_fp8_kernel_impl(
/* ldb */
n_size
,
/* ldb */
n_size
,
/* ldc */
2
*
N
,
/* ldc */
2
*
N
,
/* brg */
use_brgemm
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
/* block_size_K */
block_size_K
,
}
/* do_unpack */
do_unpack
);
});
if
(
use_brgemm
)
{
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
at
::
native
::
cpublas
::
brgemm_release
();
...
@@ -421,22 +424,23 @@ void shared_expert_fp8_kernel_impl(
...
@@ -421,22 +424,23 @@ void shared_expert_fp8_kernel_impl(
scale_size_K
=
div_up
(
N
,
block_size_K
);
scale_size_K
=
div_up
(
N
,
block_size_K
);
// parallel on [MB2, NB2]
// parallel on [MB2, NB2]
at
::
parallel_
for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB2
,
NB2
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
alignas
(
64
)
scalar_t
C
[
BLOCK_M
*
BLOCK_K
];
alignas
(
64
)
scalar_t
C
[
BLOCK_M
*
BLOCK_K
];
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
loop_2d
<
at
::
Float8_e4m3fn
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
IC
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
// do unpacking for the first row
bool
do_unpack
=
(
mb
==
mb0
);
// 2.a gemm: C = A @ B
// 2.a gemm: C = A @ B
tinygemm_kernel
<
scalar_t
>
(
tinygemm_kernel
<
scalar_t
>
(
/* A */
ic1
+
mb
*
BLOCK_M
*
N
,
/* A */
ic1
+
mb
*
BLOCK_M
*
N
,
/* B */
packed_w2
+
nb
*
BLOCK_N
*
N
,
/* B */
packed_w2
+
nb
*
BLOCK_N
*
N
,
/* C */
C
,
/* C */
C
,
/* Btmp */
B_tmp
+
tid
*
B
LOCK_N
*
std
::
max
(
K
,
N
)
,
/* Btmp */
B_tmp
+
tid
*
B
_tmp_size_per_thread
+
nb_offset
*
BLOCK_N
*
IC
,
/* Ctmp */
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
,
/* Ctmp */
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
,
/* scale */
w2s
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
,
/* scale */
w2s
+
(
nb
/
blocks_n_per_group
)
*
scale_size_K
,
/* M */
m_size
,
/* M */
m_size
,
...
@@ -446,7 +450,8 @@ void shared_expert_fp8_kernel_impl(
...
@@ -446,7 +450,8 @@ void shared_expert_fp8_kernel_impl(
/* ldb */
n_size
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* ldc */
BLOCK_N
,
/* brg */
use_brgemm
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
/* block_size_K */
block_size_K
,
/* do_unpack */
do_unpack
);
// 2.b copy from C to output and add fused_experts_out
// 2.b copy from C to output and add fused_experts_out
scalar_t
*
__restrict__
out
=
output
+
mb
*
BLOCK_M
*
K
+
nb
*
BLOCK_N
;
scalar_t
*
__restrict__
out
=
output
+
mb
*
BLOCK_M
*
K
+
nb
*
BLOCK_N
;
...
@@ -454,7 +459,7 @@ void shared_expert_fp8_kernel_impl(
...
@@ -454,7 +459,7 @@ void shared_expert_fp8_kernel_impl(
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
add_mul_stub
(
out
+
m
*
K
,
C
+
m
*
BLOCK_N
,
fused_out
+
m
*
K
,
routed_scaling_factor
,
n_size
);
add_mul_stub
(
out
+
m
*
K
,
C
+
m
*
BLOCK_N
,
fused_out
+
m
*
K
,
routed_scaling_factor
,
n_size
);
}
}
}
}
);
});
});
if
(
use_brgemm
)
{
if
(
use_brgemm
)
{
...
...
sgl-kernel/csrc/cpu/moe_int8.cpp
View file @
5ad296bd
...
@@ -109,6 +109,120 @@ inline void add_mul_stub(
...
@@ -109,6 +109,120 @@ inline void add_mul_stub(
}
}
}
}
template
<
typename
scalar_t
,
int
BLOCK_N
>
inline
void
silu_and_mul
(
scalar_t
*
__restrict__
C
,
const
int32_t
*
__restrict__
C0
,
// x: x0, x1
const
int32_t
*
__restrict__
C1
,
// y: y0, y1
const
float
*
__restrict__
As
,
const
float
*
__restrict__
Bs0
,
const
float
*
__restrict__
Bs1
,
const
int32_t
*
__restrict__
Bcomp0
,
const
int32_t
*
__restrict__
Bcomp1
,
int64_t
m_size
,
int64_t
N
)
{
#if defined(CPU_CAPABILITY_AVX512)
constexpr
int
COLS
=
BLOCK_N
/
16
;
static_assert
(
COLS
%
2
==
0
);
__m512
vc0
[
COLS
];
__m512
vc1
[
COLS
];
__m512i
vcomp0
[
COLS
];
__m512i
vcomp1
[
COLS
];
__m512
vas
;
__m512
vbs0
[
COLS
];
__m512
vbs1
[
COLS
];
auto
load_scale_and_comp
=
[
&
](
auto
col
)
{
vcomp0
[
col
]
=
_mm512_loadu_si512
(
Bcomp0
+
col
*
16
);
vcomp1
[
col
]
=
_mm512_loadu_si512
(
Bcomp1
+
col
*
16
);
vbs0
[
col
]
=
_mm512_loadu_ps
(
Bs0
+
col
*
16
);
vbs1
[
col
]
=
_mm512_loadu_ps
(
Bs1
+
col
*
16
);
};
Unroll
<
COLS
>
{}(
load_scale_and_comp
);
auto
scalec
=
[
&
](
auto
col
,
int64_t
m
)
{
// update As
vas
=
_mm512_set1_ps
(
As
[
m
]);
// C = As * (C - Bcomp) * Bs
__m512i
vc32_0
=
_mm512_loadu_si512
(
C0
+
m
*
BLOCK_N
+
col
*
16
);
__m512i
vc32_1
=
_mm512_loadu_si512
(
C1
+
m
*
BLOCK_N
+
col
*
16
);
vc0
[
col
]
=
_mm512_cvtepi32_ps
(
_mm512_sub_epi32
(
vc32_0
,
vcomp0
[
col
]));
vc1
[
col
]
=
_mm512_cvtepi32_ps
(
_mm512_sub_epi32
(
vc32_1
,
vcomp1
[
col
]));
vc0
[
col
]
=
_mm512_mul_ps
(
_mm512_mul_ps
(
vc0
[
col
],
vas
),
vbs0
[
col
]);
vc1
[
col
]
=
_mm512_mul_ps
(
_mm512_mul_ps
(
vc1
[
col
],
vas
),
vbs1
[
col
]);
};
using
bVec
=
at
::
vec
::
Vectorized
<
scalar_t
>
;
using
fVec
=
at
::
vec
::
Vectorized
<
float
>
;
const
fVec
one
=
fVec
(
1.
f
);
auto
silu_and_mul
=
[
&
](
auto
col
)
{
fVec
x
=
fVec
(
vc0
[
col
]);
fVec
y
=
fVec
(
vc1
[
col
]);
x
=
x
/
(
one
+
x
.
neg
().
exp_u20
());
vc0
[
col
]
=
x
*
y
;
};
auto
storec
=
[
&
](
auto
col
,
int64_t
m
)
{
if
constexpr
(
col
%
2
==
0
)
{
fVec
x0
=
fVec
(
vc0
[
col
+
0
]);
fVec
x1
=
fVec
(
vc0
[
col
+
1
]);
bVec
out_vec
=
convert_from_float_ext
<
scalar_t
>
(
x0
,
x1
);
out_vec
.
store
(
C
+
m
*
N
+
col
*
16
);
}
};
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
Unroll
<
COLS
>
{}(
scalec
,
m
);
Unroll
<
COLS
>
{}(
silu_and_mul
);
Unroll
<
COLS
>
{}(
storec
,
m
);
}
#else
TORCH_CHECK
(
false
,
"silu_and_mul: scalar path not implemented!"
);
#endif
}
template
<
int
BLOCK_N
>
inline
void
scale_C
(
float
*
__restrict__
C
,
const
int32_t
*
__restrict__
Ctmp
,
const
float
*
__restrict__
As
,
const
float
*
__restrict__
Bs
,
const
int32_t
*
__restrict__
Bcomp
,
int64_t
m_size
)
{
#if defined(CPU_CAPABILITY_AVX512)
constexpr
int
COLS
=
BLOCK_N
/
16
;
static_assert
(
COLS
%
2
==
0
);
__m512
vc
[
COLS
];
__m512i
vcomp
[
COLS
];
__m512
vas
;
__m512
vbs
[
COLS
];
auto
load_scale_and_comp
=
[
&
](
auto
col
)
{
vcomp
[
col
]
=
_mm512_loadu_si512
(
Bcomp
+
col
*
16
);
vbs
[
col
]
=
_mm512_loadu_ps
(
Bs
+
col
*
16
);
};
Unroll
<
COLS
>
{}(
load_scale_and_comp
);
auto
scalec
=
[
&
](
auto
col
,
int64_t
m
)
{
// update As
vas
=
_mm512_set1_ps
(
As
[
m
]);
// C = As * (C - Bcomp) * Bs
__m512i
vc32
=
_mm512_loadu_si512
(
Ctmp
+
m
*
BLOCK_N
+
col
*
16
);
vc
[
col
]
=
_mm512_cvtepi32_ps
(
_mm512_sub_epi32
(
vc32
,
vcomp
[
col
]));
vc
[
col
]
=
_mm512_mul_ps
(
_mm512_mul_ps
(
vc
[
col
],
vas
),
vbs
[
col
]);
_mm512_storeu_ps
(
C
+
m
*
BLOCK_N
+
col
*
16
,
vc
[
col
]);
};
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
Unroll
<
COLS
>
{}(
scalec
,
m
);
}
#else
TORCH_CHECK
(
false
,
"scale_C: scalar path not implemented!"
);
#endif
}
/// gemm for w13
/// gemm for w13
template
<
typename
scalar_t
,
int
BLOCK_M
,
int
BLOCK_N
>
template
<
typename
scalar_t
,
int
BLOCK_M
,
int
BLOCK_N
>
struct
tinygemm_kernel_vnni
{
struct
tinygemm_kernel_vnni
{
...
@@ -515,28 +629,31 @@ void fused_experts_int8_kernel_impl(
...
@@ -515,28 +629,31 @@ void fused_experts_int8_kernel_impl(
const
int64_t
stride_e
=
2
*
N
*
packed_K
;
const
int64_t
stride_e
=
2
*
N
*
packed_K
;
const
int64_t
stride_n
=
packed_K
;
const
int64_t
stride_n
=
packed_K
;
int64_t
avg_M
=
std
::
max
(
int64_t
(
1
),
M
*
topk
/
E
);
const
bool
use_brgemm
=
can_use_brgemm
<
int8_t
>
(
avg_M
);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at
::
parallel_
for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
uint8_t
*
__restrict__
A
=
A_tmp
+
tid
*
BLOCK_M
*
K
;
uint8_t
*
__restrict__
A
=
A_tmp
+
tid
*
BLOCK_M
*
K
;
int32_t
*
__restrict__
C0
=
reinterpret_cast
<
int32_t
*>
(
C_tmp
)
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
int32_t
*
__restrict__
C1
=
C0
+
BLOCK_M
*
BLOCK_N
;
alignas
(
64
)
float
As
[
BLOCK_M
];
alignas
(
64
)
float
As
[
BLOCK_M
];
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
loop_2d
<
int8_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
*
2
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
mb
=
i
/
NB
;
// nb_upper from top half and nb_lower from bottom half
int64_t
nb
=
i
%
NB
;
int64_t
nb_upper
=
nb
,
nb_lower
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
// nb0 from top half and nb1 from bottom half
int64_t
nb0
=
nb
,
nb1
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb0
*
BLOCK_N
,
BLOCK_N
);
// B shape [K, n_size] in vnni format
// B shape [K, n_size] in vnni format
int32_t
expert_id
=
expert_ids
[
mb
];
int32_t
expert_id
=
expert_ids
[
mb
];
const
int8_t
*
__restrict__
B0
=
packed_w1
+
expert_id
*
stride_e
+
nb
0
*
BLOCK_N
*
stride_n
;
const
int8_t
*
__restrict__
B0
=
packed_w1
+
expert_id
*
stride_e
+
nb
_upper
*
BLOCK_N
*
stride_n
;
const
int8_t
*
__restrict__
B1
=
packed_w1
+
expert_id
*
stride_e
+
nb
1
*
BLOCK_N
*
stride_n
;
const
int8_t
*
__restrict__
B1
=
packed_w1
+
expert_id
*
stride_e
+
nb
_lower
*
BLOCK_N
*
stride_n
;
const
float
*
__restrict__
Bs0
=
w1s
+
expert_id
*
2
*
N
+
nb
0
*
BLOCK_N
;
const
float
*
__restrict__
Bs0
=
w1s
+
expert_id
*
2
*
N
+
nb
_upper
*
BLOCK_N
;
const
float
*
__restrict__
Bs1
=
w1s
+
expert_id
*
2
*
N
+
nb
1
*
BLOCK_N
;
const
float
*
__restrict__
Bs1
=
w1s
+
expert_id
*
2
*
N
+
nb
_lower
*
BLOCK_N
;
// 1.a load A
// 1.a load A
const
int32_t
*
A_ids
=
sorted_ids
+
mb
*
BLOCK_M
;
const
int32_t
*
A_ids
=
sorted_ids
+
mb
*
BLOCK_M
;
...
@@ -548,22 +665,62 @@ void fused_experts_int8_kernel_impl(
...
@@ -548,22 +665,62 @@ void fused_experts_int8_kernel_impl(
As
[
m
]
=
As_tmp
[
index
];
As
[
m
]
=
As_tmp
[
index
];
}
}
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
if
(
use_brgemm
)
{
const
int64_t
offset
=
offsets
[
mb
];
// 1.b gemm: C0 = A @ B0
tinygemm_kernel
(
at
::
native
::
cpublas
::
brgemm
(
/* A */
A
,
/* M */
m_size
,
/* B0 */
B0
,
/* N */
n_size
,
/* B1 */
B1
,
/* K */
K
,
/* C */
ic1
+
offset
*
N
+
nb
*
BLOCK_N
,
/* lda */
K
,
/* As */
As
,
/* ldb */
n_size
,
/* Bs0 */
Bs0
,
/* ldc */
BLOCK_N
,
/* Bs1 */
Bs1
,
/* add_C */
false
,
/* M */
m_size
,
/* A */
A
,
/* N */
n_size
,
/* B */
B0
,
/* K */
K
,
/* C */
C0
);
/* lda */
K
,
/* ldb */
n_size
,
// 1.c gemm: C1 = A @ B1
/* ldc */
N
);
at
::
native
::
cpublas
::
brgemm
(
/* M */
m_size
,
/* N */
n_size
,
/* K */
K
,
/* lda */
K
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* add_C */
false
,
/* A */
A
,
/* B */
B1
,
/* C */
C1
);
const
int32_t
*
Bcomp0
=
reinterpret_cast
<
const
int32_t
*>
(
B0
+
block_size_n
()
*
K
);
const
int32_t
*
Bcomp1
=
reinterpret_cast
<
const
int32_t
*>
(
B1
+
block_size_n
()
*
K
);
// 1.d silu and mul
const
int64_t
offset
=
offsets
[
mb
];
silu_and_mul
<
scalar_t
,
BLOCK_N
>
(
ic1
+
offset
*
N
+
nb
*
BLOCK_N
,
C0
,
C1
,
As
,
Bs0
,
Bs1
,
Bcomp0
,
Bcomp1
,
m_size
,
N
);
}
else
{
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
const
int64_t
offset
=
offsets
[
mb
];
tinygemm_kernel
(
/* A */
A
,
/* B0 */
B0
,
/* B1 */
B1
,
/* C */
ic1
+
offset
*
N
+
nb
*
BLOCK_N
,
/* As */
As
,
/* Bs0 */
Bs0
,
/* Bs1 */
Bs1
,
/* M */
m_size
,
/* N */
n_size
,
/* K */
K
,
/* lda */
K
,
/* ldb */
n_size
,
/* ldc */
N
);
}
});
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
}
});
});
...
@@ -584,16 +741,13 @@ void fused_experts_int8_kernel_impl(
...
@@ -584,16 +741,13 @@ void fused_experts_int8_kernel_impl(
const
int64_t
stride_oc
=
packed_N
;
const
int64_t
stride_oc
=
packed_N
;
// parallel on [MB2, NB2]
// parallel on [MB2, NB2]
at
::
parallel_
for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB2
,
NB2
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
// we won't be using C1 for gemm2
float
*
__restrict__
C
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
float
*
__restrict__
C
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
int32_t
*
__restrict__
C32
=
reinterpret_cast
<
int32_t
*>
(
C
+
BLOCK_M
*
BLOCK_N
);
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
loop_2d
<
int8_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
IC
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
int64_t
m_size
=
offsets
[
mb
+
1
]
-
offsets
[
mb
];
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
...
@@ -609,18 +763,36 @@ void fused_experts_int8_kernel_impl(
...
@@ -609,18 +763,36 @@ void fused_experts_int8_kernel_impl(
const
float
*
__restrict__
Bs
=
w2s
+
expert_id
*
K
+
nb
*
BLOCK_N
;
const
float
*
__restrict__
Bs
=
w2s
+
expert_id
*
K
+
nb
*
BLOCK_N
;
// 2.a gemm: C = A @ B
// 2.a gemm: C = A @ B
tinygemm_kernel
<
scalar_t
>
(
if
(
use_brgemm
)
{
/* A */
A
,
at
::
native
::
cpublas
::
brgemm
(
/* B */
B
,
/* M */
m_size
,
/* C */
C
,
/* N */
n_size
,
/* As */
As
,
/* K */
IC
,
/* Bs */
Bs
,
/* lda */
IC
,
/* M */
m_size
,
/* ldb */
n_size
,
/* N */
n_size
,
/* ldc */
BLOCK_N
,
/* K */
IC
,
/* add_C */
false
,
/* lda */
IC
,
/* A */
A
,
/* ldb */
n_size
,
/* B */
B
,
/* ldc */
BLOCK_N
);
/* C */
C32
);
// apply scales
const
int32_t
*
Bcomp
=
reinterpret_cast
<
const
int32_t
*>
(
B
+
block_size_n
()
*
IC
);
scale_C
<
BLOCK_N
>
(
C
,
C32
,
As
,
Bs
,
Bcomp
,
m_size
);
}
else
{
tinygemm_kernel
<
scalar_t
>
(
/* A */
A
,
/* B */
B
,
/* C */
C
,
/* As */
As
,
/* Bs */
Bs
,
/* M */
m_size
,
/* N */
n_size
,
/* K */
IC
,
/* lda */
IC
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
);
}
// 2.b copy from C to ic2 in original order
// 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32
// and also mul topk_weights in float32
...
@@ -629,6 +801,10 @@ void fused_experts_int8_kernel_impl(
...
@@ -629,6 +801,10 @@ void fused_experts_int8_kernel_impl(
float
weight
=
topk_weights
[
index
];
float
weight
=
topk_weights
[
index
];
copy_mul_stub
(
ic2
+
index
*
K
+
nb
*
BLOCK_N
,
C
+
m
*
BLOCK_N
,
weight
,
n_size
);
copy_mul_stub
(
ic2
+
index
*
K
+
nb
*
BLOCK_N
,
C
+
m
*
BLOCK_N
,
weight
,
n_size
);
}
}
});
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
}
});
});
...
@@ -708,15 +884,19 @@ void shared_expert_int8_kernel_impl(
...
@@ -708,15 +884,19 @@ void shared_expert_int8_kernel_impl(
const
int64_t
packed_N
=
get_row_size
<
int8_t
>
(
N
);
const
int64_t
packed_N
=
get_row_size
<
int8_t
>
(
N
);
const
int64_t
stride_n
=
packed_K
;
const
int64_t
stride_n
=
packed_K
;
const
bool
use_brgemm
=
can_use_brgemm
<
int8_t
>
(
M
);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at
::
parallel_for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_2d
(
MB
,
NB
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
// get local pointers
int64_t
mb
=
i
/
NB
;
int
tid
=
get_thread_num
();
int64_t
nb
=
i
%
NB
;
int32_t
*
__restrict__
C0
=
reinterpret_cast
<
int32_t
*>
(
C_tmp
)
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
int32_t
*
__restrict__
C1
=
C0
+
BLOCK_M
*
BLOCK_N
;
// nb0 from top half and nb1 from bottom half
int64_t
nb0
=
nb
,
nb1
=
nb
+
NB
;
loop_2d
<
int8_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
K
*
2
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
n_size
=
std
::
min
(
N
-
nb0
*
BLOCK_N
,
BLOCK_N
);
// nb_upper from top half and nb_lower from bottom half
int64_t
nb_upper
=
nb
,
nb_lower
=
nb
+
NB
;
int64_t
n_size
=
std
::
min
(
N
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
// A shape [m_size, K]
// A shape [m_size, K]
...
@@ -724,26 +904,65 @@ void shared_expert_int8_kernel_impl(
...
@@ -724,26 +904,65 @@ void shared_expert_int8_kernel_impl(
const
float
*
As
=
As_tmp
+
mb
*
BLOCK_M
;
const
float
*
As
=
As_tmp
+
mb
*
BLOCK_M
;
// B shape [K, n_size] in vnni format
// B shape [K, n_size] in vnni format
const
int8_t
*
__restrict__
B0
=
packed_w1
+
nb0
*
BLOCK_N
*
stride_n
;
const
int8_t
*
__restrict__
B0
=
packed_w1
+
nb_upper
*
BLOCK_N
*
stride_n
;
const
int8_t
*
__restrict__
B1
=
packed_w1
+
nb1
*
BLOCK_N
*
stride_n
;
const
int8_t
*
__restrict__
B1
=
packed_w1
+
nb_lower
*
BLOCK_N
*
stride_n
;
const
float
*
__restrict__
Bs0
=
w1s
+
nb0
*
BLOCK_N
;
const
float
*
__restrict__
Bs0
=
w1s
+
nb_upper
*
BLOCK_N
;
const
float
*
__restrict__
Bs1
=
w1s
+
nb1
*
BLOCK_N
;
const
float
*
__restrict__
Bs1
=
w1s
+
nb_lower
*
BLOCK_N
;
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
if
(
use_brgemm
)
{
tinygemm_kernel
(
// 1.b gemm: C0 = A @ B0
/* A */
A
,
at
::
native
::
cpublas
::
brgemm
(
/* B0 */
B0
,
/* M */
m_size
,
/* B1 */
B1
,
/* N */
n_size
,
/* C */
ic1
+
mb
*
BLOCK_M
*
N
+
nb
*
BLOCK_N
,
/* K */
K
,
/* As */
As
,
/* lda */
K
,
/* Bs0 */
Bs0
,
/* ldb */
n_size
,
/* Bs1 */
Bs1
,
/* ldc */
BLOCK_N
,
/* M */
m_size
,
/* add_C */
false
,
/* N */
n_size
,
/* A */
A
,
/* K */
K
,
/* B */
B0
,
/* lda */
K
,
/* C */
C0
);
/* ldb */
n_size
,
/* ldc */
N
);
// 1.c gemm: C1 = A @ B1
at
::
native
::
cpublas
::
brgemm
(
/* M */
m_size
,
/* N */
n_size
,
/* K */
K
,
/* lda */
K
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
,
/* add_C */
false
,
/* A */
A
,
/* B */
B1
,
/* C */
C1
);
const
int32_t
*
Bcomp0
=
reinterpret_cast
<
const
int32_t
*>
(
B0
+
block_size_n
()
*
K
);
const
int32_t
*
Bcomp1
=
reinterpret_cast
<
const
int32_t
*>
(
B1
+
block_size_n
()
*
K
);
// 1.d silu and mul
silu_and_mul
<
scalar_t
,
BLOCK_N
>
(
ic1
+
mb
*
BLOCK_M
*
N
+
nb
*
BLOCK_N
,
C0
,
C1
,
As
,
Bs0
,
Bs1
,
Bcomp0
,
Bcomp1
,
m_size
,
N
);
}
else
{
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
tinygemm_kernel
(
/* A */
A
,
/* B0 */
B0
,
/* B1 */
B1
,
/* C */
ic1
+
mb
*
BLOCK_M
*
N
+
nb
*
BLOCK_N
,
/* As */
As
,
/* Bs0 */
Bs0
,
/* Bs1 */
Bs1
,
/* M */
m_size
,
/* N */
n_size
,
/* K */
K
,
/* lda */
K
,
/* ldb */
n_size
,
/* ldc */
N
);
}
});
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
}
});
});
...
@@ -763,16 +982,13 @@ void shared_expert_int8_kernel_impl(
...
@@ -763,16 +982,13 @@ void shared_expert_int8_kernel_impl(
const
int64_t
stride_oc
=
packed_N
;
const
int64_t
stride_oc
=
packed_N
;
// parallel on [MB2, NB2]
// parallel on [MB2, NB2]
at
::
parallel_
for
(
0
,
MB2
*
NB2
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
parallel_
2d
(
MB2
,
NB2
,
[
&
](
int64_t
mb0
,
int64_t
mb1
,
int64_t
nb0
,
int64_t
nb1
)
{
// get local pointers
// get local pointers
int
tid
=
at
::
get_thread_num
();
int
tid
=
get_thread_num
();
// we won't be using C1 for gemm2
float
*
__restrict__
C
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
float
*
__restrict__
C
=
C_tmp
+
tid
*
2
*
BLOCK_M
*
BLOCK_N
;
int32_t
*
__restrict__
C32
=
reinterpret_cast
<
int32_t
*>
(
C
+
BLOCK_M
*
BLOCK_N
);
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
loop_2d
<
int8_t
>
(
mb0
,
mb1
,
nb0
,
nb1
,
BLOCK_N
*
IC
,
[
&
](
int64_t
mb
,
int64_t
nb
,
int64_t
nb_offset
)
{
int64_t
mb
=
i
/
NB2
;
int64_t
nb
=
i
%
NB2
;
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
m_size
=
std
::
min
(
M
-
mb
*
BLOCK_M
,
BLOCK_M
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
int64_t
n_size
=
std
::
min
(
OC
-
nb
*
BLOCK_N
,
BLOCK_N
);
...
@@ -784,19 +1000,37 @@ void shared_expert_int8_kernel_impl(
...
@@ -784,19 +1000,37 @@ void shared_expert_int8_kernel_impl(
const
int8_t
*
__restrict__
B
=
packed_w2
+
nb
*
BLOCK_N
*
stride_oc
;
const
int8_t
*
__restrict__
B
=
packed_w2
+
nb
*
BLOCK_N
*
stride_oc
;
const
float
*
__restrict__
Bs
=
w2s
+
nb
*
BLOCK_N
;
const
float
*
__restrict__
Bs
=
w2s
+
nb
*
BLOCK_N
;
// 2.a gemm: C = A @ B
if
(
use_brgemm
)
{
tinygemm_kernel
<
scalar_t
>
(
at
::
native
::
cpublas
::
brgemm
(
/* A */
A
,
/* M */
m_size
,
/* B */
B
,
/* N */
n_size
,
/* C */
C
,
/* K */
IC
,
/* As */
As
,
/* lda */
IC
,
/* Bs */
Bs
,
/* ldb */
n_size
,
/* M */
m_size
,
/* ldc */
BLOCK_N
,
/* N */
n_size
,
/* add_C */
false
,
/* K */
IC
,
/* A */
A
,
/* lda */
IC
,
/* B */
B
,
/* ldb */
n_size
,
/* C */
C32
);
/* ldc */
BLOCK_N
);
// apply scales
const
int32_t
*
Bcomp
=
reinterpret_cast
<
const
int32_t
*>
(
B
+
block_size_n
()
*
IC
);
scale_C
<
BLOCK_N
>
(
C
,
C32
,
As
,
Bs
,
Bcomp
,
m_size
);
}
else
{
// 2.a gemm: C = A @ B
tinygemm_kernel
<
scalar_t
>
(
/* A */
A
,
/* B */
B
,
/* C */
C
,
/* As */
As
,
/* Bs */
Bs
,
/* M */
m_size
,
/* N */
n_size
,
/* K */
IC
,
/* lda */
IC
,
/* ldb */
n_size
,
/* ldc */
BLOCK_N
);
}
// 2.b copy from C to output and add fused_experts_out
// 2.b copy from C to output and add fused_experts_out
scalar_t
*
__restrict__
out
=
output
+
mb
*
BLOCK_M
*
K
+
nb
*
BLOCK_N
;
scalar_t
*
__restrict__
out
=
output
+
mb
*
BLOCK_M
*
K
+
nb
*
BLOCK_N
;
...
@@ -804,6 +1038,10 @@ void shared_expert_int8_kernel_impl(
...
@@ -804,6 +1038,10 @@ void shared_expert_int8_kernel_impl(
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
for
(
int64_t
m
=
0
;
m
<
m_size
;
++
m
)
{
add_mul_stub
(
out
+
m
*
K
,
C
+
m
*
BLOCK_N
,
fused_out
+
m
*
K
,
routed_scaling_factor
,
n_size
);
add_mul_stub
(
out
+
m
*
K
,
C
+
m
*
BLOCK_N
,
fused_out
+
m
*
K
,
routed_scaling_factor
,
n_size
);
}
}
});
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
}
});
});
}
}
...
...
sgl-kernel/csrc/cpu/qkv_proj.cpp
View file @
5ad296bd
...
@@ -100,8 +100,7 @@ void segment_gemm_kernel_impl(
...
@@ -100,8 +100,7 @@ void segment_gemm_kernel_impl(
const
int64_t
NB1
=
div_up
(
N1
,
BLOCK_N
);
const
int64_t
NB1
=
div_up
(
N1
,
BLOCK_N
);
const
int64_t
NB
=
NB0
+
NB1
;
const
int64_t
NB
=
NB0
+
NB1
;
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
const
bool
use_brgemm
=
can_use_brgemm
<
int8_t
>
(
M
);
const
bool
use_brgemm
=
false
;
// K + 4 after compensation
// K + 4 after compensation
const
int64_t
packed_row_size
=
get_row_size
<
int8_t
>
(
K
);
const
int64_t
packed_row_size
=
get_row_size
<
int8_t
>
(
K
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment