Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
cf0ccd40
"src/kernels/git@developer.sourcefind.cn:Fzc7075/nunchaku.git" did not exist on "fcc551cbdf05ef976812a341ddeb349de2a1e464"
Unverified
Commit
cf0ccd40
authored
Mar 10, 2025
by
Lianmin Zheng
Committed by
GitHub
Mar 10, 2025
Browse files
Optimize rope in sgl kernel (#4267)
parent
3d56585a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
3 deletions
+2
-3
sgl-kernel/csrc/elementwise/rope.cu
sgl-kernel/csrc/elementwise/rope.cu
+1
-1
sgl-kernel/python/sgl_kernel/elementwise.py
sgl-kernel/python/sgl_kernel/elementwise.py
+1
-2
No files found.
sgl-kernel/csrc/elementwise/rope.cu
View file @
cf0ccd40
...
@@ -65,7 +65,7 @@ void apply_rope_pos_ids_cos_sin_cache(
...
@@ -65,7 +65,7 @@ void apply_rope_pos_ids_cos_sin_cache(
static_cast
<
c_type
*>
(
q_rope
.
data_ptr
()),
static_cast
<
c_type
*>
(
q_rope
.
data_ptr
()),
static_cast
<
c_type
*>
(
k_rope
.
data_ptr
()),
static_cast
<
c_type
*>
(
k_rope
.
data_ptr
()),
static_cast
<
float
*>
(
cos_sin_cache
.
data_ptr
()),
static_cast
<
float
*>
(
cos_sin_cache
.
data_ptr
()),
static_cast
<
int
32
_t
*>
(
pos_ids
.
data_ptr
()),
static_cast
<
int
64
_t
*>
(
pos_ids
.
data_ptr
()),
nnz
,
nnz
,
num_qo_heads
,
num_qo_heads
,
num_kv_heads
,
num_kv_heads
,
...
...
sgl-kernel/python/sgl_kernel/elementwise.py
View file @
cf0ccd40
...
@@ -139,14 +139,13 @@ def apply_rope_with_cos_sin_cache_inplace(
...
@@ -139,14 +139,13 @@ def apply_rope_with_cos_sin_cache_inplace(
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
if
cos_sin_cache
.
dtype
!=
torch
.
float32
:
raise
ValueError
(
"cos_sin_cache should be float32"
)
raise
ValueError
(
"cos_sin_cache should be float32"
)
positions
=
positions
.
int
()
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
(
torch
.
ops
.
sgl_kernel
.
apply_rope_pos_ids_cos_sin_cache
(
q
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
q
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
k
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
q_rope
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
q_rope
=
query
.
view
(
query
.
shape
[
0
],
-
1
,
head_size
),
k_rope
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
k_rope
=
key
.
view
(
key
.
shape
[
0
],
-
1
,
head_size
),
cos_sin_cache
=
cos_sin_cache
,
cos_sin_cache
=
cos_sin_cache
,
pos_ids
=
positions
,
pos_ids
=
positions
.
long
()
,
interleave
=
(
not
is_neox
),
interleave
=
(
not
is_neox
),
cuda_stream
=
get_cuda_stream
(),
cuda_stream
=
get_cuda_stream
(),
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment