Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9aea2555
Unverified
Commit
9aea2555
authored
Aug 12, 2025
by
fzyzcjy
Committed by
GitHub
Aug 12, 2025
Browse files
Fuse writing KV buffer into rope kernel (part 1: sgl-kernel) (#9077)
parent
fcc11e5e
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1151 additions
and
193 deletions
+1151
-193
python/sglang/srt/bench_utils.py
python/sglang/srt/bench_utils.py
+137
-0
sgl-kernel/benchmark/bench_rotary_embedding.py
sgl-kernel/benchmark/bench_rotary_embedding.py
+96
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+2
-1
sgl-kernel/csrc/elementwise/pos_enc.cuh
sgl-kernel/csrc/elementwise/pos_enc.cuh
+431
-0
sgl-kernel/csrc/elementwise/rope.cu
sgl-kernel/csrc/elementwise/rope.cu
+99
-27
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+5
-1
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/elementwise.py
sgl-kernel/python/sgl_kernel/elementwise.py
+63
-5
sgl-kernel/python/sgl_kernel/testing/__init__.py
sgl-kernel/python/sgl_kernel/testing/__init__.py
+0
-0
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
+217
-0
sgl-kernel/tests/test_rotary_embedding.py
sgl-kernel/tests/test_rotary_embedding.py
+100
-159
No files found.
python/sglang/srt/bench_utils.py
0 → 100644
View file @
9aea2555
import
os
import
sys
from
contextlib
import
nullcontext
import
torch
# NOTE copied and modified from DeepGEMM
class
suppress_stdout_stderr
:
def
__enter__
(
self
):
self
.
outnull_file
=
open
(
os
.
devnull
,
"w"
)
self
.
errnull_file
=
open
(
os
.
devnull
,
"w"
)
self
.
old_stdout_fileno_undup
=
sys
.
stdout
.
fileno
()
self
.
old_stderr_fileno_undup
=
sys
.
stderr
.
fileno
()
self
.
old_stdout_fileno
=
os
.
dup
(
sys
.
stdout
.
fileno
())
self
.
old_stderr_fileno
=
os
.
dup
(
sys
.
stderr
.
fileno
())
self
.
old_stdout
=
sys
.
stdout
self
.
old_stderr
=
sys
.
stderr
os
.
dup2
(
self
.
outnull_file
.
fileno
(),
self
.
old_stdout_fileno_undup
)
os
.
dup2
(
self
.
errnull_file
.
fileno
(),
self
.
old_stderr_fileno_undup
)
sys
.
stdout
=
self
.
outnull_file
sys
.
stderr
=
self
.
errnull_file
return
self
def
__exit__
(
self
,
*
_
):
sys
.
stdout
=
self
.
old_stdout
sys
.
stderr
=
self
.
old_stderr
os
.
dup2
(
self
.
old_stdout_fileno
,
self
.
old_stdout_fileno_undup
)
os
.
dup2
(
self
.
old_stderr_fileno
,
self
.
old_stderr_fileno_undup
)
os
.
close
(
self
.
old_stdout_fileno
)
os
.
close
(
self
.
old_stderr_fileno
)
self
.
outnull_file
.
close
()
self
.
errnull_file
.
close
()
# NOTE copied and modified from DeepGEMM
def
bench_kineto
(
fn
,
kernel_names
,
num_tests
:
int
=
30
,
suppress_kineto_output
:
bool
=
False
,
trace_path
:
str
=
None
,
flush_l2
:
bool
=
True
,
with_multiple_kernels
:
bool
=
False
,
):
# Conflict with Nsight Systems
using_nsys
=
int
(
os
.
environ
.
get
(
"SGLANG_NSYS_PROFILING"
,
0
))
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
flush_l2_size
=
int
(
8e9
//
4
)
# For some auto-tuning kernels with prints
fn
()
# Profile
suppress
=
(
suppress_stdout_stderr
if
suppress_kineto_output
and
not
using_nsys
else
nullcontext
)
with
suppress
():
schedule
=
(
torch
.
profiler
.
schedule
(
wait
=
0
,
warmup
=
1
,
active
=
1
,
repeat
=
1
)
if
not
using_nsys
else
None
)
profiler
=
(
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
],
schedule
=
schedule
)
if
not
using_nsys
else
nullcontext
()
)
with
profiler
:
for
i
in
range
(
2
):
for
_
in
range
(
num_tests
):
if
flush_l2
:
torch
.
empty
(
flush_l2_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
).
zero_
()
fn
()
if
not
using_nsys
:
profiler
.
step
()
# Return 1 if using Nsight Systems
if
using_nsys
:
return
1
# Parse the profiling table
assert
isinstance
(
kernel_names
,
str
)
or
isinstance
(
kernel_names
,
tuple
)
is_tuple
=
isinstance
(
kernel_names
,
tuple
)
prof_lines
=
(
profiler
.
key_averages
()
.
table
(
sort_by
=
"cuda_time_total"
,
max_name_column_width
=
100
)
.
split
(
"
\n
"
)
)
kernel_names
=
(
kernel_names
,)
if
isinstance
(
kernel_names
,
str
)
else
kernel_names
assert
all
([
isinstance
(
name
,
str
)
for
name
in
kernel_names
])
if
not
with_multiple_kernels
:
for
name
in
kernel_names
:
assert
(
sum
([
name
in
line
for
line
in
prof_lines
])
==
1
),
f
"Errors of the kernel
{
name
}
in the profiling table (table:
{
prof_lines
}
)"
# Save chrome traces
if
trace_path
is
not
None
:
profiler
.
export_chrome_trace
(
trace_path
)
# Return average kernel times
units
=
{
"ms"
:
1e3
,
"us"
:
1e6
}
kernel_times
=
[]
for
name
in
kernel_names
:
total_time
=
0
total_num
=
0
for
line
in
prof_lines
:
if
name
in
line
:
time_str
=
line
.
split
()[
-
2
]
num_str
=
line
.
split
()[
-
1
]
for
unit
,
scale
in
units
.
items
():
if
unit
in
time_str
:
total_time
+=
(
float
(
time_str
.
replace
(
unit
,
""
))
/
scale
*
int
(
num_str
)
)
total_num
+=
int
(
num_str
)
break
kernel_times
.
append
(
total_time
/
total_num
)
return
tuple
(
kernel_times
)
if
is_tuple
else
kernel_times
[
0
]
sgl-kernel/benchmark/bench_rotary_embedding.py
0 → 100644
View file @
9aea2555
import
itertools
import
torch
import
triton
from
sgl_kernel
import
FusedSetKVBufferArg
from
sgl_kernel.testing.rotary_embedding
import
(
FlashInferRotaryEmbedding
,
MHATokenToKVPool
,
RotaryEmbedding
,
create_inputs
,
)
from
sglang.srt.bench_utils
import
bench_kineto
configs
=
[
(
batch_size
,
seq_len
,
save_kv_cache
)
for
batch_size
,
seq_len
in
(
(
1
,
1
),
(
32
,
1
),
(
128
,
1
),
(
512
,
1
),
(
2
,
512
),
(
4
,
4096
),
)
for
save_kv_cache
in
(
False
,
True
)
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
,
"save_kv_cache"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"sglang"
],
line_names
=
[
"SGL Kernel"
],
styles
=
[(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"bench_rotary_embedding"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
seq_len
,
save_kv_cache
,
provider
):
device
=
torch
.
device
(
"cuda"
)
num_q_heads
=
32
num_kv_heads
=
8
head_size
=
64
dtype
=
torch
.
bfloat16
config
=
dict
(
head_size
=
head_size
,
rotary_dim
=
64
,
max_position_embeddings
=
4096
,
base
=
8000
,
is_neox_style
=
True
,
dtype
=
dtype
,
)
rope_flashinfer
=
FlashInferRotaryEmbedding
(
**
config
).
to
(
device
)
pool_flashinfer
=
MHATokenToKVPool
(
head_num
=
num_kv_heads
,
head_dim
=
head_size
)
inputs
=
create_inputs
(
head_size
=
head_size
,
batch_size
=
batch_size
,
seq_len
=
seq_len
,
device
=
device
,
dtype
=
dtype
,
num_q_heads
=
num_q_heads
,
num_kv_heads
=
num_kv_heads
,
)
query_flashinfer
,
key_flashinfer
=
inputs
[
"query"
].
clone
(),
inputs
[
"key"
].
clone
()
bench_fn
=
lambda
:
rope_flashinfer
.
forward_cuda
(
inputs
[
"pos_ids"
],
query_flashinfer
,
key_flashinfer
,
fused_set_kv_buffer_arg
=
(
FusedSetKVBufferArg
(
value
=
inputs
[
"value"
],
k_buffer
=
pool_flashinfer
.
k_buffer
[
0
].
view
(
-
1
,
num_kv_heads
*
head_size
),
v_buffer
=
pool_flashinfer
.
v_buffer
[
0
].
view
(
-
1
,
num_kv_heads
*
head_size
),
k_scale
=
None
,
v_scale
=
None
,
cache_loc
=
inputs
[
"out_cache_loc"
],
)
if
save_kv_cache
else
None
),
)
time_s
=
bench_kineto
(
bench_fn
,
kernel_names
=
"BatchQKApplyRotaryPosIds"
)
return
time_s
*
1e6
if
__name__
==
"__main__"
:
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/csrc/common_extension.cc
View file @
9aea2555
...
@@ -89,7 +89,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
...
@@ -89,7 +89,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
m
.
def
(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()"
);
"Tensor pos_ids, bool interleave, int cuda_stream, "
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()"
);
m
.
impl
(
"apply_rope_pos_ids_cos_sin_cache"
,
torch
::
kCUDA
,
&
apply_rope_pos_ids_cos_sin_cache
);
m
.
impl
(
"apply_rope_pos_ids_cos_sin_cache"
,
torch
::
kCUDA
,
&
apply_rope_pos_ids_cos_sin_cache
);
/*
/*
...
...
sgl-kernel/csrc/elementwise/pos_enc.cuh
0 → 100644
View file @
9aea2555
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SGL_POS_ENC_CUH_
#define SGL_POS_ENC_CUH_
#include <flashinfer/pos_enc.cuh> // upstream
namespace
flashinfer
{
namespace
kv_buffer_saver
{
template
<
typename
DType
,
typename
IdType
,
uint32_t
vec_size
>
__device__
__forceinline__
void
prepare
(
vec_t
<
float
,
vec_size
>&
v_vec
,
IdType
&
kv_cache_offset
,
DType
*
v
,
IdType
*
kv_cache_loc
,
uint32_t
idx
,
uint32_t
tx
,
uint32_t
kv_head_idx
,
size_t
v_stride_n
,
size_t
v_stride_h
)
{
kv_cache_offset
=
kv_cache_loc
[
idx
];
DType
*
v_ptr
=
v
+
get_elem_offset_impl
(
idx
,
kv_head_idx
,
0
,
v_stride_n
,
v_stride_h
);
v_vec
.
cast_load
(
v_ptr
+
tx
*
vec_size
);
}
template
<
typename
DType
,
typename
IdType
,
uint32_t
vec_size
>
__device__
__forceinline__
void
save
(
IdType
&
kv_cache_offset
,
vec_t
<
float
,
vec_size
>&
k_vec
,
vec_t
<
float
,
vec_size
>&
v_vec
,
DType
*
k_buffer
,
DType
*
v_buffer
,
uint32_t
idx
,
uint32_t
tx
,
uint32_t
kv_head_idx
,
size_t
k_buffer_stride_n
,
size_t
k_buffer_stride_h
,
size_t
v_buffer_stride_n
,
size_t
v_buffer_stride_h
)
{
DType
*
k_buffer_ptr
=
k_buffer
+
get_elem_offset_impl
(
kv_cache_offset
,
kv_head_idx
,
0
,
k_buffer_stride_n
,
k_buffer_stride_h
);
DType
*
v_buffer_ptr
=
v_buffer
+
get_elem_offset_impl
(
kv_cache_offset
,
kv_head_idx
,
0
,
v_buffer_stride_n
,
v_buffer_stride_h
);
k_vec
.
cast_store
(
k_buffer_ptr
+
tx
*
vec_size
);
v_vec
.
cast_store
(
v_buffer_ptr
+
tx
*
vec_size
);
}
}
// namespace kv_buffer_saver
template
<
bool
save_kv_cache
,
bool
interleave
,
uint32_t
head_dim
,
uint32_t
vec_size
,
uint32_t
bdx
,
typename
DType
,
typename
IdType
>
__global__
void
BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel
(
DType
*
q
,
DType
*
k
,
DType
*
v
,
DType
*
q_rope
,
DType
*
k_rope
,
DType
*
k_buffer
,
DType
*
v_buffer
,
float
*
__restrict__
cos_sin_cache
,
IdType
*
__restrict__
pos_ids
,
uint32_t
nnz
,
uint32_t
num_qo_heads
,
uint32_t
num_kv_heads
,
uint32_t
rotary_dim
,
size_t
q_stride_n
,
size_t
q_stride_h
,
size_t
k_stride_n
,
size_t
k_stride_h
,
size_t
v_stride_n
,
size_t
v_stride_h
,
size_t
q_rope_stride_n
,
size_t
q_rope_stride_h
,
size_t
k_rope_stride_n
,
size_t
k_rope_stride_h
,
size_t
k_buffer_stride_n
,
size_t
k_buffer_stride_h
,
size_t
v_buffer_stride_n
,
size_t
v_buffer_stride_h
,
IdType
*
__restrict__
kv_cache_loc
)
{
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
,
ty
=
threadIdx
.
y
;
uint32_t
by
=
blockIdx
.
y
;
const
uint32_t
bdy
=
blockDim
.
y
;
vec_t
<
float
,
vec_size
>
cos
,
sin
;
if
(
bx
*
bdy
+
ty
<
nnz
)
{
const
uint32_t
idx
=
bx
*
bdy
+
ty
;
const
IdType
pos
=
pos_ids
[
idx
];
const
int
half_rotary_dim
=
rotary_dim
/
2
;
// 1. if interleave:
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
// 2. if not interleave
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
if
(
tx
*
vec_size
<
rotary_dim
)
{
int
sin_offset
=
rotary_dim
/
2
;
int
vec_idx
;
if
constexpr
(
interleave
)
{
vec_idx
=
(
tx
*
vec_size
)
/
2
;
// Force integer division
}
else
{
vec_idx
=
(
tx
*
vec_size
)
%
half_rotary_dim
;
// Use half_rotary_dim
}
cos
.
load
(
cos_sin_cache
+
(
pos
*
rotary_dim
)
+
vec_idx
);
sin
.
load
(
cos_sin_cache
+
(
pos
*
rotary_dim
)
+
(
sin_offset
+
vec_idx
));
}
if
(
by
<
num_qo_heads
)
{
uint32_t
qo_head_idx
=
by
;
DType
*
q_ptr
=
q
+
get_elem_offset_impl
(
idx
,
qo_head_idx
,
0
,
q_stride_n
,
q_stride_h
);
DType
*
q_rope_ptr
=
q_rope
+
get_elem_offset_impl
(
idx
,
qo_head_idx
,
0
,
q_rope_stride_n
,
q_rope_stride_h
);
vec_t
<
float
,
vec_size
>
q_vec
;
if
constexpr
(
interleave
)
{
q_vec
=
vec_apply_llama_rope_cos_sin_interleave_reuse_half
<
vec_size
,
bdx
>
(
q_ptr
,
cos
,
sin
,
rotary_dim
);
}
else
{
q_vec
=
vec_apply_llama_rope_cos_sin
<
vec_size
,
bdx
>
(
q_ptr
,
cos
,
sin
,
rotary_dim
);
}
q_vec
.
cast_store
(
q_rope_ptr
+
tx
*
vec_size
);
}
else
{
uint32_t
kv_head_idx
=
by
-
num_qo_heads
;
DType
*
k_ptr
=
k
+
get_elem_offset_impl
(
idx
,
kv_head_idx
,
0
,
k_stride_n
,
k_stride_h
);
DType
*
k_rope_ptr
=
k_rope
+
get_elem_offset_impl
(
idx
,
kv_head_idx
,
0
,
k_rope_stride_n
,
k_rope_stride_h
);
vec_t
<
float
,
vec_size
>
v_vec
;
IdType
kv_cache_offset
;
if
constexpr
(
save_kv_cache
)
{
kv_buffer_saver
::
prepare
<
DType
,
IdType
,
vec_size
>
(
v_vec
,
kv_cache_offset
,
v
,
kv_cache_loc
,
idx
,
tx
,
kv_head_idx
,
v_stride_n
,
v_stride_h
);
}
vec_t
<
float
,
vec_size
>
k_vec
;
if
constexpr
(
interleave
)
{
k_vec
=
vec_apply_llama_rope_cos_sin_interleave_reuse_half
<
vec_size
,
bdx
>
(
k_ptr
,
cos
,
sin
,
rotary_dim
);
}
else
{
k_vec
=
vec_apply_llama_rope_cos_sin
<
vec_size
,
bdx
>
(
k_ptr
,
cos
,
sin
,
rotary_dim
);
}
k_vec
.
cast_store
(
k_rope_ptr
+
tx
*
vec_size
);
if
constexpr
(
save_kv_cache
)
{
kv_buffer_saver
::
save
<
DType
,
IdType
,
vec_size
>
(
kv_cache_offset
,
k_vec
,
v_vec
,
k_buffer
,
v_buffer
,
idx
,
tx
,
kv_head_idx
,
k_buffer_stride_n
,
k_buffer_stride_h
,
v_buffer_stride_n
,
v_buffer_stride_h
);
}
}
}
}
template
<
bool
save_kv_cache
,
bool
interleave
,
uint32_t
head_dim
,
uint32_t
vec_size
,
uint32_t
bdx
,
typename
DType
,
typename
IdType
>
__global__
void
BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel
(
DType
*
q
,
DType
*
k
,
DType
*
v
,
DType
*
q_rope
,
DType
*
k_rope
,
DType
*
k_buffer
,
DType
*
v_buffer
,
float
*
__restrict__
cos_sin_cache
,
IdType
*
__restrict__
pos_ids
,
uint32_t
nnz
,
uint32_t
num_qo_heads
,
uint32_t
num_kv_heads
,
uint32_t
rotary_dim
,
size_t
q_stride_n
,
size_t
q_stride_h
,
size_t
k_stride_n
,
size_t
k_stride_h
,
size_t
v_stride_n
,
size_t
v_stride_h
,
size_t
q_rope_stride_n
,
size_t
q_rope_stride_h
,
size_t
k_rope_stride_n
,
size_t
k_rope_stride_h
,
size_t
k_buffer_stride_n
,
size_t
k_buffer_stride_h
,
size_t
v_buffer_stride_n
,
size_t
v_buffer_stride_h
,
IdType
*
__restrict__
kv_cache_loc
)
{
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
,
ty
=
threadIdx
.
y
;
const
uint32_t
bdy
=
blockDim
.
y
;
vec_t
<
float
,
vec_size
>
cos
,
sin
;
if
(
bx
*
bdy
+
ty
<
nnz
)
{
const
uint32_t
idx
=
bx
*
bdy
+
ty
;
const
IdType
pos
=
pos_ids
[
idx
];
const
int
half_rotary_dim
=
rotary_dim
/
2
;
// 1. if interleave:
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
// 2. if not interleave
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
if
(
tx
*
vec_size
<
rotary_dim
)
{
int
sin_offset
=
rotary_dim
/
2
;
int
vec_idx
;
if
constexpr
(
interleave
)
{
vec_idx
=
(
tx
*
vec_size
)
/
2
;
// Force integer division
}
else
{
vec_idx
=
(
tx
*
vec_size
)
%
half_rotary_dim
;
// Use half_rotary_dim
}
cos
.
load
(
cos_sin_cache
+
(
pos
*
rotary_dim
)
+
vec_idx
);
sin
.
load
(
cos_sin_cache
+
(
pos
*
rotary_dim
)
+
(
sin_offset
+
vec_idx
));
}
// not to unroll the loop, because num head might be large and might lead to worse performance
#pragma unroll 1
for
(
uint32_t
qo_head_idx
=
0
;
qo_head_idx
<
num_qo_heads
;
++
qo_head_idx
)
{
DType
*
q_ptr
=
q
+
get_elem_offset_impl
(
idx
,
qo_head_idx
,
0
,
q_stride_n
,
q_stride_h
);
DType
*
q_rope_ptr
=
q_rope
+
get_elem_offset_impl
(
idx
,
qo_head_idx
,
0
,
q_rope_stride_n
,
q_rope_stride_h
);
vec_t
<
float
,
vec_size
>
q_vec
;
if
constexpr
(
interleave
)
{
q_vec
=
vec_apply_llama_rope_cos_sin_interleave_reuse_half
<
vec_size
,
bdx
>
(
q_ptr
,
cos
,
sin
,
rotary_dim
);
}
else
{
q_vec
=
vec_apply_llama_rope_cos_sin
<
vec_size
,
bdx
>
(
q_ptr
,
cos
,
sin
,
rotary_dim
);
}
q_vec
.
cast_store
(
q_rope_ptr
+
tx
*
vec_size
);
}
#pragma unroll 1
for
(
uint32_t
kv_head_idx
=
0
;
kv_head_idx
<
num_kv_heads
;
++
kv_head_idx
)
{
DType
*
k_ptr
=
k
+
get_elem_offset_impl
(
idx
,
kv_head_idx
,
0
,
k_stride_n
,
k_stride_h
);
DType
*
k_rope_ptr
=
k_rope
+
get_elem_offset_impl
(
idx
,
kv_head_idx
,
0
,
k_rope_stride_n
,
k_rope_stride_h
);
vec_t
<
float
,
vec_size
>
v_vec
;
IdType
kv_cache_offset
;
if
constexpr
(
save_kv_cache
)
{
kv_buffer_saver
::
prepare
<
DType
,
IdType
,
vec_size
>
(
v_vec
,
kv_cache_offset
,
v
,
kv_cache_loc
,
idx
,
tx
,
kv_head_idx
,
v_stride_n
,
v_stride_h
);
}
vec_t
<
float
,
vec_size
>
k_vec
;
if
constexpr
(
interleave
)
{
k_vec
=
vec_apply_llama_rope_cos_sin_interleave_reuse_half
<
vec_size
,
bdx
>
(
k_ptr
,
cos
,
sin
,
rotary_dim
);
}
else
{
k_vec
=
vec_apply_llama_rope_cos_sin
<
vec_size
,
bdx
>
(
k_ptr
,
cos
,
sin
,
rotary_dim
);
}
k_vec
.
cast_store
(
k_rope_ptr
+
tx
*
vec_size
);
if
constexpr
(
save_kv_cache
)
{
kv_buffer_saver
::
save
<
DType
,
IdType
,
vec_size
>
(
kv_cache_offset
,
k_vec
,
v_vec
,
k_buffer
,
v_buffer
,
idx
,
tx
,
kv_head_idx
,
k_buffer_stride_n
,
k_buffer_stride_h
,
v_buffer_stride_n
,
v_buffer_stride_h
);
}
}
}
}
#define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \
if (save_kv_cache) { \
const bool SAVE_KV_CACHE = true; \
__VA_ARGS__ \
} else { \
const bool SAVE_KV_CACHE = false; \
__VA_ARGS__ \
}
template
<
typename
DType
,
typename
IdType
>
cudaError_t
BatchQKApplyRotaryPosIdsCosSinCacheEnhanced
(
DType
*
q
,
DType
*
k
,
DType
*
v
,
DType
*
q_rope
,
DType
*
k_rope
,
DType
*
k_buffer
,
DType
*
v_buffer
,
float
*
cos_sin_cache
,
IdType
*
pos_ids
,
uint32_t
nnz
,
uint32_t
num_qo_heads
,
uint32_t
num_kv_heads
,
uint32_t
rotary_dim
,
uint32_t
head_dim
,
size_t
q_stride_n
,
size_t
q_stride_h
,
size_t
k_stride_n
,
size_t
k_stride_h
,
size_t
v_stride_n
,
size_t
v_stride_h
,
size_t
q_rope_stride_n
,
size_t
q_rope_stride_h
,
size_t
k_rope_stride_n
,
size_t
k_rope_stride_h
,
size_t
k_buffer_stride_n
,
size_t
k_buffer_stride_h
,
size_t
v_buffer_stride_n
,
size_t
v_buffer_stride_h
,
IdType
*
kv_cache_loc
,
bool
interleave
,
bool
save_kv_cache
,
cudaStream_t
stream
=
nullptr
)
{
int
dev_id
=
0
;
int
num_sms
=
0
;
FLASHINFER_CUDA_CALL
(
cudaGetDevice
(
&
dev_id
));
FLASHINFER_CUDA_CALL
(
cudaDeviceGetAttribute
(
&
num_sms
,
cudaDevAttrMultiProcessorCount
,
dev_id
));
DISPATCH_SAVE_KV_CACHE
(
save_kv_cache
,
SAVE_KV_CACHE
,
{
DISPATCH_INTERLEAVE
(
interleave
,
INTERLEAVE
,
{
DISPATCH_HEAD_DIM
(
head_dim
,
HEAD_DIM
,
{
// operate on 16 Bytes at a time
constexpr
uint32_t
vec_size
=
std
::
max
(
16
/
sizeof
(
DType
),
HEAD_DIM
/
32
);
// how many threads needed per head_dim
constexpr
uint32_t
bdx
=
HEAD_DIM
/
vec_size
;
// how many threads needed per block
uint32_t
num_threads
=
std
::
max
(
128U
,
bdx
);
// how many tokens can we process in a block
uint32_t
bdy
=
num_threads
/
bdx
;
// how many blocks needed to process all tokens
uint32_t
nblks_x
=
(
nnz
+
bdy
-
1
)
/
bdy
;
void
*
args
[]
=
{
(
void
*
)
&
q
,
(
void
*
)
&
k
,
(
void
*
)
&
v
,
(
void
*
)
&
q_rope
,
(
void
*
)
&
k_rope
,
(
void
*
)
&
k_buffer
,
(
void
*
)
&
v_buffer
,
(
void
*
)
&
cos_sin_cache
,
(
void
*
)
&
pos_ids
,
(
void
*
)
&
nnz
,
(
void
*
)
&
num_qo_heads
,
(
void
*
)
&
num_kv_heads
,
(
void
*
)
&
rotary_dim
,
(
void
*
)
&
q_stride_n
,
(
void
*
)
&
q_stride_h
,
(
void
*
)
&
k_stride_n
,
(
void
*
)
&
k_stride_h
,
(
void
*
)
&
v_stride_n
,
(
void
*
)
&
v_stride_h
,
(
void
*
)
&
q_rope_stride_n
,
(
void
*
)
&
q_rope_stride_h
,
(
void
*
)
&
k_rope_stride_n
,
(
void
*
)
&
k_rope_stride_h
,
(
void
*
)
&
k_buffer_stride_n
,
(
void
*
)
&
k_buffer_stride_h
,
(
void
*
)
&
v_buffer_stride_n
,
(
void
*
)
&
v_buffer_stride_h
,
(
void
*
)
&
kv_cache_loc
};
auto
kernel_0
=
BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel
<
SAVE_KV_CACHE
,
INTERLEAVE
,
HEAD_DIM
,
vec_size
,
bdx
,
DType
,
IdType
>
;
int
num_blocks_per_sm_0
=
0
;
FLASHINFER_CUDA_CALL
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks_per_sm_0
,
kernel_0
,
num_threads
,
/*smem_size=*/
0
));
uint32_t
num_ctas_0
=
num_blocks_per_sm_0
*
num_sms
;
if
((
nnz
+
bdy
-
1
)
/
bdy
>=
num_ctas_0
)
{
dim3
nblks
(
nblks_x
);
dim3
nthrs
(
bdx
,
bdy
);
FLASHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel_0
,
nblks
,
nthrs
,
args
,
0
,
stream
));
}
else
{
dim3
nblks
(
nblks_x
,
num_qo_heads
+
num_kv_heads
);
dim3
nthrs
(
bdx
,
bdy
);
auto
kernel_1
=
BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel
<
SAVE_KV_CACHE
,
INTERLEAVE
,
HEAD_DIM
,
vec_size
,
bdx
,
DType
,
IdType
>
;
FLASHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel_1
,
nblks
,
nthrs
,
args
,
0
,
stream
));
}
});
});
});
return
cudaSuccess
;
}
}
// namespace flashinfer
#endif // SGL_POS_ENC_CUH_
sgl-kernel/csrc/elementwise/rope.cu
View file @
9aea2555
...
@@ -13,8 +13,8 @@
...
@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License.
* limitations under the License.
*/
*/
#include <flashinfer/pos_enc.cuh>
#include "pos_enc.cuh"
#include "pytorch_extension_utils.h"
#include "pytorch_extension_utils.h"
using
namespace
flashinfer
;
using
namespace
flashinfer
;
...
@@ -27,9 +27,37 @@ void apply_rope_pos_ids_cos_sin_cache(
...
@@ -27,9 +27,37 @@ void apply_rope_pos_ids_cos_sin_cache(
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
at
::
Tensor
pos_ids
,
bool
interleave
,
bool
interleave
,
int64_t
cuda_stream
)
{
int64_t
cuda_stream
,
const
std
::
optional
<
at
::
Tensor
>&
v
,
const
std
::
optional
<
at
::
Tensor
>&
k_buffer
,
const
std
::
optional
<
at
::
Tensor
>&
v_buffer
,
const
std
::
optional
<
at
::
Tensor
>&
kv_cache_loc
)
{
CHECK_LAST_DIM_CONTIGUOUS
(
q
);
CHECK_LAST_DIM_CONTIGUOUS
(
q
);
CHECK_LAST_DIM_CONTIGUOUS
(
k
);
CHECK_LAST_DIM_CONTIGUOUS
(
k
);
const
bool
save_kv_cache
=
v
.
has_value
();
if
(
save_kv_cache
)
{
TORCH_CHECK
(
v
.
has_value
());
TORCH_CHECK
(
k_buffer
.
has_value
());
TORCH_CHECK
(
v_buffer
.
has_value
());
TORCH_CHECK
(
kv_cache_loc
.
has_value
());
CHECK_LAST_DIM_CONTIGUOUS
(
v
.
value
());
CHECK_LAST_DIM_CONTIGUOUS
(
k_buffer
.
value
());
CHECK_LAST_DIM_CONTIGUOUS
(
v_buffer
.
value
());
CHECK_DIM
(
3
,
k_buffer
.
value
());
// k_buffer: (nnz, H_K, D)
CHECK_DIM
(
3
,
v_buffer
.
value
());
// v_buffer: (nnz, H_V, D)
CHECK_DIM
(
3
,
v
.
value
());
// v: (nnz, H_V, D)
CHECK_DIM
(
1
,
kv_cache_loc
.
value
());
// v: (n)
CHECK_INPUT
(
kv_cache_loc
.
value
());
}
size_t
k_buffer_stride_n
=
save_kv_cache
?
k_buffer
->
stride
(
0
)
:
0
;
size_t
k_buffer_stride_h
=
save_kv_cache
?
k_buffer
->
stride
(
1
)
:
0
;
size_t
v_buffer_stride_n
=
save_kv_cache
?
v_buffer
->
stride
(
0
)
:
0
;
size_t
v_buffer_stride_h
=
save_kv_cache
?
v_buffer
->
stride
(
1
)
:
0
;
size_t
v_stride_n
=
save_kv_cache
?
v
->
stride
(
0
)
:
0
;
size_t
v_stride_h
=
save_kv_cache
?
v
->
stride
(
1
)
:
0
;
auto
kv_cache_loc_ptr
=
save_kv_cache
?
static_cast
<
int64_t
*>
(
kv_cache_loc
->
data_ptr
())
:
nullptr
;
CHECK_INPUT
(
cos_sin_cache
);
CHECK_INPUT
(
cos_sin_cache
);
CHECK_INPUT
(
pos_ids
);
CHECK_INPUT
(
pos_ids
);
auto
device
=
q
.
device
();
auto
device
=
q
.
device
();
...
@@ -38,6 +66,7 @@ void apply_rope_pos_ids_cos_sin_cache(
...
@@ -38,6 +66,7 @@ void apply_rope_pos_ids_cos_sin_cache(
CHECK_EQ
(
pos_ids
.
device
(),
device
);
CHECK_EQ
(
pos_ids
.
device
(),
device
);
CHECK_DIM
(
3
,
q
);
// q: (nnz, H_Q, D)
CHECK_DIM
(
3
,
q
);
// q: (nnz, H_Q, D)
CHECK_DIM
(
3
,
k
);
// k: (nnz, H_K, D)
CHECK_DIM
(
3
,
k
);
// k: (nnz, H_K, D)
// cos_sin_cache: (max_seq_len, R)
// cos_sin_cache: (max_seq_len, R)
// First half of R is cos, second half is sin
// First half of R is cos, second half is sin
CHECK_DIM
(
2
,
cos_sin_cache
);
CHECK_DIM
(
2
,
cos_sin_cache
);
...
@@ -52,6 +81,7 @@ void apply_rope_pos_ids_cos_sin_cache(
...
@@ -52,6 +81,7 @@ void apply_rope_pos_ids_cos_sin_cache(
size_t
q_stride_h
=
q
.
stride
(
1
);
size_t
q_stride_h
=
q
.
stride
(
1
);
size_t
k_stride_n
=
k
.
stride
(
0
);
size_t
k_stride_n
=
k
.
stride
(
0
);
size_t
k_stride_h
=
k
.
stride
(
1
);
size_t
k_stride_h
=
k
.
stride
(
1
);
size_t
q_rope_stride_n
=
q_rope
.
stride
(
0
);
size_t
q_rope_stride_n
=
q_rope
.
stride
(
0
);
size_t
q_rope_stride_h
=
q_rope
.
stride
(
1
);
size_t
q_rope_stride_h
=
q_rope
.
stride
(
1
);
size_t
k_rope_stride_n
=
k_rope
.
stride
(
0
);
size_t
k_rope_stride_n
=
k_rope
.
stride
(
0
);
...
@@ -59,31 +89,73 @@ void apply_rope_pos_ids_cos_sin_cache(
...
@@ -59,31 +89,73 @@ void apply_rope_pos_ids_cos_sin_cache(
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
q
.
scalar_type
(),
c_type
,
[
&
]
{
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
q
.
scalar_type
(),
c_type
,
[
&
]
{
cudaError_t
status
=
BatchQKApplyRotaryPosIdsCosSinCache
(
// TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
static_cast
<
c_type
*>
(
q
.
data_ptr
()),
// to avoid changing original code path; but this branch is feature-complete and should switch to this later
static_cast
<
c_type
*>
(
k
.
data_ptr
()),
if
(
save_kv_cache
)
{
static_cast
<
c_type
*>
(
q_rope
.
data_ptr
()),
cudaError_t
status
=
BatchQKApplyRotaryPosIdsCosSinCacheEnhanced
(
static_cast
<
c_type
*>
(
k_rope
.
data_ptr
()),
static_cast
<
c_type
*>
(
q
.
data_ptr
()),
static_cast
<
float
*>
(
cos_sin_cache
.
data_ptr
()),
static_cast
<
c_type
*>
(
k
.
data_ptr
()),
static_cast
<
int64_t
*>
(
pos_ids
.
data_ptr
()),
save_kv_cache
?
static_cast
<
c_type
*>
(
v
->
data_ptr
())
:
nullptr
,
nnz
,
static_cast
<
c_type
*>
(
q_rope
.
data_ptr
()),
num_qo_heads
,
static_cast
<
c_type
*>
(
k_rope
.
data_ptr
()),
num_kv_heads
,
save_kv_cache
?
static_cast
<
c_type
*>
(
k_buffer
->
data_ptr
())
:
nullptr
,
rotary_dim
,
save_kv_cache
?
static_cast
<
c_type
*>
(
v_buffer
->
data_ptr
())
:
nullptr
,
head_dim
,
static_cast
<
float
*>
(
cos_sin_cache
.
data_ptr
()),
q_stride_n
,
static_cast
<
int64_t
*>
(
pos_ids
.
data_ptr
()),
q_stride_h
,
nnz
,
k_stride_n
,
num_qo_heads
,
k_stride_h
,
num_kv_heads
,
q_rope_stride_n
,
rotary_dim
,
q_rope_stride_h
,
head_dim
,
k_rope_stride_n
,
q_stride_n
,
k_rope_stride_h
,
q_stride_h
,
interleave
,
k_stride_n
,
stream
);
k_stride_h
,
TORCH_CHECK
(
v_stride_n
,
status
==
cudaSuccess
,
v_stride_h
,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
q_rope_stride_n
,
q_rope_stride_h
,
k_rope_stride_n
,
k_rope_stride_h
,
k_buffer_stride_n
,
k_buffer_stride_h
,
v_buffer_stride_n
,
v_buffer_stride_h
,
kv_cache_loc_ptr
,
interleave
,
save_kv_cache
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
}
else
{
cudaError_t
status
=
BatchQKApplyRotaryPosIdsCosSinCache
(
static_cast
<
c_type
*>
(
q
.
data_ptr
()),
static_cast
<
c_type
*>
(
k
.
data_ptr
()),
static_cast
<
c_type
*>
(
q_rope
.
data_ptr
()),
static_cast
<
c_type
*>
(
k_rope
.
data_ptr
()),
static_cast
<
float
*>
(
cos_sin_cache
.
data_ptr
()),
static_cast
<
int64_t
*>
(
pos_ids
.
data_ptr
()),
nnz
,
num_qo_heads
,
num_kv_heads
,
rotary_dim
,
head_dim
,
q_stride_n
,
q_stride_h
,
k_stride_n
,
k_stride_h
,
q_rope_stride_n
,
q_rope_stride_h
,
k_rope_stride_n
,
k_rope_stride_h
,
interleave
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
}
return
true
;
return
true
;
});
});
}
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
9aea2555
...
@@ -150,7 +150,11 @@ void apply_rope_pos_ids_cos_sin_cache(
...
@@ -150,7 +150,11 @@ void apply_rope_pos_ids_cos_sin_cache(
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
at
::
Tensor
pos_ids
,
bool
interleave
,
bool
interleave
,
int64_t
cuda_stream
);
int64_t
cuda_stream
,
const
std
::
optional
<
at
::
Tensor
>&
v
,
const
std
::
optional
<
at
::
Tensor
>&
k_buffer
,
const
std
::
optional
<
at
::
Tensor
>&
v_buffer
,
const
std
::
optional
<
at
::
Tensor
>&
kv_cache_loc
);
#ifdef USE_ROCM
#ifdef USE_ROCM
void
gelu_quick
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
input
);
void
gelu_quick
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
input
);
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
9aea2555
...
@@ -21,6 +21,7 @@ from sgl_kernel.attention import (
...
@@ -21,6 +21,7 @@ from sgl_kernel.attention import (
)
)
from
sgl_kernel.cutlass_moe
import
cutlass_w4a8_moe_mm
,
get_cutlass_w4a8_moe_mm_data
from
sgl_kernel.cutlass_moe
import
cutlass_w4a8_moe_mm
,
get_cutlass_w4a8_moe_mm_data
from
sgl_kernel.elementwise
import
(
from
sgl_kernel.elementwise
import
(
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
,
apply_rope_with_cos_sin_cache_inplace
,
fused_add_rmsnorm
,
fused_add_rmsnorm
,
gelu_and_mul
,
gelu_and_mul
,
...
...
sgl-kernel/python/sgl_kernel/elementwise.py
View file @
9aea2555
from
typing
import
Optional
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
import
torch
import
torch
from
sgl_kernel.utils
import
get_cuda_stream
,
is_hopper_arch
from
sgl_kernel.utils
import
get_cuda_stream
,
is_hopper_arch
...
@@ -237,6 +238,31 @@ if torch.version.hip is not None:
...
@@ -237,6 +238,31 @@ if torch.version.hip is not None:
return
out
return
out
@
dataclass
class
FusedSetKVBufferArg
:
"""
value : Optional[torch.Tensor]
Value tensor, shape: ``(nnz, num_v_heads * head_size)``.
k_buffer : Optional[torch.Tensor]
Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.
v_buffer : Optional[torch.Tensor]
Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.
k_scale : Optional[float]
Scale factor for keys.
v_scale : Optional[float]
Scale factor for values.
cache_loc : Optional[torch.Tensor]
Cache location tensor, used for indexing kv cache.
"""
value
:
torch
.
Tensor
k_buffer
:
torch
.
Tensor
v_buffer
:
torch
.
Tensor
k_scale
:
Optional
[
float
]
v_scale
:
Optional
[
float
]
cache_loc
:
torch
.
Tensor
def
apply_rope_with_cos_sin_cache_inplace
(
def
apply_rope_with_cos_sin_cache_inplace
(
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
@@ -244,6 +270,7 @@ def apply_rope_with_cos_sin_cache_inplace(
...
@@ -244,6 +270,7 @@ def apply_rope_with_cos_sin_cache_inplace(
head_size
:
int
,
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
=
True
,
is_neox
:
bool
=
True
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
None
:
)
->
None
:
r
"""
r
"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
Apply rotary embedding to keys and queries with precomputed cos/sin values.
...
@@ -270,6 +297,9 @@ def apply_rope_with_cos_sin_cache_inplace(
...
@@ -270,6 +297,9 @@ def apply_rope_with_cos_sin_cache_inplace(
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
fused_set_kv_buffer_arg : FusedSetKVBufferArg
Fuse the set-kv-buffer operation into this kernel
Note
Note
----
----
The rotary dimension is determined by the cosine cache and sine cache.
The rotary dimension is determined by the cosine cache and sine cache.
...
@@ -277,13 +307,41 @@ def apply_rope_with_cos_sin_cache_inplace(
...
@@ -277,13 +307,41 @@ def apply_rope_with_cos_sin_cache_inplace(
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
raise
ValueError
(
"cos_sin_cache should be float32"
)
raise
ValueError
(
"cos_sin_cache should be float32"
)
if
(
a
:
=
fused_set_kv_buffer_arg
)
is
not
None
:
assert
a
.
k_scale
is
None
,
"k_scale is not yet supported"
assert
a
.
v_scale
is
None
,
"v_scale is not yet supported"
assert
a
.
cache_loc
.
dtype
==
torch
.
int64
,
f
"
{
a
.
cache_loc
.
dtype
=
}
"
def
_view_3d
(
x
):
return
x
.
view
(
x
.
shape
[
0
],
-
1
,
head_size
)
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
.
default
(
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
.
default
(
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
_
view
_3d
(
query
),
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
_
view
_3d
(
key
),
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
_
view
_3d
(
query
),
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
_
view
_3d
(
key
),
cos_sin_cache
,
cos_sin_cache
,
positions
.
long
(),
positions
.
long
(),
(
not
is_neox
),
(
not
is_neox
),
get_cuda_stream
(),
get_cuda_stream
(),
(
_view_3d
(
fused_set_kv_buffer_arg
.
value
)
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
(
_view_3d
(
fused_set_kv_buffer_arg
.
k_buffer
)
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
(
_view_3d
(
fused_set_kv_buffer_arg
.
v_buffer
)
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
(
fused_set_kv_buffer_arg
.
cache_loc
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
)
)
sgl-kernel/python/sgl_kernel/testing/__init__.py
0 → 100644
View file @
9aea2555
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
0 → 100644
View file @
9aea2555
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
pytest
import
torch
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
# vLLM torch native
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward_native
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation of forward()."""
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
# Modification: float32 is required for the rotary embedding to work correctly
query
=
query
.
to
(
torch
.
float32
)
key
=
key
.
to
(
torch
.
float32
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
# Modification: convert to the correct dtype
query
=
query
.
to
(
self
.
dtype
)
key
=
key
.
to
(
self
.
dtype
)
return
query
,
key
class
FlashInferRotaryEmbedding
(
RotaryEmbedding
):
def
forward_cuda
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
apply_rope_with_cos_sin_cache_inplace
(
positions
=
positions
,
query
=
query
,
key
=
key
,
fused_set_kv_buffer_arg
=
fused_set_kv_buffer_arg
,
head_size
=
self
.
head_size
,
cos_sin_cache
=
self
.
cos_sin_cache
,
is_neox
=
self
.
is_neox_style
,
)
return
query
,
key
class
MHATokenToKVPool
:
KV_POOL_SIZE
=
16384
def
__init__
(
self
,
head_num
:
int
,
head_dim
:
int
,
):
self
.
head_num
=
head_num
self
.
head_dim
=
head_dim
self
.
size
=
MHATokenToKVPool
.
KV_POOL_SIZE
self
.
page_size
=
1
self
.
store_dtype
=
torch
.
bfloat16
self
.
device
=
"cuda"
self
.
layer_num
=
1
self
.
start_layer
=
0
self
.
_create_buffers
()
def
_create_buffers
(
self
):
self
.
k_buffer
=
[
torch
.
zeros
(
(
self
.
size
+
self
.
page_size
,
self
.
head_num
,
self
.
head_dim
),
dtype
=
self
.
store_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
layer_num
)
]
self
.
v_buffer
=
[
torch
.
zeros
(
(
self
.
size
+
self
.
page_size
,
self
.
head_num
,
self
.
head_dim
),
dtype
=
self
.
store_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
layer_num
)
]
def
set_kv_buffer
(
self
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
):
layer_id
=
0
self
.
k_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_v
def
create_inputs
(
head_size
:
int
,
batch_size
:
int
,
seq_len
:
int
,
device
,
dtype
:
torch
.
dtype
,
num_q_heads
:
int
,
num_kv_heads
:
int
,
):
pos_ids
=
torch
.
arange
(
seq_len
,
device
=
device
).
repeat
(
batch_size
)
query
=
torch
.
randn
(
batch_size
*
seq_len
,
num_q_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
)
key
=
torch
.
randn
(
batch_size
*
seq_len
,
num_kv_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
randn
(
batch_size
*
seq_len
,
num_kv_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
)
out_cache_loc
=
torch
.
randperm
(
MHATokenToKVPool
.
KV_POOL_SIZE
,
dtype
=
torch
.
int64
,
device
=
device
)[:
batch_size
*
seq_len
].
clone
()
return
dict
(
pos_ids
=
pos_ids
,
query
=
query
,
key
=
key
,
value
=
value
,
out_cache_loc
=
out_cache_loc
)
sgl-kernel/tests/test_rotary_embedding.py
View file @
9aea2555
...
@@ -2,153 +2,51 @@ from typing import Any, Dict, List, Optional, Tuple, Union
...
@@ -2,153 +2,51 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
pytest
import
pytest
import
torch
import
torch
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
from
sgl_kernel.testing.rotary_embedding
import
(
FlashInferRotaryEmbedding
,
# vLLM torch native
MHATokenToKVPool
,
def
_apply_rotary_emb
(
RotaryEmbedding
,
x
:
torch
.
Tensor
,
create_inputs
,
cos
:
torch
.
Tensor
,
)
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward_native
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation of forward()."""
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
# Modification: float32 is required for the rotary embedding to work correctly
query
=
query
.
to
(
torch
.
float32
)
key
=
key
.
to
(
torch
.
float32
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
# Modification: convert to the correct dtype
query
=
query
.
to
(
self
.
dtype
)
key
=
key
.
to
(
self
.
dtype
)
return
query
,
key
class
FlashInferRotaryEmbedding
(
RotaryEmbedding
):
def
forward_cuda
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
apply_rope_with_cos_sin_cache_inplace
(
positions
=
positions
,
query
=
query
,
key
=
key
,
head_size
=
self
.
head_size
,
cos_sin_cache
=
self
.
cos_sin_cache
,
is_neox
=
self
.
is_neox_style
,
)
return
query
,
key
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads"
,
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads
, save_kv_cache
"
,
[
[
(
64
,
64
,
32
,
8000
,
True
,
torch
.
bfloat16
,
"cuda"
,
32
,
32
,
1
,
1
),
# GPT-OSS cases
(
256
,
128
,
4096
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
4
,
2
),
*
[
(
512
,
128
,
311
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
3
,
39
,
4
,
2
),
(
(
128
,
128
,
2048
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
32
,
8
),
64
,
(
128
,
128
,
2048
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
16
,
4
),
64
,
(
512
,
128
,
311
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
3
,
39
,
4
,
2
),
4096
,
8000
,
True
,
torch
.
bfloat16
,
"cuda"
,
batch_size
,
seq_len
,
64
,
8
,
save_kv_cache
,
)
for
batch_size
,
seq_len
in
(
(
1
,
1
),
(
32
,
1
),
(
128
,
1
),
(
512
,
1
),
(
2
,
512
),
(
4
,
4096
),
)
for
save_kv_cache
in
(
False
,
True
)
],
# Other cases
(
64
,
64
,
32
,
8000
,
True
,
torch
.
bfloat16
,
"cuda"
,
32
,
32
,
1
,
1
,
False
),
(
256
,
128
,
4096
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
4
,
2
,
False
),
(
512
,
128
,
311
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
3
,
39
,
4
,
2
,
False
),
(
128
,
128
,
2048
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
32
,
8
,
False
),
(
128
,
128
,
2048
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
16
,
4
,
False
),
(
512
,
128
,
311
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
3
,
39
,
4
,
2
,
False
),
],
],
)
)
def
test_correctness
(
def
test_correctness
(
...
@@ -163,34 +61,77 @@ def test_correctness(
...
@@ -163,34 +61,77 @@ def test_correctness(
seq_len
:
int
,
seq_len
:
int
,
num_q_heads
:
int
,
num_q_heads
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
save_kv_cache
:
bool
,
):
):
rope_ref
=
RotaryEmbedding
(
config
=
dict
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
head_size
=
head_size
,
).
to
(
device
)
rotary_dim
=
rotary_dim
,
rope_flashinfer
=
FlashInferRotaryEmbedding
(
max_position_embeddings
=
max_position_embeddings
,
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
base
=
base
,
).
to
(
device
)
is_neox_style
=
is_neox_style
,
dtype
=
dtype
,
pos_ids
=
torch
.
arange
(
seq_len
,
device
=
device
).
repeat
(
batch_size
)
query
=
torch
.
randn
(
batch_size
*
seq_len
,
num_q_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
)
)
key
=
torch
.
randn
(
batch_size
*
seq_len
,
num_kv_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
rope_ref
=
RotaryEmbedding
(
**
config
).
to
(
device
)
rope_flashinfer
=
FlashInferRotaryEmbedding
(
**
config
).
to
(
device
)
inputs
=
create_inputs
(
head_size
=
head_size
,
batch_size
=
batch_size
,
seq_len
=
seq_len
,
device
=
device
,
dtype
=
dtype
,
num_q_heads
=
num_q_heads
,
num_kv_heads
=
num_kv_heads
,
)
)
query_ref
,
key_ref
=
query
.
clone
(),
key
.
clone
()
if
save_kv_cache
:
query_flashinfer
,
key_flashinfer
=
query
.
clone
(),
key
.
clone
()
pool_ref
=
MHATokenToKVPool
(
head_num
=
num_kv_heads
,
head_dim
=
head_size
)
pool_flashinfer
=
MHATokenToKVPool
(
head_num
=
num_kv_heads
,
head_dim
=
head_size
)
query_ref
,
key_ref
=
inputs
[
"query"
].
clone
(),
inputs
[
"key"
].
clone
()
query_flashinfer
,
key_flashinfer
=
inputs
[
"query"
].
clone
(),
inputs
[
"key"
].
clone
()
query_ref_out
,
key_ref_out
=
rope_ref
.
forward_native
(
inputs
[
"pos_ids"
],
query_ref
,
key_ref
)
if
save_kv_cache
:
pool_ref
.
set_kv_buffer
(
loc
=
inputs
[
"out_cache_loc"
],
cache_k
=
key_ref_out
.
view
(
-
1
,
num_kv_heads
,
head_size
),
cache_v
=
inputs
[
"value"
].
view
(
-
1
,
num_kv_heads
,
head_size
),
)
query_ref_out
,
key_ref_out
=
rope_ref
.
forward_native
(
pos_ids
,
query_ref
,
key_ref
)
query_flashinfer_out
,
key_flashinfer_out
=
rope_flashinfer
.
forward_cuda
(
query_flashinfer_out
,
key_flashinfer_out
=
rope_flashinfer
.
forward_cuda
(
pos_ids
,
query_flashinfer
,
key_flashinfer
inputs
[
"pos_ids"
],
query_flashinfer
,
key_flashinfer
,
fused_set_kv_buffer_arg
=
(
FusedSetKVBufferArg
(
value
=
inputs
[
"value"
],
k_buffer
=
pool_flashinfer
.
k_buffer
[
0
].
view
(
-
1
,
num_kv_heads
*
head_size
),
v_buffer
=
pool_flashinfer
.
v_buffer
[
0
].
view
(
-
1
,
num_kv_heads
*
head_size
),
k_scale
=
None
,
v_scale
=
None
,
cache_loc
=
inputs
[
"out_cache_loc"
],
)
if
save_kv_cache
else
None
),
)
)
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
query_ref_out
,
query_flashinfer_out
,
atol
=
1e-2
,
rtol
=
1e-2
query_ref_out
,
query_flashinfer_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
)
torch
.
testing
.
assert_close
(
key_ref_out
,
key_flashinfer_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
key_ref_out
,
key_flashinfer_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
save_kv_cache
:
for
field
in
[
"k_buffer"
,
"v_buffer"
]:
x_ref
=
getattr
(
pool_ref
,
field
)[
0
]
x_flashinfer
=
getattr
(
pool_flashinfer
,
field
)[
0
]
torch
.
testing
.
assert_close
(
x_ref
,
x_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
)
nonzero_ref
=
x_ref
!=
0
nonzero_flashinfer
=
x_ref
!=
0
assert
torch
.
all
(
nonzero_ref
==
nonzero_flashinfer
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment