Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
13fb8b54
Unverified
Commit
13fb8b54
authored
Oct 23, 2025
by
blzheng
Committed by
GitHub
Oct 22, 2025
Browse files
[CPU] Optimize FP16 decode_attention_cpu (#10652)
parent
81fd2b0e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
181 additions
and
9 deletions
+181
-9
python/sglang/srt/layers/vocab_parallel_embedding.py
python/sglang/srt/layers/vocab_parallel_embedding.py
+4
-1
sgl-kernel/csrc/cpu/decode.cpp
sgl-kernel/csrc/cpu/decode.cpp
+171
-2
sgl-kernel/csrc/cpu/vec.h
sgl-kernel/csrc/cpu/vec.h
+1
-1
test/srt/cpu/test_decode.py
test/srt/cpu/test_decode.py
+5
-5
No files found.
python/sglang/srt/layers/vocab_parallel_embedding.py
View file @
13fb8b54
...
@@ -540,7 +540,10 @@ class ParallelLMHead(VocabParallelEmbedding):
...
@@ -540,7 +540,10 @@ class ParallelLMHead(VocabParallelEmbedding):
# We only support pack LMHead if it's not quantized.
# We only support pack LMHead if it's not quantized.
if
_is_cpu
and
_is_cpu_amx_available
:
if
_is_cpu
and
_is_cpu_amx_available
:
if
hasattr
(
self
,
"weight"
)
and
self
.
weight
.
dtype
==
torch
.
bfloat16
:
if
hasattr
(
self
,
"weight"
)
and
self
.
weight
.
dtype
in
[
torch
.
bfloat16
,
torch
.
float16
,
]:
self
.
quant_method
=
PackWeightMethod
(
weight_names
=
[
"weight"
])
self
.
quant_method
=
PackWeightMethod
(
weight_names
=
[
"weight"
])
if
bias
:
if
bias
:
...
...
sgl-kernel/csrc/cpu/decode.cpp
View file @
13fb8b54
...
@@ -308,6 +308,93 @@ struct tinygemm_kernel_nt<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
...
@@ -308,6 +308,93 @@ struct tinygemm_kernel_nt<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
};
};
#endif
#endif
#if defined(CPU_CAPABILITY_AVX512)
template
<
typename
index_t
,
int
BLOCK_M
,
int
BLOCK_N
>
struct
tinygemm_kernel_nt
<
at
::
Half
,
index_t
,
BLOCK_M
,
BLOCK_N
>
{
static
inline
void
apply
(
const
at
::
Half
*
__restrict__
A
,
const
at
::
Half
*
__restrict__
B
,
float
*
__restrict__
C
,
const
index_t
*
__restrict__
indices
,
float
scale
,
int64_t
lda
,
int64_t
ldb
,
int64_t
ldc
,
int64_t
K
,
int64_t
max_tokens
)
{
constexpr
int
ROWS
=
BLOCK_M
;
constexpr
int
COLS
=
BLOCK_N
;
__m512
va0
,
va1
;
__m512
vb0
[
COLS
],
vb1
[
COLS
];
__m512
vc
[
ROWS
*
COLS
];
__m512
vscale
=
_mm512_set1_ps
(
scale
);
auto
loadc
=
[
&
](
auto
i
)
{
vc
[
i
]
=
_mm512_setzero_ps
();
};
Unroll
<
ROWS
*
COLS
>
{}(
loadc
);
auto
compute
=
[
&
](
auto
i
,
int64_t
k
)
{
constexpr
int
row
=
i
/
COLS
;
constexpr
int
col
=
i
%
COLS
;
if
constexpr
(
col
==
0
)
{
__m512i
a16
=
_mm512_loadu_si512
((
__m512i
const
*
)(
A
+
row
*
lda
+
k
));
va0
=
CVT_FP16_TO_FP32
(
_mm512_extracti32x8_epi32
(
a16
,
0
));
va1
=
CVT_FP16_TO_FP32
(
_mm512_extracti32x8_epi32
(
a16
,
1
));
}
if
constexpr
(
row
==
0
)
{
int64_t
b_idx
=
indices
[
col
];
TORCH_CHECK
(
b_idx
<
max_tokens
,
"token index out of scope!"
);
__m512i
b16
=
_mm512_loadu_si512
((
__m512i
const
*
)(
B
+
b_idx
*
ldb
+
k
));
vb0
[
col
]
=
CVT_FP16_TO_FP32
(
_mm512_extracti32x8_epi32
(
b16
,
0
));
vb1
[
col
]
=
CVT_FP16_TO_FP32
(
_mm512_extracti32x8_epi32
(
b16
,
1
));
}
vc
[
i
]
=
_mm512_fmadd_ps
(
va0
,
vb0
[
col
],
_mm512_fmadd_ps
(
va1
,
vb1
[
col
],
vc
[
i
]));
};
auto
compute2
=
[
&
](
auto
i
,
int64_t
k
,
__mmask32
mask
)
{
constexpr
int
row
=
i
/
COLS
;
constexpr
int
col
=
i
%
COLS
;
if
constexpr
(
col
==
0
)
{
__m512i
a16
=
_mm512_maskz_loadu_epi16
(
mask
,
(
const
void
*
)(
A
+
row
*
lda
+
k
));
va0
=
CVT_FP16_TO_FP32
(
_mm512_extracti32x8_epi32
(
a16
,
0
));
va1
=
CVT_FP16_TO_FP32
(
_mm512_extracti32x8_epi32
(
a16
,
1
));
}
if
constexpr
(
row
==
0
)
{
int64_t
b_idx
=
indices
[
col
];
TORCH_CHECK
(
b_idx
<
max_tokens
,
"token index out of scope!"
);
__m512i
b16
=
_mm512_maskz_loadu_epi16
(
mask
,
(
const
void
*
)(
B
+
b_idx
*
ldb
+
k
));
vb0
[
col
]
=
CVT_FP16_TO_FP32
(
_mm512_extracti32x8_epi32
(
b16
,
0
));
vb1
[
col
]
=
CVT_FP16_TO_FP32
(
_mm512_extracti32x8_epi32
(
b16
,
1
));
}
vc
[
i
]
=
_mm512_fmadd_ps
(
va0
,
vb0
[
col
],
_mm512_fmadd_ps
(
va1
,
vb1
[
col
],
vc
[
i
]));
};
int64_t
k
=
0
;
for
(;
k
<=
K
-
32
;
k
+=
32
)
{
Unroll
<
ROWS
*
COLS
>
{}(
compute
,
k
);
}
int64_t
count
=
K
-
k
;
if
(
count
>
0
)
{
__mmask32
mask
=
(
1ULL
<<
count
)
-
1
;
Unroll
<
ROWS
*
COLS
>
{}(
compute2
,
k
,
mask
);
}
auto
storec
=
[
&
](
auto
i
)
{
constexpr
int
row
=
i
/
COLS
;
constexpr
int
col
=
i
%
COLS
;
C
[
row
*
ldc
+
col
]
=
_mm512_reduce_add_ps
(
_mm512_mul_ps
(
vc
[
i
],
vscale
));
};
Unroll
<
ROWS
*
COLS
>
{}(
storec
);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NT(MB_SIZE, NB_SIZE) \
#define LAUNCH_TINYGEMM_KERNEL_NT(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nt<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
tinygemm_kernel_nt<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, B, C + mb_start * ldc + nb_start, indices + nb_start, scale, lda, ldb, ldc, K, max_tokens);
A + mb_start * lda, B, C + mb_start * ldc + nb_start, indices + nb_start, scale, lda, ldb, ldc, K, max_tokens);
...
@@ -443,6 +530,87 @@ struct tinygemm_kernel_nn<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
...
@@ -443,6 +530,87 @@ struct tinygemm_kernel_nn<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
};
};
#endif
#endif
#if defined(CPU_CAPABILITY_AVX512)
template
<
typename
index_t
,
int
BLOCK_M
,
int
BLOCK_N
>
struct
tinygemm_kernel_nn
<
at
::
Half
,
index_t
,
BLOCK_M
,
BLOCK_N
>
{
static
inline
void
apply
(
const
float
*
__restrict__
A
,
const
at
::
Half
*
__restrict__
B
,
float
*
__restrict__
C
,
const
index_t
*
__restrict__
indices
,
const
float
*
__restrict__
scale
,
int64_t
lda
,
int64_t
ldb
,
int64_t
ldc
,
int64_t
K
,
int64_t
max_tokens
)
{
constexpr
int
ROWS
=
BLOCK_M
;
constexpr
int
COLS
=
BLOCK_N
/
16
;
__m512
va
;
__m512
vb
[
COLS
];
__m512
vc
[
ROWS
*
COLS
];
__m512
vscale
;
auto
loadc
=
[
&
](
auto
i
)
{
constexpr
int
row
=
i
/
COLS
;
constexpr
int
col
=
i
%
COLS
;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Warray-bounds"
if
constexpr
(
col
==
0
)
{
vscale
=
_mm512_set1_ps
(
scale
[
row
]);
}
#pragma GCC diagnostic pop
vc
[
i
]
=
_mm512_loadu_ps
(
C
+
row
*
ldc
+
col
*
16
);
vc
[
i
]
=
_mm512_mul_ps
(
vc
[
i
],
vscale
);
};
Unroll
<
ROWS
*
COLS
>
{}(
loadc
);
auto
compute
=
[
&
](
auto
i
,
int64_t
k
)
{
constexpr
int
row
=
i
/
COLS
;
constexpr
int
col
=
i
%
COLS
;
if
constexpr
(
col
==
0
)
{
va
=
_mm512_set1_ps
(
A
[
row
*
lda
+
k
]);
}
if
constexpr
(
row
==
0
)
{
if
(
k
+
1
<
K
)
{
int64_t
b_idx_prefetch
=
indices
[
k
+
1
];
_mm_prefetch
(
B
+
b_idx_prefetch
*
ldb
+
col
*
16
,
_MM_HINT_T0
);
}
int64_t
b_idx
=
indices
[
k
];
TORCH_CHECK
(
b_idx
<
max_tokens
,
"token index out of scope!"
);
// for COLS = 2, 4, 6, 8 use 512 bit load
// for COLS = 1, 3, 5, 7 use 256 bit load
if
constexpr
(
COLS
%
2
==
0
)
{
if
constexpr
(
col
%
2
==
0
)
{
__m512i
b16
=
_mm512_loadu_si512
(
reinterpret_cast
<
const
__m512i
*>
(
B
+
b_idx
*
ldb
+
col
*
16
));
vb
[
col
+
0
]
=
CVT_FP16_TO_FP32
(
_mm512_extracti32x8_epi32
(
b16
,
0
));
vb
[
col
+
1
]
=
CVT_FP16_TO_FP32
(
_mm512_extracti32x8_epi32
(
b16
,
1
));
}
}
else
{
__m256i
b16
=
_mm256_loadu_si256
(
reinterpret_cast
<
const
__m256i
*>
(
B
+
b_idx
*
ldb
+
col
*
16
));
vb
[
col
]
=
CVT_FP16_TO_FP32
(
b16
);
}
}
vc
[
i
]
=
_mm512_fmadd_ps
(
va
,
vb
[
col
],
vc
[
i
]);
};
for
(
int64_t
k
=
0
;
k
<
K
;
++
k
)
{
Unroll
<
ROWS
*
COLS
>
{}(
compute
,
k
);
}
auto
storec
=
[
&
](
auto
i
)
{
constexpr
int
row
=
i
/
COLS
;
constexpr
int
col
=
i
%
COLS
;
_mm512_storeu_ps
(
C
+
row
*
ldc
+
col
*
16
,
vc
[
i
]);
};
Unroll
<
ROWS
*
COLS
>
{}(
storec
);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
tinygemm_kernel_nn<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
A + mb_start * lda, \
...
@@ -512,9 +680,10 @@ void index_gemm_kernel_nt(
...
@@ -512,9 +680,10 @@ void index_gemm_kernel_nt(
return
;
return
;
}
}
// pattern: 1-6-24
// default pattern: 1-6-24
// FP16 pattern: 2-8-16
constexpr
int64_t
BLOCK_M
=
4
;
constexpr
int64_t
BLOCK_M
=
4
;
constexpr
int64_t
BLOCK_N
=
6
;
constexpr
int64_t
BLOCK_N
=
std
::
is_same_v
<
scalar_t
,
at
::
Half
>
?
4
:
6
;
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
NB
=
div_up
(
N
,
BLOCK_N
);
const
int64_t
NB
=
div_up
(
N
,
BLOCK_N
);
...
...
sgl-kernel/csrc/cpu/vec.h
View file @
13fb8b54
...
@@ -47,7 +47,7 @@ convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorize
...
@@ -47,7 +47,7 @@ convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorize
#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
#define CVT_FP16_TO_FP32(a) _mm512_cvtp
s
_p
h
(a
, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
)
#define CVT_FP16_TO_FP32(a) _mm512_cvtp
h
_p
s
(a)
// this doesn't handle NaN.
// this doesn't handle NaN.
inline
__m512bh
cvt_e4m3_bf16_intrinsic_no_nan
(
__m256i
fp8_vec
)
{
inline
__m512bh
cvt_e4m3_bf16_intrinsic_no_nan
(
__m256i
fp8_vec
)
{
...
...
test/srt/cpu/test_decode.py
View file @
13fb8b54
...
@@ -59,8 +59,7 @@ class TestDecodeAttention(CustomTestCase):
...
@@ -59,8 +59,7 @@ class TestDecodeAttention(CustomTestCase):
return
output
return
output
def
_test_grouped_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
,
D_V
,
device
):
def
_test_grouped_decode_attention_once
(
self
,
B
,
H_Q
,
H_KV
,
D
,
D_V
,
dtype
,
device
):
dtype
=
torch
.
bfloat16
# This represents the number of tokens already in the sequence
# This represents the number of tokens already in the sequence
seq_len
=
1024
seq_len
=
1024
total_tokens
=
B
*
seq_len
total_tokens
=
B
*
seq_len
...
@@ -158,9 +157,10 @@ class TestDecodeAttention(CustomTestCase):
...
@@ -158,9 +157,10 @@ class TestDecodeAttention(CustomTestCase):
]
]
for
B
,
H_Q
,
H_KV
,
D
,
D_V
in
configs
:
for
B
,
H_Q
,
H_KV
,
D
,
D_V
in
configs
:
self
.
_test_grouped_decode_attention_once
(
for
dtype
in
[
torch
.
bfloat16
,
torch
.
float16
]:
B
,
H_Q
,
H_KV
,
D
,
D_V
,
device
=
device
self
.
_test_grouped_decode_attention_once
(
)
B
,
H_Q
,
H_KV
,
D
,
D_V
,
dtype
=
dtype
,
device
=
device
)
def
test_grouped_decode_attention
(
self
):
def
test_grouped_decode_attention
(
self
):
self
.
_test_grouped_decode_attention
(
"cpu"
)
self
.
_test_grouped_decode_attention
(
"cpu"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment