Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4ba1eea8
Unverified
Commit
4ba1eea8
authored
May 23, 2025
by
blzheng
Committed by
GitHub
May 23, 2025
Browse files
Add fp8 qkv_proj_with_rope kernel for CPU in sgl-kernel and add UT (#6493)
parent
4685fbb8
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
483 additions
and
11 deletions
+483
-11
sgl-kernel/csrc/cpu/qkv_proj.cpp
sgl-kernel/csrc/cpu/qkv_proj.cpp
+124
-2
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
+9
-5
test/srt/cpu/test_decode.py
test/srt/cpu/test_decode.py
+2
-2
test/srt/cpu/test_extend.py
test/srt/cpu/test_extend.py
+2
-2
test/srt/cpu/test_qkv_proj_with_rope.py
test/srt/cpu/test_qkv_proj_with_rope.py
+346
-0
No files found.
sgl-kernel/csrc/cpu/qkv_proj.cpp
View file @
4ba1eea8
...
...
@@ -152,6 +152,85 @@ void segment_gemm_kernel_impl(
});
}
// [C0, C1] = A @ [B0, B1]
template
<
typename
scalar_t
>
void
segment_gemm_kernel_impl
(
scalar_t
*
__restrict__
C0
,
scalar_t
*
__restrict__
C1
,
const
scalar_t
*
__restrict__
A
,
const
at
::
Float8_e4m3fn
*
__restrict__
B0
,
const
at
::
Float8_e4m3fn
*
__restrict__
B1
,
const
float
*
__restrict__
Bs0
,
const
float
*
__restrict__
Bs1
,
int64_t
M
,
int64_t
N0
,
int64_t
N1
,
int64_t
K
,
int64_t
block_size_N
,
int64_t
block_size_K
)
{
constexpr
int64_t
BLOCK_M
=
block_size_m
();
constexpr
int64_t
BLOCK_N
=
block_size_n
();
const
int64_t
MB
=
div_up
(
M
,
BLOCK_M
);
const
int64_t
NB0
=
div_up
(
N0
,
BLOCK_N
);
const
int64_t
NB1
=
div_up
(
N1
,
BLOCK_N
);
const
int64_t
NB
=
NB0
+
NB1
;
const
int64_t
scale_size_K
=
div_up
(
K
,
block_size_K
);
const
int64_t
blocks_n_per_group
=
block_size_N
/
BLOCK_N
;
const
bool
use_brgemm
=
can_use_brgemm
<
at
::
Float8_e4m3fn
>
(
M
);
// parallel on [MB, NB0 + NB1]
at
::
parallel_for
(
0
,
MB
*
NB
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int64_t
mb
{
0
},
nb
{
0
};
data_index_init
(
begin
,
mb
,
MB
,
nb
,
NB
);
// for brgemm, use float32 for accumulate
alignas
(
64
)
float
Ctmp
[
BLOCK_M
*
BLOCK_N
];
// for brgemm when mat2 is float8_e4m3
alignas
(
64
)
scalar_t
Btmp
[
BLOCK_N
*
BLOCK_K
];
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
UNUSED
(
i
);
int
mb_start
=
mb
*
BLOCK_M
;
int
mb_size
=
std
::
min
(
M
-
mb_start
,
BLOCK_M
);
int
nb_start
=
nb
*
BLOCK_N
;
int
nb_size
=
BLOCK_N
;
const
at
::
Float8_e4m3fn
*
__restrict__
B
=
nb
<
NB0
?
B0
:
B1
;
const
float
*
__restrict__
Bs
=
nb
<
NB0
?
Bs0
:
Bs1
;
scalar_t
*
__restrict__
C
=
nb
<
NB0
?
C0
:
C1
;
int64_t
ldc
=
nb
<
NB0
?
N0
:
N1
;
int64_t
local_nb_start
=
nb
<
NB0
?
nb_start
:
nb_start
-
N0
;
int64_t
new_nb
=
nb
<
NB0
?
nb
:
nb
-
NB0
;
tinygemm_kernel
<
scalar_t
>
(
/* A */
A
+
mb_start
*
K
,
/* B */
B
+
local_nb_start
*
K
/* nb * BLOCK_N * K */
,
/* C */
C
+
mb_start
*
ldc
+
local_nb_start
,
/* Btmp*/
Btmp
,
/* Ctmp*/
Ctmp
,
/* Bs */
Bs
+
(
new_nb
/
blocks_n_per_group
)
*
scale_size_K
,
/* M */
mb_size
,
/* N */
nb_size
,
/* K */
K
,
/* lda */
K
,
/* ldb */
nb_size
,
/* ldc */
ldc
,
/* brg */
use_brgemm
,
/* block_size_K */
block_size_K
);
// move to the next index
data_index_step
(
mb
,
MB
,
nb
,
NB
);
}
if
(
use_brgemm
)
{
at
::
native
::
cpublas
::
brgemm_release
();
}
});
}
template
<
typename
scalar_t
>
inline
float
reduce
(
const
scalar_t
*
__restrict__
x
,
int64_t
size
)
{
using
bVec
=
at
::
vec
::
Vectorized
<
scalar_t
>
;
...
...
@@ -321,6 +400,15 @@ extern at::Tensor int8_scaled_mm_with_quant(
extern
void
bmm_cpu
(
at
::
Tensor
&
out
,
at
::
Tensor
&
mat1
,
at
::
Tensor
&
mat2
,
bool
is_vnni
,
const
std
::
optional
<
at
::
Tensor
>&
scale
);
extern
at
::
Tensor
fp8_scaled_mm_cpu
(
at
::
Tensor
&
mat1
,
at
::
Tensor
&
mat2
,
at
::
Tensor
&
scales2
,
std
::
vector
<
int64_t
>
block_size
,
const
std
::
optional
<
at
::
Tensor
>&
bias
,
at
::
ScalarType
out_dtype
,
bool
is_vnni
);
// NB: shapes in DeepDeek R1
//
// hidden_states : [num_seqs, hidden_size] [1, 7168]
...
...
@@ -343,10 +431,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at
::
Tensor
&
cos_sin_cache
,
double
eps
,
bool
use_int8_w8a8
,
bool
use_fp8_w8a16
,
std
::
optional
<
at
::
Tensor
>
q_a_proj_scale
,
std
::
optional
<
at
::
Tensor
>
q_b_proj_scale
,
std
::
optional
<
at
::
Tensor
>
kv_a_proj_scale
,
bool
is_vnni
)
{
bool
is_vnni
,
std
::
optional
<
std
::
vector
<
int64_t
>>
block_size
)
{
RECORD_FUNCTION
(
"sgl-kernel::qkv_proj_with_rope"
,
std
::
vector
<
c10
::
IValue
>
({
hidden_states
,
q_a_proj_weight
,
q_b_proj_weight
,
kv_a_proj_weight
,
w_kc
}));
...
...
@@ -394,7 +484,13 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
TORCH_CHECK
(
q_b_proj_scale
.
has_value
(),
"missing q_b_proj_scale for int8 w8a8."
);
TORCH_CHECK
(
kv_a_proj_scale
.
has_value
(),
"missing kv_a_proj_scale for int8 w8a8."
);
}
if
(
use_fp8_w8a16
)
{
TORCH_CHECK
(
q_a_proj_scale
.
has_value
(),
"missing q_a_proj_scale for fp8 w8a16."
);
TORCH_CHECK
(
q_b_proj_scale
.
has_value
(),
"missing q_b_proj_scale for fp8 w8a16."
);
TORCH_CHECK
(
kv_a_proj_scale
.
has_value
(),
"missing kv_a_proj_scale for fp8 w8a16."
);
TORCH_CHECK
(
block_size
.
has_value
(),
"missing block_size for fp8 w8a16."
);
TORCH_CHECK
(
block_size
.
value
().
size
()
==
2
,
"block_size should be 2D for fp8 w8a16."
);
}
// outputs and temp buffer
const
auto
options
=
hidden_states
.
options
();
auto
q_input
=
at
::
empty
({
num_seqs
,
num_heads
,
kv_lora_rank
+
qk_rope_head_dim
},
options
);
...
...
@@ -436,6 +532,29 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
q_lora_rank
,
kv_lora_rank
+
qk_rope_head_dim
,
hidden_size
);
}
else
if
(
use_fp8_w8a16
)
{
int64_t
block_size_N
=
block_size
.
value
()[
0
];
int64_t
block_size_K
=
block_size
.
value
()[
1
];
auto
q_a_proj_s
=
q_a_proj_scale
.
value
();
auto
kv_a_proj_s
=
kv_a_proj_scale
.
value
();
CHECK_EQ
(
q_a_proj_s
.
size
(
0
),
div_up
(
q_lora_rank
,
block_size_N
));
CHECK_EQ
(
q_a_proj_s
.
size
(
1
),
div_up
(
hidden_size
,
block_size_K
));
CHECK_EQ
(
kv_a_proj_s
.
size
(
0
),
div_up
(
kv_lora_rank
+
qk_rope_head_dim
,
block_size_N
));
CHECK_EQ
(
kv_a_proj_s
.
size
(
1
),
div_up
(
hidden_size
,
block_size_K
));
segment_gemm_kernel_impl
<
scalar_t
>
(
qa
.
data_ptr
<
scalar_t
>
(),
k_input
.
data_ptr
<
scalar_t
>
(),
hidden_states
.
data_ptr
<
scalar_t
>
(),
q_a_proj_weight
.
data_ptr
<
at
::
Float8_e4m3fn
>
(),
kv_a_proj_weight
.
data_ptr
<
at
::
Float8_e4m3fn
>
(),
q_a_proj_s
.
data_ptr
<
float
>
(),
kv_a_proj_s
.
data_ptr
<
float
>
(),
num_seqs
,
q_lora_rank
,
kv_lora_rank
+
qk_rope_head_dim
,
hidden_size
,
block_size_N
,
block_size_K
);
}
else
{
segment_gemm_kernel_impl
<
scalar_t
>
(
qa
.
data_ptr
<
scalar_t
>
(),
...
...
@@ -469,6 +588,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
std
::
optional
<
at
::
Tensor
>
bias
;
if
(
use_int8_w8a8
)
{
qb
=
int8_scaled_mm_with_quant
(
qa
,
q_b_proj_weight
,
q_b_proj_scale
.
value
(),
bias
,
at
::
kBFloat16
,
is_vnni
);
}
else
if
(
use_fp8_w8a16
)
{
qb
=
fp8_scaled_mm_cpu
(
qa
,
q_b_proj_weight
,
q_b_proj_scale
.
value
(),
block_size
.
value
(),
bias
,
at
::
kBFloat16
,
is_vnni
);
}
else
{
qb
=
weight_packed_linear
(
qa
,
q_b_proj_weight
,
bias
,
is_vnni
);
}
...
...
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
View file @
4ba1eea8
...
...
@@ -165,10 +165,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at
::
Tensor
&
cos_sin_cache
,
double
eps
,
bool
use_int8_w8a8
,
bool
use_fp8_w8a16
,
std
::
optional
<
at
::
Tensor
>
q_a_proj_scale
,
std
::
optional
<
at
::
Tensor
>
q_b_proj_scale
,
std
::
optional
<
at
::
Tensor
>
kv_a_proj_scale
,
bool
is_vnni
);
bool
is_vnni
,
std
::
optional
<
std
::
vector
<
int64_t
>>
block_size
);
// shared memory init
void
initialize
(
int64_t
size
,
int64_t
rank
);
...
...
@@ -209,8 +211,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// decode
m
.
def
(
"decode_attention_cpu(Tensor query, Tensor output, Tensor k_cache, Tensor v_cahce, Tensor attn_logits, Tensor "
"req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, float logit_cap) -> ()"
);
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor output, Tensor key, Tensor value, "
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
"float logit_cap) -> ()"
);
m
.
impl
(
"decode_attention_cpu"
,
torch
::
kCPU
,
&
decode_attention_cpu
);
// extend
...
...
@@ -265,8 +268,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
"kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, Tensor? q_a_proj_scale, Tensor? q_b_proj_scale, Tensor? "
"kv_a_proj_scale, bool is_vnni) -> (Tensor, Tensor, Tensor)"
);
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? q_a_proj_scale, Tensor? "
"q_b_proj_scale, Tensor? "
"kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)"
);
m
.
impl
(
"qkv_proj_with_rope"
,
torch
::
kCPU
,
&
qkv_proj_with_rope
);
// shared expert
...
...
test/srt/cpu/test_decode.py
View file @
4ba1eea8
import
unittest
import
sgl_kernel
import
torch
from
sgl_kernel.common_ops
import
decode_attention_cpu
as
decode_attention
from
torch.nn.functional
import
scaled_dot_product_attention
from
sglang.test.test_utils
import
CustomTestCase
...
...
@@ -105,7 +105,7 @@ class TestDecodeAttention(CustomTestCase):
v_buffer
=
v_buffer
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
key
=
key
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
value
=
value
.
transpose
(
0
,
1
).
contiguous
().
transpose
(
0
,
1
)
decode_attention
(
torch
.
ops
.
sgl_kernel
.
decode_attention
_cpu
(
q
,
k_buffer
,
v_buffer
,
...
...
test/srt/cpu/test_extend.py
View file @
4ba1eea8
import
unittest
import
sgl_kernel
import
torch
from
sgl_kernel.common_ops
import
extend_attention_cpu
as
extend_attention
from
torch.nn.functional
import
scaled_dot_product_attention
from
sglang.test.test_utils
import
CustomTestCase
...
...
@@ -157,7 +157,7 @@ class TestExtendAttention(CustomTestCase):
)
o_extend
=
torch
.
empty
((
extend_token_num
,
H_Q
,
DV
),
dtype
=
dtype
)
extend_attention
(
torch
.
ops
.
sgl_kernel
.
extend_attention
_cpu
(
q_extend
,
k_extend
,
v_extend
,
...
...
test/srt/cpu/test_qkv_proj_with_rope.py
0 → 100644
View file @
4ba1eea8
import
unittest
import
sgl_kernel
import
torch
from
utils
import
(
convert_weight
,
native_w8a8_per_token_matmul
,
per_token_quant_int8
,
precision
,
)
from
sglang.srt.layers.rotary_embedding
import
_apply_rotary_emb
from
sglang.test.test_utils
import
CustomTestCase
convert_weight_packed
=
torch
.
ops
.
sgl_kernel
.
convert_weight_packed
qkv_proj_with_rope
=
torch
.
ops
.
sgl_kernel
.
qkv_proj_with_rope
torch
.
manual_seed
(
0
)
# constants
kv_lora_rank
=
512
qk_head_dim
=
192
qk_nope_head_dim
=
128
qk_rope_head_dim
=
64
rotary_dim
=
qk_rope_head_dim
num_heads
=
22
q_lora_rank
=
1536
hidden_size
=
7168
B
=
1
eps
=
1e-6
def
layernorm
(
x
,
weight
,
variance_epsilon
=
1e-6
,
residual
=
None
):
orig_dtype
=
x
.
dtype
x
=
x
.
to
(
torch
.
float32
)
variance
=
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
variance_epsilon
)
return
(
x
*
weight
).
to
(
orig_dtype
)
def
rotary_emb
(
q_pe
,
k_pe
,
pos
,
cos_sin_cache
):
orig_dtype
=
q_pe
.
dtype
q_pe
=
q_pe
.
float
()
k_pe
=
k_pe
.
float
()
cos_sin_cache
=
cos_sin_cache
.
float
()
query_rot
=
q_pe
[...,
:
rotary_dim
]
key_rot
=
k_pe
[...,
:
rotary_dim
]
cos_sin
=
cos_sin_cache
[
pos
]
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
False
)
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
False
)
return
query_rot
.
to
(
orig_dtype
),
key_rot
.
to
(
orig_dtype
)
def
native_torch
(
q_input
,
hidden_states
,
q_a_proj_weight
,
norm_weight1
,
q_b_proj_weight
,
w_kc
,
kv_a_proj_weight
,
norm_weight2
,
pos
,
cos_sin_cache
,
):
q
=
torch
.
matmul
(
hidden_states
,
q_a_proj_weight
.
t
())
q
=
layernorm
(
q
,
norm_weight1
)
q
=
torch
.
matmul
(
q
,
q_b_proj_weight
.
t
()).
view
(
-
1
,
num_heads
,
qk_head_dim
)
q_nope
,
q_pe
=
q
.
split
([
qk_nope_head_dim
,
qk_rope_head_dim
],
dim
=-
1
)
q_nope_out
=
torch
.
bmm
(
q_nope
.
transpose
(
0
,
1
),
w_kc
)
q_input
[...,
:
kv_lora_rank
]
=
q_nope_out
.
transpose
(
0
,
1
)
latent_cache
=
torch
.
matmul
(
hidden_states
,
kv_a_proj_weight
.
t
())
v_input
=
latent_cache
[...,
:
kv_lora_rank
]
v_input
=
layernorm
(
v_input
.
contiguous
(),
norm_weight2
).
unsqueeze
(
1
)
k_input
=
latent_cache
.
unsqueeze
(
1
)
k_input
[...,
:
kv_lora_rank
]
=
v_input
k_pe
=
k_input
[...,
kv_lora_rank
:]
q_pe
,
k_pe
=
rotary_emb
(
q_pe
,
k_pe
,
pos
,
cos_sin_cache
)
q_input
[...,
kv_lora_rank
:]
=
q_pe
k_input
[...,
kv_lora_rank
:]
=
k_pe
return
q_input
,
k_input
,
v_input
def
native_torch_int8
(
q_input
,
hidden_states
,
w1_q
,
w1_s
,
norm_weight1
,
w2_q
,
w2_s
,
w_kc
,
w3_q
,
w3_s
,
norm_weight2
,
pos
,
cos_sin_cache
,
):
a_q
,
a_s
=
per_token_quant_int8
(
hidden_states
)
q
=
native_w8a8_per_token_matmul
(
a_q
,
w1_q
,
a_s
,
w1_s
,
None
,
torch
.
bfloat16
)
q
=
layernorm
(
q
,
norm_weight1
)
a_q
,
a_s
=
per_token_quant_int8
(
q
)
q
=
native_w8a8_per_token_matmul
(
a_q
,
w2_q
,
a_s
,
w2_s
,
None
,
torch
.
bfloat16
).
view
(
-
1
,
num_heads
,
qk_head_dim
)
q_nope
,
q_pe
=
q
.
split
([
qk_nope_head_dim
,
qk_rope_head_dim
],
dim
=-
1
)
q_nope_out
=
torch
.
bmm
(
q_nope
.
transpose
(
0
,
1
),
w_kc
)
q_input
[...,
:
kv_lora_rank
]
=
q_nope_out
.
transpose
(
0
,
1
)
a_q
,
a_s
=
per_token_quant_int8
(
hidden_states
)
latent_cache
=
native_w8a8_per_token_matmul
(
a_q
,
w3_q
,
a_s
,
w3_s
,
None
,
torch
.
bfloat16
)
v_input
=
latent_cache
[...,
:
kv_lora_rank
]
v_input
=
layernorm
(
v_input
.
contiguous
(),
norm_weight2
).
unsqueeze
(
1
)
k_input
=
latent_cache
.
unsqueeze
(
1
)
k_input
[...,
:
kv_lora_rank
]
=
v_input
k_pe
=
k_input
[...,
kv_lora_rank
:]
q_pe
,
k_pe
=
rotary_emb
(
q_pe
,
k_pe
,
pos
,
cos_sin_cache
)
q_input
[...,
kv_lora_rank
:]
=
q_pe
k_input
[...,
kv_lora_rank
:]
=
k_pe
return
q_input
,
k_input
,
v_input
class
TestQKVProjWithROPE
(
CustomTestCase
):
def
test_bf16_qkv_proj_with_rope
(
self
):
dtype
=
torch
.
bfloat16
hidden_states
=
torch
.
randn
(
B
,
hidden_size
,
dtype
=
dtype
)
/
hidden_size
q_input
=
torch
.
empty
(
B
,
num_heads
,
kv_lora_rank
+
qk_rope_head_dim
,
dtype
=
dtype
)
q_a_proj_weight
=
torch
.
randn
(
q_lora_rank
,
hidden_size
,
dtype
=
dtype
)
*
0.1
norm_weight1
=
torch
.
randn
(
q_lora_rank
,
dtype
=
dtype
)
q_b_proj_weight
=
(
torch
.
randn
(
num_heads
*
qk_head_dim
,
q_lora_rank
,
dtype
=
dtype
)
*
0.1
)
w_kc
=
torch
.
randn
(
num_heads
,
kv_lora_rank
,
qk_nope_head_dim
,
dtype
=
dtype
)
*
0.1
kv_a_proj_weight
=
(
torch
.
randn
(
kv_lora_rank
+
qk_rope_head_dim
,
hidden_size
,
dtype
=
dtype
)
*
0.1
)
norm_weight2
=
torch
.
randn
(
kv_lora_rank
,
dtype
=
dtype
)
pos
=
torch
.
randint
(
10
,
100
,
(
B
,))
cos_sin_cache
=
torch
.
randn
(
100
,
rotary_dim
,
dtype
=
dtype
)
q_ref
,
k_ref
,
v_ref
=
native_torch
(
q_input
,
hidden_states
,
q_a_proj_weight
,
norm_weight1
,
q_b_proj_weight
,
w_kc
.
transpose
(
1
,
2
),
kv_a_proj_weight
,
norm_weight2
,
pos
,
cos_sin_cache
,
)
qa_packed
=
convert_weight_packed
(
q_a_proj_weight
)
qb_packed
=
convert_weight_packed
(
q_b_proj_weight
)
kva_packed
=
convert_weight_packed
(
kv_a_proj_weight
)
wkc_packed
=
convert_weight_packed
(
w_kc
)
q_out
,
k_out
,
v_out
=
qkv_proj_with_rope
(
hidden_states
,
qa_packed
,
qb_packed
,
kva_packed
,
wkc_packed
,
norm_weight1
,
norm_weight2
,
pos
,
cos_sin_cache
,
eps
,
False
,
False
,
None
,
None
,
None
,
True
,
None
,
)
atol
=
rtol
=
precision
[
q_ref
.
dtype
]
self
.
assertTrue
(
torch
.
allclose
(
q_ref
,
q_out
,
atol
=
atol
,
rtol
=
rtol
))
self
.
assertTrue
(
torch
.
allclose
(
k_ref
,
k_out
,
atol
=
atol
,
rtol
=
rtol
))
self
.
assertTrue
(
torch
.
allclose
(
v_ref
,
v_out
,
atol
=
atol
,
rtol
=
rtol
))
def
test_int8_qkv_proj_with_rope
(
self
):
dtype
=
torch
.
bfloat16
hidden_states
=
torch
.
randn
(
B
,
hidden_size
,
dtype
=
dtype
)
/
hidden_size
q_input
=
torch
.
empty
(
B
,
num_heads
,
kv_lora_rank
+
qk_rope_head_dim
,
dtype
=
dtype
)
q_a_proj_weight
=
torch
.
randn
(
q_lora_rank
,
hidden_size
,
dtype
=
dtype
)
*
0.1
norm_weight1
=
torch
.
randn
(
q_lora_rank
,
dtype
=
dtype
)
q_b_proj_weight
=
(
torch
.
randn
(
num_heads
*
qk_head_dim
,
q_lora_rank
,
dtype
=
dtype
)
*
0.1
)
w_kc
=
torch
.
randn
(
num_heads
,
kv_lora_rank
,
qk_nope_head_dim
,
dtype
=
dtype
)
*
0.1
kv_a_proj_weight
=
(
torch
.
randn
(
kv_lora_rank
+
qk_rope_head_dim
,
hidden_size
,
dtype
=
dtype
)
*
0.1
)
norm_weight2
=
torch
.
randn
(
kv_lora_rank
,
dtype
=
dtype
)
pos
=
torch
.
randint
(
10
,
100
,
(
B
,))
cos_sin_cache
=
torch
.
randn
(
100
,
rotary_dim
,
dtype
=
dtype
)
w1_q
,
w1_s
=
per_token_quant_int8
(
q_a_proj_weight
)
w2_q
,
w2_s
=
per_token_quant_int8
(
q_b_proj_weight
)
w3_q
,
w3_s
=
per_token_quant_int8
(
kv_a_proj_weight
)
q_ref
,
k_ref
,
v_ref
=
native_torch_int8
(
q_input
,
hidden_states
,
w1_q
,
w1_s
,
norm_weight1
,
w2_q
,
w2_s
,
w_kc
.
transpose
(
1
,
2
),
w3_q
,
w3_s
,
norm_weight2
,
pos
,
cos_sin_cache
,
)
w1_q_packed
=
convert_weight_packed
(
w1_q
)
w2_q_packed
=
convert_weight_packed
(
w2_q
)
w3_q_packed
=
convert_weight_packed
(
w3_q
)
wkc_packed
=
convert_weight_packed
(
w_kc
)
q_out
,
k_out
,
v_out
=
qkv_proj_with_rope
(
hidden_states
,
w1_q_packed
,
w2_q_packed
,
w3_q_packed
,
wkc_packed
,
norm_weight1
,
norm_weight2
,
pos
,
cos_sin_cache
,
eps
,
True
,
False
,
w1_s
,
w2_s
,
w3_s
,
True
,
None
,
)
atol
=
rtol
=
precision
[
q_ref
.
dtype
]
self
.
assertTrue
(
torch
.
allclose
(
q_ref
,
q_out
,
atol
=
atol
,
rtol
=
rtol
))
self
.
assertTrue
(
torch
.
allclose
(
k_ref
,
k_out
,
atol
=
atol
,
rtol
=
rtol
))
self
.
assertTrue
(
torch
.
allclose
(
v_ref
,
v_out
,
atol
=
atol
,
rtol
=
rtol
))
def
test_fp8_qkv_proj_with_rope
(
self
):
dtype
=
torch
.
bfloat16
hidden_states
=
torch
.
randn
(
B
,
hidden_size
,
dtype
=
dtype
)
/
hidden_size
q_input
=
torch
.
empty
(
B
,
num_heads
,
kv_lora_rank
+
qk_rope_head_dim
,
dtype
=
dtype
)
q_a_proj_weight
=
torch
.
randn
(
q_lora_rank
,
hidden_size
,
dtype
=
dtype
)
*
0.1
norm_weight1
=
torch
.
randn
(
q_lora_rank
,
dtype
=
dtype
)
q_b_proj_weight
=
(
torch
.
randn
(
num_heads
*
qk_head_dim
,
q_lora_rank
,
dtype
=
dtype
)
*
0.1
)
w_kc
=
torch
.
randn
(
num_heads
,
kv_lora_rank
,
qk_nope_head_dim
,
dtype
=
dtype
)
*
0.1
kv_a_proj_weight
=
(
torch
.
randn
(
kv_lora_rank
+
qk_rope_head_dim
,
hidden_size
,
dtype
=
dtype
)
*
0.1
)
norm_weight2
=
torch
.
randn
(
kv_lora_rank
,
dtype
=
dtype
)
pos
=
torch
.
randint
(
10
,
100
,
(
B
,))
cos_sin_cache
=
torch
.
randn
(
100
,
rotary_dim
,
dtype
=
dtype
)
scale_block_size_N
=
128
scale_block_size_K
=
128
fp8_q_a_proj_weight
,
q_a_proj_weight_scale_inv
,
q_a_proj_weight_dq
=
(
convert_weight
(
q_a_proj_weight
,
[
scale_block_size_N
,
scale_block_size_K
],
torch
.
bfloat16
,
)
)
fp8_q_b_proj_weight
,
q_b_proj_weight_scale_inv
,
q_b_proj_weight_dq
=
(
convert_weight
(
q_b_proj_weight
,
[
scale_block_size_N
,
scale_block_size_K
],
torch
.
bfloat16
,
)
)
(
fp8_kv_a_proj_with_mqa_weight
,
kv_a_proj_with_mqa_weight_scale_inv
,
kv_a_proj_with_mqa_weight_dq
,
)
=
convert_weight
(
kv_a_proj_weight
,
[
scale_block_size_N
,
scale_block_size_K
],
torch
.
bfloat16
)
q_ref
,
k_ref
,
v_ref
=
native_torch
(
q_input
,
hidden_states
,
q_a_proj_weight_dq
,
norm_weight1
,
q_b_proj_weight_dq
,
w_kc
.
transpose
(
1
,
2
),
kv_a_proj_with_mqa_weight_dq
,
norm_weight2
,
pos
,
cos_sin_cache
,
)
fp8_q_a_proj_weight
=
convert_weight_packed
(
fp8_q_a_proj_weight
)
fp8_q_b_proj_weight
=
convert_weight_packed
(
fp8_q_b_proj_weight
)
fp8_kv_a_proj_with_mqa_weight
=
convert_weight_packed
(
fp8_kv_a_proj_with_mqa_weight
)
w_kc
=
convert_weight_packed
(
w_kc
)
q_out
,
k_out
,
v_out
=
qkv_proj_with_rope
(
hidden_states
,
fp8_q_a_proj_weight
,
fp8_q_b_proj_weight
,
fp8_kv_a_proj_with_mqa_weight
,
w_kc
,
norm_weight1
,
norm_weight2
,
pos
,
cos_sin_cache
,
eps
,
False
,
True
,
q_a_proj_weight_scale_inv
.
float
(),
q_b_proj_weight_scale_inv
.
float
(),
kv_a_proj_with_mqa_weight_scale_inv
.
float
(),
True
,
[
scale_block_size_N
,
scale_block_size_K
],
)
atol
=
rtol
=
precision
[
q_ref
.
dtype
]
self
.
assertTrue
(
torch
.
allclose
(
q_ref
,
q_out
,
atol
=
atol
,
rtol
=
rtol
))
self
.
assertTrue
(
torch
.
allclose
(
k_ref
,
k_out
,
atol
=
atol
,
rtol
=
rtol
))
self
.
assertTrue
(
torch
.
allclose
(
v_ref
,
v_out
,
atol
=
atol
,
rtol
=
rtol
))
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment