Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
9aea2555
"tools/config_init_to_json.py" did not exist on "34bde6d8ec9622441af5cb857eddb0509a84dbe2"
Unverified
Commit
9aea2555
authored
Aug 12, 2025
by
fzyzcjy
Committed by
GitHub
Aug 12, 2025
Browse files
Fuse writing KV buffer into rope kernel (part 1: sgl-kernel) (#9077)
parent
fcc11e5e
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
1151 additions
and
193 deletions
+1151
-193
python/sglang/srt/bench_utils.py
python/sglang/srt/bench_utils.py
+137
-0
sgl-kernel/benchmark/bench_rotary_embedding.py
sgl-kernel/benchmark/bench_rotary_embedding.py
+96
-0
sgl-kernel/csrc/common_extension.cc
sgl-kernel/csrc/common_extension.cc
+2
-1
sgl-kernel/csrc/elementwise/pos_enc.cuh
sgl-kernel/csrc/elementwise/pos_enc.cuh
+431
-0
sgl-kernel/csrc/elementwise/rope.cu
sgl-kernel/csrc/elementwise/rope.cu
+99
-27
sgl-kernel/include/sgl_kernel_ops.h
sgl-kernel/include/sgl_kernel_ops.h
+5
-1
sgl-kernel/python/sgl_kernel/__init__.py
sgl-kernel/python/sgl_kernel/__init__.py
+1
-0
sgl-kernel/python/sgl_kernel/elementwise.py
sgl-kernel/python/sgl_kernel/elementwise.py
+63
-5
sgl-kernel/python/sgl_kernel/testing/__init__.py
sgl-kernel/python/sgl_kernel/testing/__init__.py
+0
-0
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
+217
-0
sgl-kernel/tests/test_rotary_embedding.py
sgl-kernel/tests/test_rotary_embedding.py
+100
-159
No files found.
python/sglang/srt/bench_utils.py
0 → 100644
View file @
9aea2555
import
os
import
sys
from
contextlib
import
nullcontext
import
torch
# NOTE copied and modified from DeepGEMM
class
suppress_stdout_stderr
:
def
__enter__
(
self
):
self
.
outnull_file
=
open
(
os
.
devnull
,
"w"
)
self
.
errnull_file
=
open
(
os
.
devnull
,
"w"
)
self
.
old_stdout_fileno_undup
=
sys
.
stdout
.
fileno
()
self
.
old_stderr_fileno_undup
=
sys
.
stderr
.
fileno
()
self
.
old_stdout_fileno
=
os
.
dup
(
sys
.
stdout
.
fileno
())
self
.
old_stderr_fileno
=
os
.
dup
(
sys
.
stderr
.
fileno
())
self
.
old_stdout
=
sys
.
stdout
self
.
old_stderr
=
sys
.
stderr
os
.
dup2
(
self
.
outnull_file
.
fileno
(),
self
.
old_stdout_fileno_undup
)
os
.
dup2
(
self
.
errnull_file
.
fileno
(),
self
.
old_stderr_fileno_undup
)
sys
.
stdout
=
self
.
outnull_file
sys
.
stderr
=
self
.
errnull_file
return
self
def
__exit__
(
self
,
*
_
):
sys
.
stdout
=
self
.
old_stdout
sys
.
stderr
=
self
.
old_stderr
os
.
dup2
(
self
.
old_stdout_fileno
,
self
.
old_stdout_fileno_undup
)
os
.
dup2
(
self
.
old_stderr_fileno
,
self
.
old_stderr_fileno_undup
)
os
.
close
(
self
.
old_stdout_fileno
)
os
.
close
(
self
.
old_stderr_fileno
)
self
.
outnull_file
.
close
()
self
.
errnull_file
.
close
()
# NOTE copied and modified from DeepGEMM
def
bench_kineto
(
fn
,
kernel_names
,
num_tests
:
int
=
30
,
suppress_kineto_output
:
bool
=
False
,
trace_path
:
str
=
None
,
flush_l2
:
bool
=
True
,
with_multiple_kernels
:
bool
=
False
,
):
# Conflict with Nsight Systems
using_nsys
=
int
(
os
.
environ
.
get
(
"SGLANG_NSYS_PROFILING"
,
0
))
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
flush_l2_size
=
int
(
8e9
//
4
)
# For some auto-tuning kernels with prints
fn
()
# Profile
suppress
=
(
suppress_stdout_stderr
if
suppress_kineto_output
and
not
using_nsys
else
nullcontext
)
with
suppress
():
schedule
=
(
torch
.
profiler
.
schedule
(
wait
=
0
,
warmup
=
1
,
active
=
1
,
repeat
=
1
)
if
not
using_nsys
else
None
)
profiler
=
(
torch
.
profiler
.
profile
(
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
],
schedule
=
schedule
)
if
not
using_nsys
else
nullcontext
()
)
with
profiler
:
for
i
in
range
(
2
):
for
_
in
range
(
num_tests
):
if
flush_l2
:
torch
.
empty
(
flush_l2_size
,
dtype
=
torch
.
int
,
device
=
"cuda"
).
zero_
()
fn
()
if
not
using_nsys
:
profiler
.
step
()
# Return 1 if using Nsight Systems
if
using_nsys
:
return
1
# Parse the profiling table
assert
isinstance
(
kernel_names
,
str
)
or
isinstance
(
kernel_names
,
tuple
)
is_tuple
=
isinstance
(
kernel_names
,
tuple
)
prof_lines
=
(
profiler
.
key_averages
()
.
table
(
sort_by
=
"cuda_time_total"
,
max_name_column_width
=
100
)
.
split
(
"
\n
"
)
)
kernel_names
=
(
kernel_names
,)
if
isinstance
(
kernel_names
,
str
)
else
kernel_names
assert
all
([
isinstance
(
name
,
str
)
for
name
in
kernel_names
])
if
not
with_multiple_kernels
:
for
name
in
kernel_names
:
assert
(
sum
([
name
in
line
for
line
in
prof_lines
])
==
1
),
f
"Errors of the kernel
{
name
}
in the profiling table (table:
{
prof_lines
}
)"
# Save chrome traces
if
trace_path
is
not
None
:
profiler
.
export_chrome_trace
(
trace_path
)
# Return average kernel times
units
=
{
"ms"
:
1e3
,
"us"
:
1e6
}
kernel_times
=
[]
for
name
in
kernel_names
:
total_time
=
0
total_num
=
0
for
line
in
prof_lines
:
if
name
in
line
:
time_str
=
line
.
split
()[
-
2
]
num_str
=
line
.
split
()[
-
1
]
for
unit
,
scale
in
units
.
items
():
if
unit
in
time_str
:
total_time
+=
(
float
(
time_str
.
replace
(
unit
,
""
))
/
scale
*
int
(
num_str
)
)
total_num
+=
int
(
num_str
)
break
kernel_times
.
append
(
total_time
/
total_num
)
return
tuple
(
kernel_times
)
if
is_tuple
else
kernel_times
[
0
]
sgl-kernel/benchmark/bench_rotary_embedding.py
0 → 100644
View file @
9aea2555
import
itertools
import
torch
import
triton
from
sgl_kernel
import
FusedSetKVBufferArg
from
sgl_kernel.testing.rotary_embedding
import
(
FlashInferRotaryEmbedding
,
MHATokenToKVPool
,
RotaryEmbedding
,
create_inputs
,
)
from
sglang.srt.bench_utils
import
bench_kineto
configs
=
[
(
batch_size
,
seq_len
,
save_kv_cache
)
for
batch_size
,
seq_len
in
(
(
1
,
1
),
(
32
,
1
),
(
128
,
1
),
(
512
,
1
),
(
2
,
512
),
(
4
,
4096
),
)
for
save_kv_cache
in
(
False
,
True
)
]
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"batch_size"
,
"seq_len"
,
"save_kv_cache"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"sglang"
],
line_names
=
[
"SGL Kernel"
],
styles
=
[(
"green"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
"bench_rotary_embedding"
,
args
=
{},
)
)
def
benchmark
(
batch_size
,
seq_len
,
save_kv_cache
,
provider
):
device
=
torch
.
device
(
"cuda"
)
num_q_heads
=
32
num_kv_heads
=
8
head_size
=
64
dtype
=
torch
.
bfloat16
config
=
dict
(
head_size
=
head_size
,
rotary_dim
=
64
,
max_position_embeddings
=
4096
,
base
=
8000
,
is_neox_style
=
True
,
dtype
=
dtype
,
)
rope_flashinfer
=
FlashInferRotaryEmbedding
(
**
config
).
to
(
device
)
pool_flashinfer
=
MHATokenToKVPool
(
head_num
=
num_kv_heads
,
head_dim
=
head_size
)
inputs
=
create_inputs
(
head_size
=
head_size
,
batch_size
=
batch_size
,
seq_len
=
seq_len
,
device
=
device
,
dtype
=
dtype
,
num_q_heads
=
num_q_heads
,
num_kv_heads
=
num_kv_heads
,
)
query_flashinfer
,
key_flashinfer
=
inputs
[
"query"
].
clone
(),
inputs
[
"key"
].
clone
()
bench_fn
=
lambda
:
rope_flashinfer
.
forward_cuda
(
inputs
[
"pos_ids"
],
query_flashinfer
,
key_flashinfer
,
fused_set_kv_buffer_arg
=
(
FusedSetKVBufferArg
(
value
=
inputs
[
"value"
],
k_buffer
=
pool_flashinfer
.
k_buffer
[
0
].
view
(
-
1
,
num_kv_heads
*
head_size
),
v_buffer
=
pool_flashinfer
.
v_buffer
[
0
].
view
(
-
1
,
num_kv_heads
*
head_size
),
k_scale
=
None
,
v_scale
=
None
,
cache_loc
=
inputs
[
"out_cache_loc"
],
)
if
save_kv_cache
else
None
),
)
time_s
=
bench_kineto
(
bench_fn
,
kernel_names
=
"BatchQKApplyRotaryPosIds"
)
return
time_s
*
1e6
if
__name__
==
"__main__"
:
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/csrc/common_extension.cc
View file @
9aea2555
...
...
@@ -89,7 +89,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
def
(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()"
);
"Tensor pos_ids, bool interleave, int cuda_stream, "
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()"
);
m
.
impl
(
"apply_rope_pos_ids_cos_sin_cache"
,
torch
::
kCUDA
,
&
apply_rope_pos_ids_cos_sin_cache
);
/*
...
...
sgl-kernel/csrc/elementwise/pos_enc.cuh
0 → 100644
View file @
9aea2555
/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef SGL_POS_ENC_CUH_
#define SGL_POS_ENC_CUH_
#include <flashinfer/pos_enc.cuh> // upstream
namespace
flashinfer
{
namespace
kv_buffer_saver
{
template
<
typename
DType
,
typename
IdType
,
uint32_t
vec_size
>
__device__
__forceinline__
void
prepare
(
vec_t
<
float
,
vec_size
>&
v_vec
,
IdType
&
kv_cache_offset
,
DType
*
v
,
IdType
*
kv_cache_loc
,
uint32_t
idx
,
uint32_t
tx
,
uint32_t
kv_head_idx
,
size_t
v_stride_n
,
size_t
v_stride_h
)
{
kv_cache_offset
=
kv_cache_loc
[
idx
];
DType
*
v_ptr
=
v
+
get_elem_offset_impl
(
idx
,
kv_head_idx
,
0
,
v_stride_n
,
v_stride_h
);
v_vec
.
cast_load
(
v_ptr
+
tx
*
vec_size
);
}
template
<
typename
DType
,
typename
IdType
,
uint32_t
vec_size
>
__device__
__forceinline__
void
save
(
IdType
&
kv_cache_offset
,
vec_t
<
float
,
vec_size
>&
k_vec
,
vec_t
<
float
,
vec_size
>&
v_vec
,
DType
*
k_buffer
,
DType
*
v_buffer
,
uint32_t
idx
,
uint32_t
tx
,
uint32_t
kv_head_idx
,
size_t
k_buffer_stride_n
,
size_t
k_buffer_stride_h
,
size_t
v_buffer_stride_n
,
size_t
v_buffer_stride_h
)
{
DType
*
k_buffer_ptr
=
k_buffer
+
get_elem_offset_impl
(
kv_cache_offset
,
kv_head_idx
,
0
,
k_buffer_stride_n
,
k_buffer_stride_h
);
DType
*
v_buffer_ptr
=
v_buffer
+
get_elem_offset_impl
(
kv_cache_offset
,
kv_head_idx
,
0
,
v_buffer_stride_n
,
v_buffer_stride_h
);
k_vec
.
cast_store
(
k_buffer_ptr
+
tx
*
vec_size
);
v_vec
.
cast_store
(
v_buffer_ptr
+
tx
*
vec_size
);
}
}
// namespace kv_buffer_saver
template
<
bool
save_kv_cache
,
bool
interleave
,
uint32_t
head_dim
,
uint32_t
vec_size
,
uint32_t
bdx
,
typename
DType
,
typename
IdType
>
__global__
void
BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel
(
DType
*
q
,
DType
*
k
,
DType
*
v
,
DType
*
q_rope
,
DType
*
k_rope
,
DType
*
k_buffer
,
DType
*
v_buffer
,
float
*
__restrict__
cos_sin_cache
,
IdType
*
__restrict__
pos_ids
,
uint32_t
nnz
,
uint32_t
num_qo_heads
,
uint32_t
num_kv_heads
,
uint32_t
rotary_dim
,
size_t
q_stride_n
,
size_t
q_stride_h
,
size_t
k_stride_n
,
size_t
k_stride_h
,
size_t
v_stride_n
,
size_t
v_stride_h
,
size_t
q_rope_stride_n
,
size_t
q_rope_stride_h
,
size_t
k_rope_stride_n
,
size_t
k_rope_stride_h
,
size_t
k_buffer_stride_n
,
size_t
k_buffer_stride_h
,
size_t
v_buffer_stride_n
,
size_t
v_buffer_stride_h
,
IdType
*
__restrict__
kv_cache_loc
)
{
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
,
ty
=
threadIdx
.
y
;
uint32_t
by
=
blockIdx
.
y
;
const
uint32_t
bdy
=
blockDim
.
y
;
vec_t
<
float
,
vec_size
>
cos
,
sin
;
if
(
bx
*
bdy
+
ty
<
nnz
)
{
const
uint32_t
idx
=
bx
*
bdy
+
ty
;
const
IdType
pos
=
pos_ids
[
idx
];
const
int
half_rotary_dim
=
rotary_dim
/
2
;
// 1. if interleave:
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
// 2. if not interleave
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
if
(
tx
*
vec_size
<
rotary_dim
)
{
int
sin_offset
=
rotary_dim
/
2
;
int
vec_idx
;
if
constexpr
(
interleave
)
{
vec_idx
=
(
tx
*
vec_size
)
/
2
;
// Force integer division
}
else
{
vec_idx
=
(
tx
*
vec_size
)
%
half_rotary_dim
;
// Use half_rotary_dim
}
cos
.
load
(
cos_sin_cache
+
(
pos
*
rotary_dim
)
+
vec_idx
);
sin
.
load
(
cos_sin_cache
+
(
pos
*
rotary_dim
)
+
(
sin_offset
+
vec_idx
));
}
if
(
by
<
num_qo_heads
)
{
uint32_t
qo_head_idx
=
by
;
DType
*
q_ptr
=
q
+
get_elem_offset_impl
(
idx
,
qo_head_idx
,
0
,
q_stride_n
,
q_stride_h
);
DType
*
q_rope_ptr
=
q_rope
+
get_elem_offset_impl
(
idx
,
qo_head_idx
,
0
,
q_rope_stride_n
,
q_rope_stride_h
);
vec_t
<
float
,
vec_size
>
q_vec
;
if
constexpr
(
interleave
)
{
q_vec
=
vec_apply_llama_rope_cos_sin_interleave_reuse_half
<
vec_size
,
bdx
>
(
q_ptr
,
cos
,
sin
,
rotary_dim
);
}
else
{
q_vec
=
vec_apply_llama_rope_cos_sin
<
vec_size
,
bdx
>
(
q_ptr
,
cos
,
sin
,
rotary_dim
);
}
q_vec
.
cast_store
(
q_rope_ptr
+
tx
*
vec_size
);
}
else
{
uint32_t
kv_head_idx
=
by
-
num_qo_heads
;
DType
*
k_ptr
=
k
+
get_elem_offset_impl
(
idx
,
kv_head_idx
,
0
,
k_stride_n
,
k_stride_h
);
DType
*
k_rope_ptr
=
k_rope
+
get_elem_offset_impl
(
idx
,
kv_head_idx
,
0
,
k_rope_stride_n
,
k_rope_stride_h
);
vec_t
<
float
,
vec_size
>
v_vec
;
IdType
kv_cache_offset
;
if
constexpr
(
save_kv_cache
)
{
kv_buffer_saver
::
prepare
<
DType
,
IdType
,
vec_size
>
(
v_vec
,
kv_cache_offset
,
v
,
kv_cache_loc
,
idx
,
tx
,
kv_head_idx
,
v_stride_n
,
v_stride_h
);
}
vec_t
<
float
,
vec_size
>
k_vec
;
if
constexpr
(
interleave
)
{
k_vec
=
vec_apply_llama_rope_cos_sin_interleave_reuse_half
<
vec_size
,
bdx
>
(
k_ptr
,
cos
,
sin
,
rotary_dim
);
}
else
{
k_vec
=
vec_apply_llama_rope_cos_sin
<
vec_size
,
bdx
>
(
k_ptr
,
cos
,
sin
,
rotary_dim
);
}
k_vec
.
cast_store
(
k_rope_ptr
+
tx
*
vec_size
);
if
constexpr
(
save_kv_cache
)
{
kv_buffer_saver
::
save
<
DType
,
IdType
,
vec_size
>
(
kv_cache_offset
,
k_vec
,
v_vec
,
k_buffer
,
v_buffer
,
idx
,
tx
,
kv_head_idx
,
k_buffer_stride_n
,
k_buffer_stride_h
,
v_buffer_stride_n
,
v_buffer_stride_h
);
}
}
}
}
template
<
bool
save_kv_cache
,
bool
interleave
,
uint32_t
head_dim
,
uint32_t
vec_size
,
uint32_t
bdx
,
typename
DType
,
typename
IdType
>
__global__
void
BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel
(
DType
*
q
,
DType
*
k
,
DType
*
v
,
DType
*
q_rope
,
DType
*
k_rope
,
DType
*
k_buffer
,
DType
*
v_buffer
,
float
*
__restrict__
cos_sin_cache
,
IdType
*
__restrict__
pos_ids
,
uint32_t
nnz
,
uint32_t
num_qo_heads
,
uint32_t
num_kv_heads
,
uint32_t
rotary_dim
,
size_t
q_stride_n
,
size_t
q_stride_h
,
size_t
k_stride_n
,
size_t
k_stride_h
,
size_t
v_stride_n
,
size_t
v_stride_h
,
size_t
q_rope_stride_n
,
size_t
q_rope_stride_h
,
size_t
k_rope_stride_n
,
size_t
k_rope_stride_h
,
size_t
k_buffer_stride_n
,
size_t
k_buffer_stride_h
,
size_t
v_buffer_stride_n
,
size_t
v_buffer_stride_h
,
IdType
*
__restrict__
kv_cache_loc
)
{
uint32_t
bx
=
blockIdx
.
x
,
tx
=
threadIdx
.
x
,
ty
=
threadIdx
.
y
;
const
uint32_t
bdy
=
blockDim
.
y
;
vec_t
<
float
,
vec_size
>
cos
,
sin
;
if
(
bx
*
bdy
+
ty
<
nnz
)
{
const
uint32_t
idx
=
bx
*
bdy
+
ty
;
const
IdType
pos
=
pos_ids
[
idx
];
const
int
half_rotary_dim
=
rotary_dim
/
2
;
// 1. if interleave:
// - cos = cos_sin_cache[pos_id][tx * vec_size // 2]
// - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2]
// 2. if not interleave
// - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)]
// - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % (rot_dim // 2)]
if
(
tx
*
vec_size
<
rotary_dim
)
{
int
sin_offset
=
rotary_dim
/
2
;
int
vec_idx
;
if
constexpr
(
interleave
)
{
vec_idx
=
(
tx
*
vec_size
)
/
2
;
// Force integer division
}
else
{
vec_idx
=
(
tx
*
vec_size
)
%
half_rotary_dim
;
// Use half_rotary_dim
}
cos
.
load
(
cos_sin_cache
+
(
pos
*
rotary_dim
)
+
vec_idx
);
sin
.
load
(
cos_sin_cache
+
(
pos
*
rotary_dim
)
+
(
sin_offset
+
vec_idx
));
}
// not to unroll the loop, because num head might be large and might lead to worse performance
#pragma unroll 1
for
(
uint32_t
qo_head_idx
=
0
;
qo_head_idx
<
num_qo_heads
;
++
qo_head_idx
)
{
DType
*
q_ptr
=
q
+
get_elem_offset_impl
(
idx
,
qo_head_idx
,
0
,
q_stride_n
,
q_stride_h
);
DType
*
q_rope_ptr
=
q_rope
+
get_elem_offset_impl
(
idx
,
qo_head_idx
,
0
,
q_rope_stride_n
,
q_rope_stride_h
);
vec_t
<
float
,
vec_size
>
q_vec
;
if
constexpr
(
interleave
)
{
q_vec
=
vec_apply_llama_rope_cos_sin_interleave_reuse_half
<
vec_size
,
bdx
>
(
q_ptr
,
cos
,
sin
,
rotary_dim
);
}
else
{
q_vec
=
vec_apply_llama_rope_cos_sin
<
vec_size
,
bdx
>
(
q_ptr
,
cos
,
sin
,
rotary_dim
);
}
q_vec
.
cast_store
(
q_rope_ptr
+
tx
*
vec_size
);
}
#pragma unroll 1
for
(
uint32_t
kv_head_idx
=
0
;
kv_head_idx
<
num_kv_heads
;
++
kv_head_idx
)
{
DType
*
k_ptr
=
k
+
get_elem_offset_impl
(
idx
,
kv_head_idx
,
0
,
k_stride_n
,
k_stride_h
);
DType
*
k_rope_ptr
=
k_rope
+
get_elem_offset_impl
(
idx
,
kv_head_idx
,
0
,
k_rope_stride_n
,
k_rope_stride_h
);
vec_t
<
float
,
vec_size
>
v_vec
;
IdType
kv_cache_offset
;
if
constexpr
(
save_kv_cache
)
{
kv_buffer_saver
::
prepare
<
DType
,
IdType
,
vec_size
>
(
v_vec
,
kv_cache_offset
,
v
,
kv_cache_loc
,
idx
,
tx
,
kv_head_idx
,
v_stride_n
,
v_stride_h
);
}
vec_t
<
float
,
vec_size
>
k_vec
;
if
constexpr
(
interleave
)
{
k_vec
=
vec_apply_llama_rope_cos_sin_interleave_reuse_half
<
vec_size
,
bdx
>
(
k_ptr
,
cos
,
sin
,
rotary_dim
);
}
else
{
k_vec
=
vec_apply_llama_rope_cos_sin
<
vec_size
,
bdx
>
(
k_ptr
,
cos
,
sin
,
rotary_dim
);
}
k_vec
.
cast_store
(
k_rope_ptr
+
tx
*
vec_size
);
if
constexpr
(
save_kv_cache
)
{
kv_buffer_saver
::
save
<
DType
,
IdType
,
vec_size
>
(
kv_cache_offset
,
k_vec
,
v_vec
,
k_buffer
,
v_buffer
,
idx
,
tx
,
kv_head_idx
,
k_buffer_stride_n
,
k_buffer_stride_h
,
v_buffer_stride_n
,
v_buffer_stride_h
);
}
}
}
}
#define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \
if (save_kv_cache) { \
const bool SAVE_KV_CACHE = true; \
__VA_ARGS__ \
} else { \
const bool SAVE_KV_CACHE = false; \
__VA_ARGS__ \
}
template
<
typename
DType
,
typename
IdType
>
cudaError_t
BatchQKApplyRotaryPosIdsCosSinCacheEnhanced
(
DType
*
q
,
DType
*
k
,
DType
*
v
,
DType
*
q_rope
,
DType
*
k_rope
,
DType
*
k_buffer
,
DType
*
v_buffer
,
float
*
cos_sin_cache
,
IdType
*
pos_ids
,
uint32_t
nnz
,
uint32_t
num_qo_heads
,
uint32_t
num_kv_heads
,
uint32_t
rotary_dim
,
uint32_t
head_dim
,
size_t
q_stride_n
,
size_t
q_stride_h
,
size_t
k_stride_n
,
size_t
k_stride_h
,
size_t
v_stride_n
,
size_t
v_stride_h
,
size_t
q_rope_stride_n
,
size_t
q_rope_stride_h
,
size_t
k_rope_stride_n
,
size_t
k_rope_stride_h
,
size_t
k_buffer_stride_n
,
size_t
k_buffer_stride_h
,
size_t
v_buffer_stride_n
,
size_t
v_buffer_stride_h
,
IdType
*
kv_cache_loc
,
bool
interleave
,
bool
save_kv_cache
,
cudaStream_t
stream
=
nullptr
)
{
int
dev_id
=
0
;
int
num_sms
=
0
;
FLASHINFER_CUDA_CALL
(
cudaGetDevice
(
&
dev_id
));
FLASHINFER_CUDA_CALL
(
cudaDeviceGetAttribute
(
&
num_sms
,
cudaDevAttrMultiProcessorCount
,
dev_id
));
DISPATCH_SAVE_KV_CACHE
(
save_kv_cache
,
SAVE_KV_CACHE
,
{
DISPATCH_INTERLEAVE
(
interleave
,
INTERLEAVE
,
{
DISPATCH_HEAD_DIM
(
head_dim
,
HEAD_DIM
,
{
// operate on 16 Bytes at a time
constexpr
uint32_t
vec_size
=
std
::
max
(
16
/
sizeof
(
DType
),
HEAD_DIM
/
32
);
// how many threads needed per head_dim
constexpr
uint32_t
bdx
=
HEAD_DIM
/
vec_size
;
// how many threads needed per block
uint32_t
num_threads
=
std
::
max
(
128U
,
bdx
);
// how many tokens can we process in a block
uint32_t
bdy
=
num_threads
/
bdx
;
// how many blocks needed to process all tokens
uint32_t
nblks_x
=
(
nnz
+
bdy
-
1
)
/
bdy
;
void
*
args
[]
=
{
(
void
*
)
&
q
,
(
void
*
)
&
k
,
(
void
*
)
&
v
,
(
void
*
)
&
q_rope
,
(
void
*
)
&
k_rope
,
(
void
*
)
&
k_buffer
,
(
void
*
)
&
v_buffer
,
(
void
*
)
&
cos_sin_cache
,
(
void
*
)
&
pos_ids
,
(
void
*
)
&
nnz
,
(
void
*
)
&
num_qo_heads
,
(
void
*
)
&
num_kv_heads
,
(
void
*
)
&
rotary_dim
,
(
void
*
)
&
q_stride_n
,
(
void
*
)
&
q_stride_h
,
(
void
*
)
&
k_stride_n
,
(
void
*
)
&
k_stride_h
,
(
void
*
)
&
v_stride_n
,
(
void
*
)
&
v_stride_h
,
(
void
*
)
&
q_rope_stride_n
,
(
void
*
)
&
q_rope_stride_h
,
(
void
*
)
&
k_rope_stride_n
,
(
void
*
)
&
k_rope_stride_h
,
(
void
*
)
&
k_buffer_stride_n
,
(
void
*
)
&
k_buffer_stride_h
,
(
void
*
)
&
v_buffer_stride_n
,
(
void
*
)
&
v_buffer_stride_h
,
(
void
*
)
&
kv_cache_loc
};
auto
kernel_0
=
BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel
<
SAVE_KV_CACHE
,
INTERLEAVE
,
HEAD_DIM
,
vec_size
,
bdx
,
DType
,
IdType
>
;
int
num_blocks_per_sm_0
=
0
;
FLASHINFER_CUDA_CALL
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
num_blocks_per_sm_0
,
kernel_0
,
num_threads
,
/*smem_size=*/
0
));
uint32_t
num_ctas_0
=
num_blocks_per_sm_0
*
num_sms
;
if
((
nnz
+
bdy
-
1
)
/
bdy
>=
num_ctas_0
)
{
dim3
nblks
(
nblks_x
);
dim3
nthrs
(
bdx
,
bdy
);
FLASHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel_0
,
nblks
,
nthrs
,
args
,
0
,
stream
));
}
else
{
dim3
nblks
(
nblks_x
,
num_qo_heads
+
num_kv_heads
);
dim3
nthrs
(
bdx
,
bdy
);
auto
kernel_1
=
BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel
<
SAVE_KV_CACHE
,
INTERLEAVE
,
HEAD_DIM
,
vec_size
,
bdx
,
DType
,
IdType
>
;
FLASHINFER_CUDA_CALL
(
cudaLaunchKernel
((
void
*
)
kernel_1
,
nblks
,
nthrs
,
args
,
0
,
stream
));
}
});
});
});
return
cudaSuccess
;
}
}
// namespace flashinfer
#endif // SGL_POS_ENC_CUH_
sgl-kernel/csrc/elementwise/rope.cu
View file @
9aea2555
...
...
@@ -13,8 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <flashinfer/pos_enc.cuh>
#include "pos_enc.cuh"
#include "pytorch_extension_utils.h"
using
namespace
flashinfer
;
...
...
@@ -27,9 +27,37 @@ void apply_rope_pos_ids_cos_sin_cache(
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
bool
interleave
,
int64_t
cuda_stream
)
{
int64_t
cuda_stream
,
const
std
::
optional
<
at
::
Tensor
>&
v
,
const
std
::
optional
<
at
::
Tensor
>&
k_buffer
,
const
std
::
optional
<
at
::
Tensor
>&
v_buffer
,
const
std
::
optional
<
at
::
Tensor
>&
kv_cache_loc
)
{
CHECK_LAST_DIM_CONTIGUOUS
(
q
);
CHECK_LAST_DIM_CONTIGUOUS
(
k
);
const
bool
save_kv_cache
=
v
.
has_value
();
if
(
save_kv_cache
)
{
TORCH_CHECK
(
v
.
has_value
());
TORCH_CHECK
(
k_buffer
.
has_value
());
TORCH_CHECK
(
v_buffer
.
has_value
());
TORCH_CHECK
(
kv_cache_loc
.
has_value
());
CHECK_LAST_DIM_CONTIGUOUS
(
v
.
value
());
CHECK_LAST_DIM_CONTIGUOUS
(
k_buffer
.
value
());
CHECK_LAST_DIM_CONTIGUOUS
(
v_buffer
.
value
());
CHECK_DIM
(
3
,
k_buffer
.
value
());
// k_buffer: (nnz, H_K, D)
CHECK_DIM
(
3
,
v_buffer
.
value
());
// v_buffer: (nnz, H_V, D)
CHECK_DIM
(
3
,
v
.
value
());
// v: (nnz, H_V, D)
CHECK_DIM
(
1
,
kv_cache_loc
.
value
());
// v: (n)
CHECK_INPUT
(
kv_cache_loc
.
value
());
}
size_t
k_buffer_stride_n
=
save_kv_cache
?
k_buffer
->
stride
(
0
)
:
0
;
size_t
k_buffer_stride_h
=
save_kv_cache
?
k_buffer
->
stride
(
1
)
:
0
;
size_t
v_buffer_stride_n
=
save_kv_cache
?
v_buffer
->
stride
(
0
)
:
0
;
size_t
v_buffer_stride_h
=
save_kv_cache
?
v_buffer
->
stride
(
1
)
:
0
;
size_t
v_stride_n
=
save_kv_cache
?
v
->
stride
(
0
)
:
0
;
size_t
v_stride_h
=
save_kv_cache
?
v
->
stride
(
1
)
:
0
;
auto
kv_cache_loc_ptr
=
save_kv_cache
?
static_cast
<
int64_t
*>
(
kv_cache_loc
->
data_ptr
())
:
nullptr
;
CHECK_INPUT
(
cos_sin_cache
);
CHECK_INPUT
(
pos_ids
);
auto
device
=
q
.
device
();
...
...
@@ -38,6 +66,7 @@ void apply_rope_pos_ids_cos_sin_cache(
CHECK_EQ
(
pos_ids
.
device
(),
device
);
CHECK_DIM
(
3
,
q
);
// q: (nnz, H_Q, D)
CHECK_DIM
(
3
,
k
);
// k: (nnz, H_K, D)
// cos_sin_cache: (max_seq_len, R)
// First half of R is cos, second half is sin
CHECK_DIM
(
2
,
cos_sin_cache
);
...
...
@@ -52,6 +81,7 @@ void apply_rope_pos_ids_cos_sin_cache(
size_t
q_stride_h
=
q
.
stride
(
1
);
size_t
k_stride_n
=
k
.
stride
(
0
);
size_t
k_stride_h
=
k
.
stride
(
1
);
size_t
q_rope_stride_n
=
q_rope
.
stride
(
0
);
size_t
q_rope_stride_h
=
q_rope
.
stride
(
1
);
size_t
k_rope_stride_n
=
k_rope
.
stride
(
0
);
...
...
@@ -59,31 +89,73 @@ void apply_rope_pos_ids_cos_sin_cache(
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
cuda_stream
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
q
.
scalar_type
(),
c_type
,
[
&
]
{
cudaError_t
status
=
BatchQKApplyRotaryPosIdsCosSinCache
(
static_cast
<
c_type
*>
(
q
.
data_ptr
()),
static_cast
<
c_type
*>
(
k
.
data_ptr
()),
static_cast
<
c_type
*>
(
q_rope
.
data_ptr
()),
static_cast
<
c_type
*>
(
k_rope
.
data_ptr
()),
static_cast
<
float
*>
(
cos_sin_cache
.
data_ptr
()),
static_cast
<
int64_t
*>
(
pos_ids
.
data_ptr
()),
nnz
,
num_qo_heads
,
num_kv_heads
,
rotary_dim
,
head_dim
,
q_stride_n
,
q_stride_h
,
k_stride_n
,
k_stride_h
,
q_rope_stride_n
,
q_rope_stride_h
,
k_rope_stride_n
,
k_rope_stride_h
,
interleave
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
// TODO temporarily only use `BatchQKApplyRotaryPosIdsCosSinCacheEnhanced` when save_kv_cache
// to avoid changing original code path; but this branch is feature-complete and should switch to this later
if
(
save_kv_cache
)
{
cudaError_t
status
=
BatchQKApplyRotaryPosIdsCosSinCacheEnhanced
(
static_cast
<
c_type
*>
(
q
.
data_ptr
()),
static_cast
<
c_type
*>
(
k
.
data_ptr
()),
save_kv_cache
?
static_cast
<
c_type
*>
(
v
->
data_ptr
())
:
nullptr
,
static_cast
<
c_type
*>
(
q_rope
.
data_ptr
()),
static_cast
<
c_type
*>
(
k_rope
.
data_ptr
()),
save_kv_cache
?
static_cast
<
c_type
*>
(
k_buffer
->
data_ptr
())
:
nullptr
,
save_kv_cache
?
static_cast
<
c_type
*>
(
v_buffer
->
data_ptr
())
:
nullptr
,
static_cast
<
float
*>
(
cos_sin_cache
.
data_ptr
()),
static_cast
<
int64_t
*>
(
pos_ids
.
data_ptr
()),
nnz
,
num_qo_heads
,
num_kv_heads
,
rotary_dim
,
head_dim
,
q_stride_n
,
q_stride_h
,
k_stride_n
,
k_stride_h
,
v_stride_n
,
v_stride_h
,
q_rope_stride_n
,
q_rope_stride_h
,
k_rope_stride_n
,
k_rope_stride_h
,
k_buffer_stride_n
,
k_buffer_stride_h
,
v_buffer_stride_n
,
v_buffer_stride_h
,
kv_cache_loc_ptr
,
interleave
,
save_kv_cache
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
}
else
{
cudaError_t
status
=
BatchQKApplyRotaryPosIdsCosSinCache
(
static_cast
<
c_type
*>
(
q
.
data_ptr
()),
static_cast
<
c_type
*>
(
k
.
data_ptr
()),
static_cast
<
c_type
*>
(
q_rope
.
data_ptr
()),
static_cast
<
c_type
*>
(
k_rope
.
data_ptr
()),
static_cast
<
float
*>
(
cos_sin_cache
.
data_ptr
()),
static_cast
<
int64_t
*>
(
pos_ids
.
data_ptr
()),
nnz
,
num_qo_heads
,
num_kv_heads
,
rotary_dim
,
head_dim
,
q_stride_n
,
q_stride_h
,
k_stride_n
,
k_stride_h
,
q_rope_stride_n
,
q_rope_stride_h
,
k_rope_stride_n
,
k_rope_stride_h
,
interleave
,
stream
);
TORCH_CHECK
(
status
==
cudaSuccess
,
"BatchQKApplyRotaryPosIdsCosSinCache failed with error code "
+
std
::
string
(
cudaGetErrorString
(
status
)));
}
return
true
;
});
}
sgl-kernel/include/sgl_kernel_ops.h
View file @
9aea2555
...
...
@@ -150,7 +150,11 @@ void apply_rope_pos_ids_cos_sin_cache(
at
::
Tensor
cos_sin_cache
,
at
::
Tensor
pos_ids
,
bool
interleave
,
int64_t
cuda_stream
);
int64_t
cuda_stream
,
const
std
::
optional
<
at
::
Tensor
>&
v
,
const
std
::
optional
<
at
::
Tensor
>&
k_buffer
,
const
std
::
optional
<
at
::
Tensor
>&
v_buffer
,
const
std
::
optional
<
at
::
Tensor
>&
kv_cache_loc
);
#ifdef USE_ROCM
void
gelu_quick
(
at
::
Tensor
&
out
,
const
at
::
Tensor
&
input
);
...
...
sgl-kernel/python/sgl_kernel/__init__.py
View file @
9aea2555
...
...
@@ -21,6 +21,7 @@ from sgl_kernel.attention import (
)
from
sgl_kernel.cutlass_moe
import
cutlass_w4a8_moe_mm
,
get_cutlass_w4a8_moe_mm_data
from
sgl_kernel.elementwise
import
(
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
,
fused_add_rmsnorm
,
gelu_and_mul
,
...
...
sgl-kernel/python/sgl_kernel/elementwise.py
View file @
9aea2555
from
typing
import
Optional
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
import
torch
from
sgl_kernel.utils
import
get_cuda_stream
,
is_hopper_arch
...
...
@@ -237,6 +238,31 @@ if torch.version.hip is not None:
return
out
@
dataclass
class
FusedSetKVBufferArg
:
"""
value : Optional[torch.Tensor]
Value tensor, shape: ``(nnz, num_v_heads * head_size)``.
k_buffer : Optional[torch.Tensor]
Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.
v_buffer : Optional[torch.Tensor]
Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.
k_scale : Optional[float]
Scale factor for keys.
v_scale : Optional[float]
Scale factor for values.
cache_loc : Optional[torch.Tensor]
Cache location tensor, used for indexing kv cache.
"""
value
:
torch
.
Tensor
k_buffer
:
torch
.
Tensor
v_buffer
:
torch
.
Tensor
k_scale
:
Optional
[
float
]
v_scale
:
Optional
[
float
]
cache_loc
:
torch
.
Tensor
def
apply_rope_with_cos_sin_cache_inplace
(
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
...
...
@@ -244,6 +270,7 @@ def apply_rope_with_cos_sin_cache_inplace(
head_size
:
int
,
cos_sin_cache
:
torch
.
Tensor
,
is_neox
:
bool
=
True
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
)
->
None
:
r
"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
...
...
@@ -270,6 +297,9 @@ def apply_rope_with_cos_sin_cache_inplace(
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
fused_set_kv_buffer_arg : FusedSetKVBufferArg
Fuse the set-kv-buffer operation into this kernel
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
...
...
@@ -277,13 +307,41 @@ def apply_rope_with_cos_sin_cache_inplace(
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
raise
ValueError
(
"cos_sin_cache should be float32"
)
if
(
a
:
=
fused_set_kv_buffer_arg
)
is
not
None
:
assert
a
.
k_scale
is
None
,
"k_scale is not yet supported"
assert
a
.
v_scale
is
None
,
"v_scale is not yet supported"
assert
a
.
cache_loc
.
dtype
==
torch
.
int64
,
f
"
{
a
.
cache_loc
.
dtype
=
}
"
def
_view_3d
(
x
):
return
x
.
view
(
x
.
shape
[
0
],
-
1
,
head_size
)
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
.
default
(
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
_
view
_3d
(
query
),
_
view
_3d
(
key
),
_
view
_3d
(
query
),
_
view
_3d
(
key
),
cos_sin_cache
,
positions
.
long
(),
(
not
is_neox
),
get_cuda_stream
(),
(
_view_3d
(
fused_set_kv_buffer_arg
.
value
)
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
(
_view_3d
(
fused_set_kv_buffer_arg
.
k_buffer
)
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
(
_view_3d
(
fused_set_kv_buffer_arg
.
v_buffer
)
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
(
fused_set_kv_buffer_arg
.
cache_loc
if
fused_set_kv_buffer_arg
is
not
None
else
None
),
)
sgl-kernel/python/sgl_kernel/testing/__init__.py
0 → 100644
View file @
9aea2555
sgl-kernel/python/sgl_kernel/testing/rotary_embedding.py
0 → 100644
View file @
9aea2555
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
pytest
import
torch
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
# vLLM torch native
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward_native
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation of forward()."""
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
# Modification: float32 is required for the rotary embedding to work correctly
query
=
query
.
to
(
torch
.
float32
)
key
=
key
.
to
(
torch
.
float32
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
# Modification: convert to the correct dtype
query
=
query
.
to
(
self
.
dtype
)
key
=
key
.
to
(
self
.
dtype
)
return
query
,
key
class
FlashInferRotaryEmbedding
(
RotaryEmbedding
):
def
forward_cuda
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
fused_set_kv_buffer_arg
:
Optional
[
FusedSetKVBufferArg
]
=
None
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
apply_rope_with_cos_sin_cache_inplace
(
positions
=
positions
,
query
=
query
,
key
=
key
,
fused_set_kv_buffer_arg
=
fused_set_kv_buffer_arg
,
head_size
=
self
.
head_size
,
cos_sin_cache
=
self
.
cos_sin_cache
,
is_neox
=
self
.
is_neox_style
,
)
return
query
,
key
class
MHATokenToKVPool
:
KV_POOL_SIZE
=
16384
def
__init__
(
self
,
head_num
:
int
,
head_dim
:
int
,
):
self
.
head_num
=
head_num
self
.
head_dim
=
head_dim
self
.
size
=
MHATokenToKVPool
.
KV_POOL_SIZE
self
.
page_size
=
1
self
.
store_dtype
=
torch
.
bfloat16
self
.
device
=
"cuda"
self
.
layer_num
=
1
self
.
start_layer
=
0
self
.
_create_buffers
()
def
_create_buffers
(
self
):
self
.
k_buffer
=
[
torch
.
zeros
(
(
self
.
size
+
self
.
page_size
,
self
.
head_num
,
self
.
head_dim
),
dtype
=
self
.
store_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
layer_num
)
]
self
.
v_buffer
=
[
torch
.
zeros
(
(
self
.
size
+
self
.
page_size
,
self
.
head_num
,
self
.
head_dim
),
dtype
=
self
.
store_dtype
,
device
=
self
.
device
,
)
for
_
in
range
(
self
.
layer_num
)
]
def
set_kv_buffer
(
self
,
loc
:
torch
.
Tensor
,
cache_k
:
torch
.
Tensor
,
cache_v
:
torch
.
Tensor
,
):
layer_id
=
0
self
.
k_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_k
self
.
v_buffer
[
layer_id
-
self
.
start_layer
][
loc
]
=
cache_v
def
create_inputs
(
head_size
:
int
,
batch_size
:
int
,
seq_len
:
int
,
device
,
dtype
:
torch
.
dtype
,
num_q_heads
:
int
,
num_kv_heads
:
int
,
):
pos_ids
=
torch
.
arange
(
seq_len
,
device
=
device
).
repeat
(
batch_size
)
query
=
torch
.
randn
(
batch_size
*
seq_len
,
num_q_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
)
key
=
torch
.
randn
(
batch_size
*
seq_len
,
num_kv_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
)
value
=
torch
.
randn
(
batch_size
*
seq_len
,
num_kv_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
)
out_cache_loc
=
torch
.
randperm
(
MHATokenToKVPool
.
KV_POOL_SIZE
,
dtype
=
torch
.
int64
,
device
=
device
)[:
batch_size
*
seq_len
].
clone
()
return
dict
(
pos_ids
=
pos_ids
,
query
=
query
,
key
=
key
,
value
=
value
,
out_cache_loc
=
out_cache_loc
)
sgl-kernel/tests/test_rotary_embedding.py
View file @
9aea2555
...
...
@@ -2,153 +2,51 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import
pytest
import
torch
from
sgl_kernel
import
apply_rope_with_cos_sin_cache_inplace
# vLLM torch native
def
_apply_rotary_emb
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
is_neox_style
:
bool
,
)
->
torch
.
Tensor
:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos
=
cos
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
sin
=
sin
.
unsqueeze
(
-
2
).
to
(
x
.
dtype
)
if
is_neox_style
:
x1
,
x2
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
else
:
x1
=
x
[...,
::
2
]
x2
=
x
[...,
1
::
2
]
o1
=
x1
*
cos
-
x2
*
sin
o2
=
x2
*
cos
+
x1
*
sin
if
is_neox_style
:
return
torch
.
cat
((
o1
,
o2
),
dim
=-
1
)
else
:
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
# Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/rotary_embedding.py
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
)
->
None
:
super
().
__init__
()
self
.
head_size
=
head_size
self
.
rotary_dim
=
rotary_dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
self
.
is_neox_style
=
is_neox_style
self
.
dtype
=
dtype
cache
=
self
.
_compute_cos_sin_cache
()
self
.
cos_sin_cache
:
torch
.
Tensor
self
.
register_buffer
(
"cos_sin_cache"
,
cache
,
persistent
=
False
)
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
"""Compute the cos and sin cache."""
inv_freq
=
self
.
_compute_inv_freq
(
self
.
base
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
,
dtype
=
torch
.
float
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
def
forward_native
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""A PyTorch-native implementation of forward()."""
if
offsets
is
not
None
:
positions
=
positions
+
offsets
positions
=
positions
.
flatten
()
num_tokens
=
positions
.
shape
[
0
]
cos_sin
=
self
.
cos_sin_cache
.
index_select
(
0
,
positions
)
# Modification: float32 is required for the rotary embedding to work correctly
query
=
query
.
to
(
torch
.
float32
)
key
=
key
.
to
(
torch
.
float32
)
cos
,
sin
=
cos_sin
.
chunk
(
2
,
dim
=-
1
)
query_shape
=
query
.
shape
query
=
query
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
query_rot
=
query
[...,
:
self
.
rotary_dim
]
query_pass
=
query
[...,
self
.
rotary_dim
:]
query_rot
=
_apply_rotary_emb
(
query_rot
,
cos
,
sin
,
self
.
is_neox_style
)
query
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
).
reshape
(
query_shape
)
key_shape
=
key
.
shape
key
=
key
.
view
(
num_tokens
,
-
1
,
self
.
head_size
)
key_rot
=
key
[...,
:
self
.
rotary_dim
]
key_pass
=
key
[...,
self
.
rotary_dim
:]
key_rot
=
_apply_rotary_emb
(
key_rot
,
cos
,
sin
,
self
.
is_neox_style
)
key
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
).
reshape
(
key_shape
)
# Modification: convert to the correct dtype
query
=
query
.
to
(
self
.
dtype
)
key
=
key
.
to
(
self
.
dtype
)
return
query
,
key
class
FlashInferRotaryEmbedding
(
RotaryEmbedding
):
def
forward_cuda
(
self
,
positions
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
apply_rope_with_cos_sin_cache_inplace
(
positions
=
positions
,
query
=
query
,
key
=
key
,
head_size
=
self
.
head_size
,
cos_sin_cache
=
self
.
cos_sin_cache
,
is_neox
=
self
.
is_neox_style
,
)
return
query
,
key
from
sgl_kernel
import
FusedSetKVBufferArg
,
apply_rope_with_cos_sin_cache_inplace
from
sgl_kernel.testing.rotary_embedding
import
(
FlashInferRotaryEmbedding
,
MHATokenToKVPool
,
RotaryEmbedding
,
create_inputs
,
)
@
pytest
.
mark
.
parametrize
(
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads"
,
"head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype, device, batch_size, seq_len, num_q_heads, num_kv_heads
, save_kv_cache
"
,
[
(
64
,
64
,
32
,
8000
,
True
,
torch
.
bfloat16
,
"cuda"
,
32
,
32
,
1
,
1
),
(
256
,
128
,
4096
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
4
,
2
),
(
512
,
128
,
311
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
3
,
39
,
4
,
2
),
(
128
,
128
,
2048
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
32
,
8
),
(
128
,
128
,
2048
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
16
,
4
),
(
512
,
128
,
311
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
3
,
39
,
4
,
2
),
# GPT-OSS cases
*
[
(
64
,
64
,
4096
,
8000
,
True
,
torch
.
bfloat16
,
"cuda"
,
batch_size
,
seq_len
,
64
,
8
,
save_kv_cache
,
)
for
batch_size
,
seq_len
in
(
(
1
,
1
),
(
32
,
1
),
(
128
,
1
),
(
512
,
1
),
(
2
,
512
),
(
4
,
4096
),
)
for
save_kv_cache
in
(
False
,
True
)
],
# Other cases
(
64
,
64
,
32
,
8000
,
True
,
torch
.
bfloat16
,
"cuda"
,
32
,
32
,
1
,
1
,
False
),
(
256
,
128
,
4096
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
4
,
2
,
False
),
(
512
,
128
,
311
,
10000
,
True
,
torch
.
bfloat16
,
"cuda"
,
3
,
39
,
4
,
2
,
False
),
(
128
,
128
,
2048
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
32
,
8
,
False
),
(
128
,
128
,
2048
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
2
,
512
,
16
,
4
,
False
),
(
512
,
128
,
311
,
10000
,
False
,
torch
.
bfloat16
,
"cuda"
,
3
,
39
,
4
,
2
,
False
),
],
)
def
test_correctness
(
...
...
@@ -163,34 +61,77 @@ def test_correctness(
seq_len
:
int
,
num_q_heads
:
int
,
num_kv_heads
:
int
,
save_kv_cache
:
bool
,
):
rope_ref
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
).
to
(
device
)
rope_flashinfer
=
FlashInferRotaryEmbedding
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
).
to
(
device
)
pos_ids
=
torch
.
arange
(
seq_len
,
device
=
device
).
repeat
(
batch_size
)
query
=
torch
.
randn
(
batch_size
*
seq_len
,
num_q_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
config
=
dict
(
head_size
=
head_size
,
rotary_dim
=
rotary_dim
,
max_position_embeddings
=
max_position_embeddings
,
base
=
base
,
is_neox_style
=
is_neox_style
,
dtype
=
dtype
,
)
key
=
torch
.
randn
(
batch_size
*
seq_len
,
num_kv_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
rope_ref
=
RotaryEmbedding
(
**
config
).
to
(
device
)
rope_flashinfer
=
FlashInferRotaryEmbedding
(
**
config
).
to
(
device
)
inputs
=
create_inputs
(
head_size
=
head_size
,
batch_size
=
batch_size
,
seq_len
=
seq_len
,
device
=
device
,
dtype
=
dtype
,
num_q_heads
=
num_q_heads
,
num_kv_heads
=
num_kv_heads
,
)
query_ref
,
key_ref
=
query
.
clone
(),
key
.
clone
()
query_flashinfer
,
key_flashinfer
=
query
.
clone
(),
key
.
clone
()
if
save_kv_cache
:
pool_ref
=
MHATokenToKVPool
(
head_num
=
num_kv_heads
,
head_dim
=
head_size
)
pool_flashinfer
=
MHATokenToKVPool
(
head_num
=
num_kv_heads
,
head_dim
=
head_size
)
query_ref
,
key_ref
=
inputs
[
"query"
].
clone
(),
inputs
[
"key"
].
clone
()
query_flashinfer
,
key_flashinfer
=
inputs
[
"query"
].
clone
(),
inputs
[
"key"
].
clone
()
query_ref_out
,
key_ref_out
=
rope_ref
.
forward_native
(
inputs
[
"pos_ids"
],
query_ref
,
key_ref
)
if
save_kv_cache
:
pool_ref
.
set_kv_buffer
(
loc
=
inputs
[
"out_cache_loc"
],
cache_k
=
key_ref_out
.
view
(
-
1
,
num_kv_heads
,
head_size
),
cache_v
=
inputs
[
"value"
].
view
(
-
1
,
num_kv_heads
,
head_size
),
)
query_ref_out
,
key_ref_out
=
rope_ref
.
forward_native
(
pos_ids
,
query_ref
,
key_ref
)
query_flashinfer_out
,
key_flashinfer_out
=
rope_flashinfer
.
forward_cuda
(
pos_ids
,
query_flashinfer
,
key_flashinfer
inputs
[
"pos_ids"
],
query_flashinfer
,
key_flashinfer
,
fused_set_kv_buffer_arg
=
(
FusedSetKVBufferArg
(
value
=
inputs
[
"value"
],
k_buffer
=
pool_flashinfer
.
k_buffer
[
0
].
view
(
-
1
,
num_kv_heads
*
head_size
),
v_buffer
=
pool_flashinfer
.
v_buffer
[
0
].
view
(
-
1
,
num_kv_heads
*
head_size
),
k_scale
=
None
,
v_scale
=
None
,
cache_loc
=
inputs
[
"out_cache_loc"
],
)
if
save_kv_cache
else
None
),
)
torch
.
testing
.
assert_close
(
query_ref_out
,
query_flashinfer_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
key_ref_out
,
key_flashinfer_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
if
save_kv_cache
:
for
field
in
[
"k_buffer"
,
"v_buffer"
]:
x_ref
=
getattr
(
pool_ref
,
field
)[
0
]
x_flashinfer
=
getattr
(
pool_flashinfer
,
field
)[
0
]
torch
.
testing
.
assert_close
(
x_ref
,
x_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
)
nonzero_ref
=
x_ref
!=
0
nonzero_flashinfer
=
x_ref
!=
0
assert
torch
.
all
(
nonzero_ref
==
nonzero_flashinfer
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment