Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ff00895c
Unverified
Commit
ff00895c
authored
Jun 03, 2025
by
jianan-gu
Committed by
GitHub
Jun 02, 2025
Browse files
Add CPU optimized kernels for topk and rope fusions (#6456)
parent
ff914748
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
833 additions
and
102 deletions
+833
-102
sgl-kernel/csrc/cpu/norm.cpp
sgl-kernel/csrc/cpu/norm.cpp
+77
-0
sgl-kernel/csrc/cpu/rope.cpp
sgl-kernel/csrc/cpu/rope.cpp
+310
-93
sgl-kernel/csrc/cpu/topk.cpp
sgl-kernel/csrc/cpu/topk.cpp
+221
-0
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
+25
-4
test/srt/cpu/test_norm.py
test/srt/cpu/test_norm.py
+14
-0
test/srt/cpu/test_rope.py
test/srt/cpu/test_rope.py
+103
-5
test/srt/cpu/test_topk.py
test/srt/cpu/test_topk.py
+83
-0
No files found.
sgl-kernel/csrc/cpu/norm.cpp
View file @
ff00895c
...
...
@@ -4,6 +4,67 @@
namespace
{
// NB: avoid using `at::vec::map<>` on bfloat16 or half
// Llama4TextL2Norm
template
<
typename
scalar_t
>
void
l2norm_kernel_impl
(
scalar_t
*
__restrict__
output
,
const
scalar_t
*
__restrict__
input
,
int64_t
batch_size
,
int64_t
hidden_size
,
float
eps
=
1e-5
)
{
using
bVec
=
at
::
vec
::
Vectorized
<
scalar_t
>
;
using
fVec
=
at
::
vec
::
Vectorized
<
float
>
;
constexpr
int
kVecSize
=
bVec
::
size
();
at
::
parallel_for
(
0
,
batch_size
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
// local ptrs
scalar_t
*
__restrict__
out_ptr
=
output
+
i
*
hidden_size
;
const
scalar_t
*
__restrict__
input_ptr
=
input
+
i
*
hidden_size
;
fVec
sum_fvec
=
fVec
(
float
(
0
));
float
sum_val
=
float
(
0
);
int64_t
d
;
#pragma GCC unroll 4
for
(
d
=
0
;
d
<=
hidden_size
-
kVecSize
;
d
+=
kVecSize
)
{
bVec
x_bvec
=
bVec
::
loadu
(
input_ptr
+
d
);
fVec
x_fvec0
,
x_fvec1
;
std
::
tie
(
x_fvec0
,
x_fvec1
)
=
at
::
vec
::
convert_to_float
(
x_bvec
);
sum_fvec
+=
x_fvec0
*
x_fvec0
;
sum_fvec
+=
x_fvec1
*
x_fvec1
;
}
#pragma GCC unroll 4
for
(;
d
<
hidden_size
;
++
d
)
{
float
x_val
=
static_cast
<
float
>
(
input_ptr
[
d
]);
sum_val
+=
x_val
*
x_val
;
}
sum_val
+=
vec_reduce_sum
(
sum_fvec
);
float
rsqrt_var
=
float
(
1
)
/
std
::
sqrt
(
sum_val
/
hidden_size
+
eps
);
const
fVec
scale_fvec
=
fVec
(
rsqrt_var
);
#pragma GCC unroll 4
for
(
d
=
0
;
d
<=
hidden_size
-
kVecSize
;
d
+=
kVecSize
)
{
bVec
x_bvec
=
bVec
::
loadu
(
input_ptr
+
d
);
fVec
x_fvec0
,
x_fvec1
;
std
::
tie
(
x_fvec0
,
x_fvec1
)
=
at
::
vec
::
convert_to_float
(
x_bvec
);
x_fvec0
=
x_fvec0
*
scale_fvec
;
x_fvec1
=
x_fvec1
*
scale_fvec
;
bVec
out_bvec
=
convert_from_float_ext
<
scalar_t
>
(
x_fvec0
,
x_fvec1
);
out_bvec
.
store
(
out_ptr
+
d
);
}
#pragma GCC unroll 4
for
(;
d
<
hidden_size
;
++
d
)
{
float
x_val
=
static_cast
<
float
>
(
input_ptr
[
d
]);
out_ptr
[
d
]
=
static_cast
<
scalar_t
>
(
x_val
*
rsqrt_var
);
}
}
});
}
template
<
typename
scalar_t
>
void
rmsnorm_kernel_impl
(
scalar_t
*
__restrict__
output
,
...
...
@@ -160,6 +221,22 @@ void fused_add_rmsnorm_kernel_impl(
}
// anonymous namespace
// input : {batch_size, hidden_size}
at
::
Tensor
l2norm_cpu
(
at
::
Tensor
&
input
,
double
eps
)
{
RECORD_FUNCTION
(
"sgl-kernel::l2norm_cpu"
,
std
::
vector
<
c10
::
IValue
>
({
input
}));
CHECK_INPUT
(
input
);
CHECK_DIM
(
2
,
input
);
int64_t
batch_size
=
input
.
size
(
0
);
int64_t
hidden_size
=
input
.
size
(
1
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
input
.
scalar_type
(),
"l2norm_kernel"
,
[
&
]
{
l2norm_kernel_impl
<
scalar_t
>
(
output
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
batch_size
,
hidden_size
,
eps
);
});
return
output
;
}
// input : {batch_size, hidden_size}
// weight: {hidden_size}
at
::
Tensor
rmsnorm_cpu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
)
{
...
...
sgl-kernel/csrc/cpu/rope.cpp
View file @
ff00895c
...
...
@@ -4,126 +4,343 @@
namespace
{
template
<
typename
scalar_t
>
void
rope_kernel_impl
(
scalar_t
*
__restrict__
q_pe_out
,
scalar_t
*
__restrict__
k_pe_out
,
int64_t
*
__restrict__
t_pos
,
scalar_t
*
__restrict__
q_pe
,
scalar_t
*
__restrict__
k_pe
,
scalar_t
*
__restrict__
t_emb_pos
,
int64_t
seq_len
,
int64_t
num_head
,
void
rotary_embedding_3D_kernel_impl
(
scalar_t
*
__restrict__
query_out
,
scalar_t
*
__restrict__
key_out
,
int64_t
*
__restrict__
positions
,
scalar_t
*
__restrict__
query
,
scalar_t
*
__restrict__
key
,
scalar_t
*
__restrict__
cos_sin_cache
,
int64_t
num_tokens
,
int64_t
num_heads
,
int64_t
num_kv_heads
,
int64_t
head_size
,
int64_t
rotary_dim
,
int64_t
HR
,
int64_t
q
_pe
_stride_s
,
int64_t
out_stride_
q
s
,
int64_t
out
_stride_
k
s
,
int64_t
HK
,
int64_t
k_pe
_stride_
s
,
int64_t
q_pe_stride_n
,
int64_t
out_stride_qn
)
{
int64_t
query_stride_s
,
int64_t
q
uery_out
_stride_s
,
int64_t
key_
out_stride_s
,
int64_t
key
_stride_s
,
int64_t
query_stride_h
,
int64_t
query_out
_stride_
h
)
{
int64_t
HR
=
rotary_dim
;
int64_t
HK
=
rotary_dim
;
int64_t
COFF
=
HR
/
2
;
at
::
parallel_for
(
0
,
seq_l
en
*
num_head
,
GRAIN_SIZE
/
rotary_dim
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
at
::
parallel_for
(
0
,
num_tok
en
s
*
num_head
s
,
GRAIN_SIZE
/
rotary_dim
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int64_t
seq
{
0
},
head_id
{
0
};
data_index_init
(
begin
,
seq
,
seq_l
en
,
head_id
,
num_head
);
data_index_init
(
begin
,
seq
,
num_tok
en
s
,
head_id
,
num_head
s
);
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
int64_t
in_offset_q
=
seq
*
q
_pe
_stride_s
+
head_id
*
q
_pe
_stride_
n
;
int64_t
out_offset_q
=
seq
*
out_stride_
q
s
+
head_id
*
out_stride_
qn
;
int64_t
out_offset_k
=
seq
*
out_stride_
k
s
;
int64_t
in_offset_q
=
seq
*
q
uery
_stride_s
+
head_id
*
q
uery
_stride_
h
;
int64_t
out_offset_q
=
seq
*
query_
out_stride_s
+
head_id
*
query_
out_stride_
h
;
int64_t
out_offset_k
=
seq
*
key_
out_stride_s
;
int64_t
p
=
0
;
scalar_t
*
sin_start
=
nullptr
;
scalar_t
*
cos_start
=
nullptr
;
// step 0) get the rotary position embedding for the current position
p
=
t_
pos
[
seq
];
sin_start
=
t_emb_pos
+
p
*
HR
+
COFF
;
cos_start
=
t_emb_pos
+
p
*
HR
;
p
=
po
sition
s
[
seq
];
sin_start
=
cos_sin_cache
+
p
*
HR
+
COFF
;
cos_start
=
cos_sin_cache
+
p
*
HR
;
// step 1) apply_rotary_pos_emb for the rotary_dim elements in every
// head of query/key
for
(
int64_t
h
=
0
;
h
<
rotary_dim
;
h
+=
2
)
{
scalar_t
cos
=
cos_start
[
h
>>
1
];
scalar_t
sin
=
sin_start
[
h
>>
1
];
scalar_t
in1
=
q
_pe
[
in_offset_q
+
h
];
scalar_t
in2
=
q
_pe
[
in_offset_q
+
h
+
1
];
scalar_t
in1
=
q
uery
[
in_offset_q
+
h
];
scalar_t
in2
=
q
uery
[
in_offset_q
+
h
+
1
];
scalar_t
out1
=
in1
*
cos
-
in2
*
sin
;
scalar_t
out2
=
in2
*
cos
+
in1
*
sin
;
q
_pe
_out
[
out_offset_q
+
h
]
=
out1
;
q
_pe
_out
[
out_offset_q
+
h
+
1
]
=
out2
;
q
uery
_out
[
out_offset_q
+
h
]
=
out1
;
q
uery
_out
[
out_offset_q
+
h
+
1
]
=
out2
;
}
for
(
int64_t
h
=
0
;
h
<
HK
;
h
+=
2
)
{
scalar_t
cos
=
cos_start
[
h
>>
1
];
scalar_t
sin
=
sin_start
[
h
>>
1
];
int64_t
k_pe_offset
=
seq
*
k
_p
e_stride_s
;
scalar_t
in1_k
=
k
_p
e
[
k_pe_offset
+
h
];
scalar_t
in2_k
=
k
_p
e
[
k_pe_offset
+
h
+
1
];
int64_t
k_pe_offset
=
seq
*
ke
y
_stride_s
;
scalar_t
in1_k
=
ke
y
[
k_pe_offset
+
h
];
scalar_t
in2_k
=
ke
y
[
k_pe_offset
+
h
+
1
];
scalar_t
out1_k
=
in1_k
*
cos
-
in2_k
*
sin
;
scalar_t
out2_k
=
in2_k
*
cos
+
in1_k
*
sin
;
k
_p
e_out
[
out_offset_k
+
h
]
=
out1_k
;
k
_p
e_out
[
out_offset_k
+
h
+
1
]
=
out2_k
;
ke
y
_out
[
out_offset_k
+
h
]
=
out1_k
;
ke
y
_out
[
out_offset_k
+
h
+
1
]
=
out2_k
;
}
// move to the next index
data_index_step
(
seq
,
seq_l
en
,
head_id
,
num_head
);
data_index_step
(
seq
,
num_tok
en
s
,
head_id
,
num_head
s
);
}
});
}
template
<
typename
scalar_t
>
void
rotary_embedding_neox_2D_kernel_impl
(
int64_t
*
__restrict__
positions
,
scalar_t
*
__restrict__
query
,
scalar_t
*
__restrict__
key
,
scalar_t
*
__restrict__
cos_sin_cache
,
int64_t
rotary_dim
,
int64_t
query_stride_s
,
int64_t
key_stride_s
,
int64_t
num_heads
,
int64_t
num_kv_heads
,
int64_t
head_size
,
int64_t
num_tokens
)
{
using
bVec
=
at
::
vec
::
Vectorized
<
scalar_t
>
;
using
fVec
=
at
::
vec
::
Vectorized
<
float
>
;
constexpr
int64_t
bVecSize
=
bVec
::
size
();
int64_t
embed_dim
=
rotary_dim
/
2
;
bool
flag
=
(
embed_dim
%
bVecSize
==
0
);
int64_t
loop_upper
=
flag
?
embed_dim
:
embed_dim
-
bVecSize
;
auto
compute_loop
=
[
&
](
int64_t
token_head
,
scalar_t
*
cache_ptr
,
scalar_t
*
qk
)
{
int64_t
j
=
0
;
for
(;
j
<
loop_upper
;
j
+=
bVecSize
)
{
int64_t
rot_offset
=
j
;
int64_t
x_index
=
rot_offset
;
int64_t
y_index
=
embed_dim
+
rot_offset
;
int64_t
out_x
=
token_head
+
x_index
;
int64_t
out_y
=
token_head
+
y_index
;
bVec
_cos
=
bVec
::
loadu
(
cache_ptr
+
x_index
);
bVec
_sin
=
bVec
::
loadu
(
cache_ptr
+
y_index
);
bVec
_q_x
=
bVec
::
loadu
(
qk
+
out_x
);
bVec
_q_y
=
bVec
::
loadu
(
qk
+
out_y
);
fVec
_cos_0
,
_cos_1
;
std
::
tie
(
_cos_0
,
_cos_1
)
=
at
::
vec
::
convert_to_float
(
_cos
);
fVec
_sin_0
,
_sin_1
;
std
::
tie
(
_sin_0
,
_sin_1
)
=
at
::
vec
::
convert_to_float
(
_sin
);
fVec
_q_x_0
,
_q_x_1
;
std
::
tie
(
_q_x_0
,
_q_x_1
)
=
at
::
vec
::
convert_to_float
(
_q_x
);
fVec
_q_y_0
,
_q_y_1
;
std
::
tie
(
_q_y_0
,
_q_y_1
)
=
at
::
vec
::
convert_to_float
(
_q_y
);
auto
out1_0
=
_q_x_0
*
_cos_0
-
_q_y_0
*
_sin_0
;
auto
out1_1
=
_q_x_1
*
_cos_1
-
_q_y_1
*
_sin_1
;
auto
out1
=
convert_from_float_ext
<
scalar_t
>
(
out1_0
,
out1_1
);
out1
.
store
(
qk
+
out_x
);
auto
out2_0
=
_q_y_0
*
_cos_0
+
_q_x_0
*
_sin_0
;
auto
out2_1
=
_q_y_1
*
_cos_1
+
_q_x_1
*
_sin_1
;
auto
out2
=
convert_from_float_ext
<
scalar_t
>
(
out2_0
,
out2_1
);
out2
.
store
(
qk
+
out_y
);
}
if
(
!
flag
)
{
for
(;
j
<
embed_dim
;
++
j
)
{
int64_t
x_index
=
j
;
int64_t
y_index
=
embed_dim
+
j
;
int64_t
out_x
=
token_head
+
x_index
;
int64_t
out_y
=
token_head
+
y_index
;
float
_cos
=
cache_ptr
[
x_index
];
float
_sin
=
cache_ptr
[
y_index
];
float
_q_x
=
qk
[
out_x
];
float
_q_y
=
qk
[
out_y
];
qk
[
out_x
]
=
_q_x
*
_cos
-
_q_y
*
_sin
;
qk
[
out_y
]
=
_q_y
*
_cos
+
_q_x
*
_sin
;
}
}
};
#pragma omp parallel for
for
(
int64_t
token_idx
=
0
;
token_idx
<
num_tokens
;
++
token_idx
)
{
int64_t
pos
=
positions
[
token_idx
];
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rotary_dim
;
for
(
int64_t
i
=
0
;
i
<
num_heads
;
++
i
)
{
int64_t
head_idx
=
i
;
int64_t
token_head
=
token_idx
*
query_stride_s
+
head_idx
*
head_size
;
compute_loop
(
token_head
,
cache_ptr
,
query
);
}
for
(
int64_t
i
=
0
;
i
<
num_kv_heads
;
++
i
)
{
int64_t
head_idx
=
i
;
int64_t
token_head
=
token_idx
*
key_stride_s
+
head_idx
*
head_size
;
compute_loop
(
token_head
,
cache_ptr
,
key
);
}
}
}
template
<
typename
scalar_t
>
void
rotary_embedding_2D_kernel_impl
(
int64_t
*
__restrict__
positions
,
scalar_t
*
__restrict__
query
,
scalar_t
*
__restrict__
key
,
scalar_t
*
__restrict__
cos_sin_cache
,
int64_t
rotary_dim
,
int64_t
query_stride_s
,
int64_t
key_stride_s
,
int64_t
num_heads
,
int64_t
num_kv_heads
,
int64_t
head_size
,
int64_t
num_tokens
)
{
int64_t
embed_dim
=
rotary_dim
/
2
;
at
::
parallel_for
(
0
,
num_tokens
*
num_heads
,
GRAIN_SIZE
/
rotary_dim
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int64_t
token_idx
=
{
0
},
i
=
{
0
};
data_index_init
(
begin
,
token_idx
,
num_tokens
,
i
,
num_heads
);
for
([[
maybe_unused
]]
auto
z
:
c10
::
irange
(
begin
,
end
))
{
int64_t
pos
=
positions
[
token_idx
];
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rotary_dim
;
scalar_t
*
cos_cache_ptr
=
cache_ptr
;
scalar_t
*
sin_cache_ptr
=
cache_ptr
+
embed_dim
;
int64_t
head_idx
=
i
;
int64_t
token_head
=
token_idx
*
query_stride_s
+
head_idx
*
head_size
;
scalar_t
*
head_query
=
token_head
+
query
;
for
(
int64_t
j
=
0
;
j
<
embed_dim
;
j
+=
1
)
{
int64_t
rot_offset
=
j
;
int64_t
x_index
=
2
*
rot_offset
;
int64_t
y_index
=
2
*
rot_offset
+
1
;
float
cos
=
cos_cache_ptr
[
rot_offset
];
float
sin
=
sin_cache_ptr
[
rot_offset
];
float
x
=
head_query
[
x_index
];
float
y
=
head_query
[
y_index
];
head_query
[
x_index
]
=
x
*
cos
-
y
*
sin
;
head_query
[
y_index
]
=
y
*
cos
+
x
*
sin
;
}
data_index_step
(
token_idx
,
num_tokens
,
i
,
num_heads
);
}
});
at
::
parallel_for
(
0
,
num_tokens
*
num_kv_heads
,
GRAIN_SIZE
/
rotary_dim
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
int64_t
token_idx
{
0
},
i
=
{
0
};
data_index_init
(
begin
,
token_idx
,
num_tokens
,
i
,
num_kv_heads
);
for
([[
maybe_unused
]]
auto
z
:
c10
::
irange
(
begin
,
end
))
{
int64_t
pos
=
positions
[
token_idx
];
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rotary_dim
;
scalar_t
*
cos_cache_ptr
=
cache_ptr
;
scalar_t
*
sin_cache_ptr
=
cache_ptr
+
embed_dim
;
int64_t
head_idx
=
i
;
int64_t
token_head
=
token_idx
*
key_stride_s
+
head_idx
*
head_size
;
scalar_t
*
head_key
=
key
+
token_head
;
for
(
int64_t
j
=
0
;
j
<
embed_dim
;
j
+=
1
)
{
int64_t
rot_offset
=
j
;
int64_t
x_index
=
2
*
rot_offset
;
int64_t
y_index
=
2
*
rot_offset
+
1
;
float
cos
=
cos_cache_ptr
[
rot_offset
];
float
sin
=
sin_cache_ptr
[
rot_offset
];
float
x
=
head_key
[
x_index
];
float
y
=
head_key
[
y_index
];
head_key
[
x_index
]
=
x
*
cos
-
y
*
sin
;
head_key
[
y_index
]
=
y
*
cos
+
x
*
sin
;
}
data_index_step
(
token_idx
,
num_tokens
,
i
,
num_kv_heads
);
}
});
}
}
// namespace
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
rotary_position_embedding_cpu
(
at
::
Tensor
&
t_pos
,
at
::
Tensor
&
q_pe
,
at
::
Tensor
&
k_pe
,
at
::
Tensor
&
t_emb_pos
)
{
RECORD_FUNCTION
(
"sgl-kernel::rotary_position_embedding_cpu"
,
std
::
vector
<
c10
::
IValue
>
({
t_pos
,
q_pe
,
k_pe
,
t_emb_pos
}));
CHECK_INPUT
(
t_pos
);
CHECK_LAST_DIM_CONTIGUOUS_INPUT
(
q_pe
);
CHECK_LAST_DIM_CONTIGUOUS_INPUT
(
k_pe
);
CHECK_INPUT
(
t_emb_pos
);
CHECK_DIM
(
1
,
t_pos
);
CHECK_DIM
(
3
,
q_pe
);
CHECK_DIM
(
3
,
k_pe
);
CHECK_DIM
(
2
,
t_emb_pos
);
int64_t
seq_len
=
q_pe
.
size
(
0
);
int64_t
num_head
=
q_pe
.
size
(
1
);
int64_t
rotary_dim
=
q_pe
.
size
(
2
);
int64_t
HK
=
k_pe
.
size
(
2
);
int64_t
HR
=
t_emb_pos
.
size
(
1
);
CHECK_EQ
(
HR
,
rotary_dim
);
CHECK_EQ
(
k_pe
.
size
(
0
),
seq_len
);
CHECK_EQ
(
k_pe
.
size
(
1
),
1
);
CHECK_EQ
(
t_pos
.
size
(
0
),
seq_len
);
CHECK_EQ
(
HK
,
rotary_dim
);
at
::
Tensor
q_pe_out
=
at
::
empty_like
(
q_pe
);
at
::
Tensor
k_pe_out
=
at
::
empty_like
(
k_pe
);
int64_t
q_pe_stride_s
=
q_pe
.
stride
(
0
);
int64_t
q_pe_stride_n
=
q_pe
.
stride
(
1
);
int64_t
k_pe_stride_s
=
k_pe
.
stride
(
0
);
int64_t
out_stride_qs
=
q_pe_out
.
stride
(
0
);
int64_t
out_stride_qn
=
q_pe_out
.
stride
(
1
);
int64_t
out_stride_ks
=
k_pe_out
.
stride
(
0
);
const
auto
input_dtype
=
q_pe
.
scalar_type
();
TORCH_CHECK
(
t_pos
.
scalar_type
()
==
at
::
kLong
,
"expect positions to be int64, got "
,
t_pos
.
scalar_type
());
TORCH_CHECK
(
input_dtype
==
k_pe
.
scalar_type
(),
"q_pe and k_pe must have the same data type"
);
TORCH_CHECK
(
input_dtype
==
t_emb_pos
.
scalar_type
(),
"q_pe and t_emb_pos must have the same data type"
);
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
input_dtype
,
"rotary_position_embedding_cpu"
,
[
&
]
{
rope_kernel_impl
<
scalar_t
>
(
q_pe_out
.
data_ptr
<
scalar_t
>
(),
k_pe_out
.
data_ptr
<
scalar_t
>
(),
t_pos
.
data_ptr
<
int64_t
>
(),
q_pe
.
data_ptr
<
scalar_t
>
(),
k_pe
.
data_ptr
<
scalar_t
>
(),
t_emb_pos
.
data_ptr
<
scalar_t
>
(),
seq_len
,
num_head
,
rotary_dim
,
HR
,
q_pe_stride_s
,
out_stride_qs
,
out_stride_ks
,
HK
,
k_pe_stride_s
,
q_pe_stride_n
,
out_stride_qn
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
rotary_embedding_cpu
(
at
::
Tensor
&
positions
,
at
::
Tensor
&
query
,
at
::
Tensor
&
key
,
int64_t
head_size
,
at
::
Tensor
&
cos_sin_cache
,
bool
is_neox
)
{
RECORD_FUNCTION
(
"sgl-kernel::rotary_embedding_cpu"
,
std
::
vector
<
c10
::
IValue
>
({
query
,
key
}));
CHECK_DIM
(
1
,
positions
);
const
auto
input_dim
=
query
.
dim
();
const
auto
input_dtype
=
query
.
scalar_type
();
TORCH_CHECK
(
input_dim
==
2
||
input_dim
==
3
,
" Query/Key must be 2D [num_tokens, num_heads*head_size] or 3D [num_tokens, num_heads, head_size] tensor"
);
CHECK_DIM
(
2
,
cos_sin_cache
);
CHECK_LAST_DIM_CONTIGUOUS_INPUT
(
query
);
CHECK_LAST_DIM_CONTIGUOUS_INPUT
(
key
);
int64_t
rotary_dim
=
cos_sin_cache
.
size
(
1
);
if
(
input_dim
==
3
)
{
// TODO: add support for head_dim != rotary_dim case when input_dim=3
CHECK_EQ
(
query
.
size
(
-
1
),
rotary_dim
);
// TODO: add support for kv_head != 1
CHECK_EQ
(
key
.
size
(
1
),
1
);
}
int64_t
num_tokens
=
positions
.
numel
();
CHECK_EQ
(
key
.
size
(
0
),
num_tokens
);
CHECK_EQ
(
query
.
size
(
0
),
num_tokens
);
TORCH_CHECK
(
positions
.
scalar_type
()
==
at
::
kLong
,
"expect positions to be int64, got "
,
positions
.
scalar_type
());
TORCH_CHECK
(
input_dtype
==
key
.
scalar_type
(),
"query and key must have the same data type"
);
TORCH_CHECK
(
input_dtype
==
cos_sin_cache
.
scalar_type
(),
"query and cos_sin_cache must have the same data type"
);
int64_t
num_heads
=
input_dim
==
2
?
query
.
size
(
-
1
)
/
head_size
:
query
.
size
(
1
);
int64_t
num_kv_heads
=
input_dim
==
2
?
key
.
size
(
-
1
)
/
head_size
:
key
.
size
(
1
);
int64_t
key_stride_s
=
key
.
stride
(
0
);
int64_t
query_stride_s
=
query
.
stride
(
0
);
// input stride of num head dim is meaningful only when input dim = 3
int64_t
query_stride_h
=
input_dim
==
3
?
query
.
stride
(
1
)
:
-
1
;
at
::
Tensor
query_out
=
at
::
empty_like
(
query
);
at
::
Tensor
key_out
=
at
::
empty_like
(
key
);
int64_t
query_out_stride_s
=
query_out
.
stride
(
0
);
int64_t
key_out_stride_s
=
key_out
.
stride
(
0
);
// output stride of num head dim is meaningful only when input dim = 3
int64_t
query_out_stride_h
=
input_dim
==
3
?
query_out
.
stride
(
1
)
:
-
1
;
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
input_dtype
,
"rotary_embedding_cpu"
,
[
&
]
{
if
(
input_dim
==
2
)
{
if
(
is_neox
)
{
rotary_embedding_neox_2D_kernel_impl
<
scalar_t
>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rotary_dim
,
query_stride_s
,
key_stride_s
,
num_heads
,
num_kv_heads
,
head_size
,
num_tokens
);
}
else
{
rotary_embedding_2D_kernel_impl
<
scalar_t
>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rotary_dim
,
query_stride_s
,
key_stride_s
,
num_heads
,
num_kv_heads
,
head_size
,
num_tokens
);
}
query_out
=
query
;
key_out
=
key
;
}
else
{
TORCH_CHECK
(
is_neox
==
false
,
" Query/Key with 3D [num_tokens, num_heads, head_size] does not support neox rope yet"
);
// TODO: add neox style support for rope impl with 3D inputs
rotary_embedding_3D_kernel_impl
<
scalar_t
>
(
query_out
.
data_ptr
<
scalar_t
>
(),
key_out
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
num_tokens
,
num_heads
,
num_kv_heads
,
head_size
,
rotary_dim
,
query_stride_s
,
query_out_stride_s
,
key_out_stride_s
,
key_stride_s
,
query_stride_h
,
query_out_stride_h
);
}
});
return
std
::
make_tuple
(
q
_pe
_out
,
k
_p
e_out
);
return
std
::
make_tuple
(
q
uery
_out
,
ke
y
_out
);
}
sgl-kernel/csrc/cpu/topk.cpp
View file @
ff00895c
...
...
@@ -157,6 +157,101 @@ inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input)
}
}
template
<
typename
scalar_t
,
int
NUM_EXPERTS
>
void
topk_sigmoid_kernel_impl
(
float
*
__restrict__
topk_weights
,
int32_t
*
__restrict__
topk_ids
,
const
scalar_t
*
__restrict__
gating_output
,
int64_t
num_tokens
,
int64_t
topk
,
bool
renormalize
)
{
using
Vec
=
at
::
vec
::
Vectorized
<
float
>
;
const
int64_t
num_experts_per_group
=
NUM_EXPERTS
;
at
::
parallel_for
(
0
,
num_tokens
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
alignas
(
64
)
float
scores
[
NUM_EXPERTS
];
using
elem_t
=
std
::
pair
<
float
,
int32_t
>
;
std
::
vector
<
elem_t
>
queue
(
num_experts_per_group
);
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
at
::
vec
::
convert
<
scalar_t
,
float
>
(
gating_output
+
i
*
NUM_EXPERTS
,
scores
,
NUM_EXPERTS
);
float
gmax
=
at
::
vec
::
reduce_all
<
float
>
(
[](
Vec
&
x
,
Vec
&
y
)
{
return
at
::
vec
::
maximum
(
x
,
y
);
},
scores
,
num_experts_per_group
);
// find position of first max,
// note that we may have multiple max values.
int
first_max_idx
=
-
1
;
for
(
int64_t
e
=
0
;
e
<
num_experts_per_group
;
++
e
)
{
if
(
scores
[
e
]
==
gmax
)
{
first_max_idx
=
e
;
break
;
}
}
// scalar sigmoid
topk_weights
[
i
]
=
1.0
/
(
1.0
+
exp
(
0.0
-
gmax
));
topk_ids
[
i
]
=
first_max_idx
;
if
(
renormalize
)
{
float
sum
=
0.
f
;
for
(
int64_t
j
=
0
;
j
<
topk
;
++
j
)
{
sum
+=
topk_weights
[
i
*
topk
+
j
];
}
float
scale
=
1.
f
/
sum
;
for
(
int64_t
j
=
0
;
j
<
topk
;
++
j
)
{
topk_weights
[
i
*
topk
+
j
]
*=
scale
;
}
}
}
});
}
template
<
typename
scalar_t
,
int
NUM_EXPERTS
>
void
topk_softmax_kernel_impl
(
float
*
__restrict__
topk_weights
,
int32_t
*
__restrict__
topk_ids
,
const
scalar_t
*
__restrict__
gating_output
,
int64_t
num_tokens
,
int64_t
topk
,
bool
renormalize
)
{
const
int64_t
num_experts_per_group
=
NUM_EXPERTS
;
at
::
parallel_for
(
0
,
num_tokens
,
0
,
[
&
](
int64_t
begin
,
int64_t
end
)
{
alignas
(
64
)
float
scores
[
NUM_EXPERTS
];
using
elem_t
=
std
::
pair
<
float
,
int32_t
>
;
std
::
vector
<
elem_t
>
queue
(
num_experts_per_group
);
for
(
int64_t
i
=
begin
;
i
<
end
;
++
i
)
{
softmax
<
scalar_t
,
NUM_EXPERTS
>
(
scores
,
gating_output
+
i
*
NUM_EXPERTS
);
for
(
int64_t
e
=
0
;
e
<
num_experts_per_group
;
++
e
)
{
queue
[
e
]
=
{
scores
[
e
],
e
};
}
std
::
partial_sort
(
queue
.
begin
(),
queue
.
begin
()
+
num_experts_per_group
,
queue
.
end
(),
[](
const
elem_t
&
x
,
const
elem_t
&
y
)
->
bool
{
return
x
.
first
>
y
.
first
;
});
for
(
int64_t
j
=
0
;
j
<
topk
;
++
j
)
{
topk_weights
[
i
*
topk
+
j
]
=
queue
[
j
].
first
;
topk_ids
[
i
*
topk
+
j
]
=
queue
[
j
].
second
;
}
if
(
renormalize
)
{
float
sum
=
0.
f
;
for
(
int64_t
j
=
0
;
j
<
topk
;
++
j
)
{
sum
+=
topk_weights
[
i
*
topk
+
j
];
}
float
scale
=
1.
f
/
sum
;
for
(
int64_t
j
=
0
;
j
<
topk
;
++
j
)
{
topk_weights
[
i
*
topk
+
j
]
*=
scale
;
}
}
}
});
}
template
<
typename
scalar_t
,
int
SIZE
>
inline
void
apply_bias
(
float
*
__restrict__
scores2
,
const
float
*
__restrict__
scores
,
const
scalar_t
*
__restrict__
bias
)
{
...
...
@@ -293,6 +388,24 @@ void biased_grouped_topk_kernel_impl(
topk_group, \
renormalize);
#define LAUNCH_TOPK_SIGMOID_KERNEL(NE) \
topk_sigmoid_kernel_impl<scalar_t, NE>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
num_tokens, \
topk, \
renormalize);
#define LAUNCH_TOPK_SOFTMAX_KERNEL(NE) \
topk_softmax_kernel_impl<scalar_t, NE>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
num_tokens, \
topk, \
renormalize);
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
topk_weights.data_ptr<float>(), \
...
...
@@ -306,6 +419,114 @@ void biased_grouped_topk_kernel_impl(
}
// anonymous namespace
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
topk_sigmoid_cpu
(
at
::
Tensor
&
hidden_states
,
at
::
Tensor
&
gating_output
,
int64_t
topk
,
bool
renormalize
)
{
RECORD_FUNCTION
(
"sgl-kernel::topk_sigmoid_cpu"
,
std
::
vector
<
c10
::
IValue
>
({
hidden_states
,
gating_output
}));
CHECK_INPUT
(
gating_output
);
const
auto
st
=
hidden_states
.
scalar_type
();
CHECK_EQ
(
gating_output
.
scalar_type
(),
st
);
int64_t
num_tokens
=
hidden_states
.
size
(
0
);
int64_t
num_experts
=
gating_output
.
size
(
1
);
TORCH_CHECK
(
gating_output
.
size
(
0
)
==
num_tokens
,
"Number of tokens mismatch"
);
TORCH_CHECK
(
topk
==
1
,
"topk_sigmoid only supports topk=1 case"
);
at
::
Tensor
topk_weights
=
at
::
empty
({
num_tokens
,
topk
},
hidden_states
.
options
().
dtype
(
at
::
kFloat
));
at
::
Tensor
topk_ids
=
at
::
empty
({
num_tokens
,
topk
},
hidden_states
.
options
().
dtype
(
at
::
kInt
));
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
st
,
"topk_sigmoid_kernel"
,
[
&
]
{
switch
(
num_experts
)
{
case
1
:
LAUNCH_TOPK_SIGMOID_KERNEL
(
1
);
break
;
case
2
:
LAUNCH_TOPK_SIGMOID_KERNEL
(
2
);
break
;
case
4
:
LAUNCH_TOPK_SIGMOID_KERNEL
(
4
);
break
;
case
8
:
LAUNCH_TOPK_SIGMOID_KERNEL
(
8
);
break
;
case
16
:
LAUNCH_TOPK_SIGMOID_KERNEL
(
16
);
break
;
case
32
:
LAUNCH_TOPK_SIGMOID_KERNEL
(
32
);
break
;
case
64
:
LAUNCH_TOPK_SIGMOID_KERNEL
(
64
);
break
;
case
128
:
LAUNCH_TOPK_SIGMOID_KERNEL
(
128
);
break
;
case
160
:
LAUNCH_TOPK_SIGMOID_KERNEL
(
160
);
break
;
case
256
:
LAUNCH_TOPK_SIGMOID_KERNEL
(
256
);
break
;
default:
TORCH_CHECK
(
false
,
"Unexpected num_experts: "
,
num_experts
);
}
});
return
std
::
make_tuple
(
topk_weights
,
topk_ids
);
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
topk_softmax_cpu
(
at
::
Tensor
&
hidden_states
,
at
::
Tensor
&
gating_output
,
int64_t
topk
,
bool
renormalize
)
{
RECORD_FUNCTION
(
"sgl-kernel::topk_softmax_cpu"
,
std
::
vector
<
c10
::
IValue
>
({
hidden_states
,
gating_output
}));
CHECK_INPUT
(
gating_output
);
const
auto
st
=
hidden_states
.
scalar_type
();
CHECK_EQ
(
gating_output
.
scalar_type
(),
st
);
int64_t
num_tokens
=
hidden_states
.
size
(
0
);
int64_t
num_experts
=
gating_output
.
size
(
1
);
TORCH_CHECK
(
gating_output
.
size
(
0
)
==
num_tokens
,
"Number of tokens mismatch"
);
at
::
Tensor
topk_weights
=
at
::
empty
({
num_tokens
,
topk
},
hidden_states
.
options
().
dtype
(
at
::
kFloat
));
at
::
Tensor
topk_ids
=
at
::
empty
({
num_tokens
,
topk
},
hidden_states
.
options
().
dtype
(
at
::
kInt
));
AT_DISPATCH_REDUCED_FLOATING_TYPES
(
st
,
"topk_softmax_cpu"
,
[
&
]
{
switch
(
num_experts
)
{
case
1
:
LAUNCH_TOPK_SOFTMAX_KERNEL
(
1
);
break
;
case
2
:
LAUNCH_TOPK_SOFTMAX_KERNEL
(
2
);
break
;
case
4
:
LAUNCH_TOPK_SOFTMAX_KERNEL
(
4
);
break
;
case
8
:
LAUNCH_TOPK_SOFTMAX_KERNEL
(
8
);
break
;
case
16
:
LAUNCH_TOPK_SOFTMAX_KERNEL
(
16
);
break
;
case
32
:
LAUNCH_TOPK_SOFTMAX_KERNEL
(
32
);
break
;
case
64
:
LAUNCH_TOPK_SOFTMAX_KERNEL
(
64
);
break
;
case
128
:
LAUNCH_TOPK_SOFTMAX_KERNEL
(
128
);
break
;
case
160
:
LAUNCH_TOPK_SOFTMAX_KERNEL
(
160
);
break
;
case
256
:
LAUNCH_TOPK_SOFTMAX_KERNEL
(
256
);
break
;
default:
TORCH_CHECK
(
false
,
"Unexpected num_experts: "
,
num_experts
);
}
});
return
std
::
make_tuple
(
topk_weights
,
topk_ids
);
}
// grouped topk for DeepSeek V2
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
grouped_topk_cpu
(
at
::
Tensor
&
hidden_states
,
...
...
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
View file @
ff00895c
...
...
@@ -23,6 +23,9 @@ limitations under the License.
// silu_and_mul
at
::
Tensor
silu_and_mul_cpu
(
at
::
Tensor
&
input
);
// l2norm
at
::
Tensor
l2norm_cpu
(
at
::
Tensor
&
input
,
double
eps
);
// rmsnorm
at
::
Tensor
rmsnorm_cpu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
double
eps
);
...
...
@@ -30,6 +33,11 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
void
fused_add_rmsnorm_cpu
(
at
::
Tensor
&
input
,
at
::
Tensor
&
residual
,
at
::
Tensor
&
weight
,
double
eps
);
// topk
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
topk_sigmoid_cpu
(
at
::
Tensor
&
hidden_states
,
at
::
Tensor
&
gating_output
,
int64_t
topk
,
bool
renormalize
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
topk_softmax_cpu
(
at
::
Tensor
&
hidden_states
,
at
::
Tensor
&
gating_output
,
int64_t
topk
,
bool
renormalize
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
grouped_topk_cpu
(
at
::
Tensor
&
hidden_states
,
at
::
Tensor
&
gating_output
,
...
...
@@ -185,8 +193,13 @@ void shm_allreduce(
at
::
Tensor
shm_allgather
(
at
::
Tensor
&
data
,
c10
::
intrusive_ptr
<
c10d
::
ProcessGroup
>
process_group
,
int64_t
dim
);
// rope
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
rotary_position_embedding_cpu
(
at
::
Tensor
&
t_pos
,
at
::
Tensor
&
q_pe
,
at
::
Tensor
&
k_pe
,
at
::
Tensor
&
t_emb_pos
);
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
>
rotary_embedding_cpu
(
at
::
Tensor
&
positions
,
at
::
Tensor
&
query
,
at
::
Tensor
&
key
,
int64_t
head_size
,
at
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
TORCH_LIBRARY_FRAGMENT
(
sgl_kernel
,
m
)
{
// activation
...
...
@@ -196,10 +209,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// norm
m
.
def
(
"rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor"
);
m
.
impl
(
"rmsnorm_cpu"
,
torch
::
kCPU
,
&
rmsnorm_cpu
);
m
.
def
(
"l2norm_cpu(Tensor input, float eps) -> Tensor"
);
m
.
impl
(
"l2norm_cpu"
,
torch
::
kCPU
,
&
l2norm_cpu
);
m
.
def
(
"fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()"
);
m
.
impl
(
"fused_add_rmsnorm_cpu"
,
torch
::
kCPU
,
&
fused_add_rmsnorm_cpu
);
// topk
m
.
def
(
"topk_sigmoid_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)"
);
m
.
impl
(
"topk_sigmoid_cpu"
,
torch
::
kCPU
,
&
topk_sigmoid_cpu
);
m
.
def
(
"topk_softmax_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)"
);
m
.
impl
(
"topk_softmax_cpu"
,
torch
::
kCPU
,
&
topk_softmax_cpu
);
m
.
def
(
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
"int topk_group) -> (Tensor, Tensor)"
);
...
...
@@ -294,8 +313,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m
.
impl
(
"shm_allgather"
,
torch
::
kCPU
,
&
shm_allgather
);
// rope
m
.
def
(
"rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)"
);
m
.
impl
(
"rotary_position_embedding_cpu"
,
torch
::
kCPU
,
&
rotary_position_embedding_cpu
);
m
.
def
(
"rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
"bool is_neox) -> (Tensor, Tensor)"
);
m
.
impl
(
"rotary_embedding_cpu"
,
torch
::
kCPU
,
&
rotary_embedding_cpu
);
}
REGISTER_EXTENSION
(
common_ops
)
test/srt/cpu/test_norm.py
View file @
ff00895c
...
...
@@ -63,10 +63,24 @@ class TestNorm(CustomTestCase):
self
.
assertTrue
(
torch
.
allclose
(
x
,
ref_x
,
atol
=
atol
,
rtol
=
rtol
))
self
.
assertTrue
(
torch
.
allclose
(
residual
,
ref_residual
,
atol
=
atol
,
rtol
=
rtol
))
def
_l2norm_test
(
self
,
m
,
n
,
dtype
):
x
=
torch
.
randn
([
m
,
n
],
dtype
=
dtype
)
hidden_size
=
x
.
size
(
-
1
)
fake_ones_weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
)
variance_epsilon
=
1e-6
out
=
torch
.
ops
.
sgl_kernel
.
l2norm_cpu
(
x
,
variance_epsilon
)
ref_out
=
self
.
_forward_native
(
x
,
fake_ones_weight
,
variance_epsilon
)
atol
=
rtol
=
precision
[
ref_out
.
dtype
]
self
.
assertTrue
(
torch
.
allclose
(
ref_out
,
out
,
atol
=
atol
,
rtol
=
rtol
))
def
test_norm
(
self
):
for
params
in
itertools
.
product
(
self
.
M
,
self
.
N
,
self
.
dtype
):
with
self
.
subTest
(
m
=
params
[
0
],
n
=
params
[
1
],
dtype
=
params
[
2
]):
self
.
_norm_test
(
*
params
)
self
.
_l2norm_test
(
*
params
)
if
__name__
==
"__main__"
:
...
...
test/srt/cpu/test_rope.py
View file @
ff00895c
...
...
@@ -4,7 +4,10 @@ import sgl_kernel
import
torch
from
utils
import
precision
from
sglang.srt.layers.rotary_embedding
import
DeepseekScalingRotaryEmbedding
from
sglang.srt.layers.rotary_embedding
import
(
DeepseekScalingRotaryEmbedding
,
RotaryEmbedding
,
)
from
sglang.test.test_utils
import
CustomTestCase
...
...
@@ -62,10 +65,13 @@ class TestROPE(CustomTestCase):
)
# fused rope kernel
q_pe_clone
,
k_pe_clone
=
(
torch
.
ops
.
sgl_kernel
.
rotary_position_embedding_cpu
(
positions
,
q_pe_clone
,
k_pe_clone
,
cos_sin_cache
)
q_pe_clone
,
k_pe_clone
=
torch
.
ops
.
sgl_kernel
.
rotary_embedding_cpu
(
positions
,
q_pe_clone
,
k_pe_clone
,
rope
.
head_size
,
cos_sin_cache
,
False
,
)
atol
=
rtol
=
precision
[
q_pe
.
dtype
]
...
...
@@ -73,6 +79,98 @@ class TestROPE(CustomTestCase):
self
.
assertTrue
(
torch
.
allclose
(
k_pe
,
k_pe_clone
,
atol
=
atol
,
rtol
=
rtol
))
torch
.
testing
.
assert_close
(
k_pe
,
k_pe_clone
)
def
test_origin_rope
(
self
):
def
single_test
(
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
dtype
:
torch
.
dtype
,
device
:
str
,
batch_size
:
int
,
seq_len
:
int
,
num_q_heads
:
int
,
num_kv_heads
:
int
,
):
torch
.
manual_seed
(
100
)
rope_ref
=
RotaryEmbedding
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
,
).
to
(
device
)
pos_ids
=
torch
.
arange
(
seq_len
,
device
=
device
).
repeat
(
batch_size
)
query
=
torch
.
randn
(
batch_size
*
seq_len
,
num_q_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
,
)
key
=
torch
.
randn
(
batch_size
*
seq_len
,
num_kv_heads
*
head_size
,
dtype
=
dtype
,
device
=
device
,
)
query_ref
,
key_ref
=
query
.
clone
(),
key
.
clone
()
query_cpu
,
key_cpu
=
query
.
clone
(),
key
.
clone
()
query_ref_out
,
key_ref_out
=
rope_ref
.
forward_native
(
pos_ids
,
query_ref
,
key_ref
)
query_cpu_out
,
key_cpu_out
=
torch
.
ops
.
sgl_kernel
.
rotary_embedding_cpu
(
pos_ids
,
query_cpu
,
key_cpu
,
rope_ref
.
head_size
,
rope_ref
.
cos_sin_cache
.
to
(
query
.
dtype
),
rope_ref
.
is_neox_style
,
)
torch
.
testing
.
assert_close
(
query_ref_out
,
query_cpu_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
key_ref_out
,
key_cpu_out
,
atol
=
1e-2
,
rtol
=
1e-2
)
test_config
=
[
(
64
,
64
,
32
,
8000
,
True
,
torch
.
bfloat16
,
"cpu"
,
32
,
32
,
1
,
1
),
(
256
,
128
,
4096
,
10000
,
True
,
torch
.
bfloat16
,
"cpu"
,
2
,
512
,
32
,
8
),
(
512
,
128
,
311
,
10000
,
True
,
torch
.
bfloat16
,
"cpu"
,
3
,
39
,
4
,
2
),
(
128
,
128
,
2048
,
10000
,
False
,
torch
.
bfloat16
,
"cpu"
,
2
,
512
,
32
,
8
),
(
128
,
128
,
2048
,
10000
,
False
,
torch
.
bfloat16
,
"cpu"
,
2
,
512
,
16
,
4
),
(
512
,
128
,
311
,
10000
,
False
,
torch
.
bfloat16
,
"cpu"
,
3
,
39
,
4
,
2
),
]
for
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
,
device
,
batch_size
,
seq_len
,
num_q_heads
,
num_kv_heads
,
)
in
test_config
:
single_test
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
,
device
,
batch_size
,
seq_len
,
num_q_heads
,
num_kv_heads
,
)
if
__name__
==
"__main__"
:
unittest
.
main
()
test/srt/cpu/test_topk.py
View file @
ff00895c
...
...
@@ -8,7 +8,9 @@ from utils import precision
from
sglang.srt.layers.moe.topk
import
(
biased_grouped_topk_impl
as
native_biased_grouped_topk
,
)
from
sglang.srt.layers.moe.topk
import
fused_topk_native
as
native_fused_topk
from
sglang.srt.layers.moe.topk
import
grouped_topk
as
native_grouped_topk
from
sglang.srt.models.llama4
import
Llama4MoE
from
sglang.test.test_utils
import
CustomTestCase
...
...
@@ -94,5 +96,86 @@ class TestBiasedGroupedTopK(CustomTestCase):
self
.
_run_single_test
(
122
,
256
,
8
,
8
,
2
,
renormalize
,
torch
.
bfloat16
)
class
TestTopK
(
CustomTestCase
):
def
_run_single_test
(
self
,
M
,
E
,
topk
,
renormalize
,
dtype
):
torch
.
manual_seed
(
1998
)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states
=
torch
.
randn
(
M
,
100
,
dtype
=
dtype
)
gating_output
=
torch
.
randn
(
M
,
E
,
dtype
=
dtype
)
*
2
*
M
ref_topk_weights
,
ref_topk_ids
=
native_fused_topk
(
hidden_states
.
float
(),
gating_output
.
float
(),
topk
,
renormalize
,
)
# fused version
topk_weights
,
topk_ids
=
torch
.
ops
.
sgl_kernel
.
topk_softmax_cpu
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
res
=
torch
.
zeros
(
M
,
E
,
dtype
=
torch
.
float
)
ref
=
torch
.
zeros
(
M
,
E
,
dtype
=
torch
.
float
)
res
.
scatter_
(
1
,
topk_ids
.
long
(),
topk_weights
)
ref
.
scatter_
(
1
,
ref_topk_ids
.
long
(),
ref_topk_weights
)
torch
.
testing
.
assert_close
(
res
,
ref
)
def
test_topk
(
self
):
for
renormalize
in
[
True
,
False
]:
self
.
_run_single_test
(
123
,
8
,
2
,
renormalize
,
torch
.
bfloat16
)
self
.
_run_single_test
(
123
,
16
,
3
,
renormalize
,
torch
.
bfloat16
)
self
.
_run_single_test
(
123
,
32
,
3
,
renormalize
,
torch
.
bfloat16
)
self
.
_run_single_test
(
123
,
32
,
3
,
renormalize
,
torch
.
bfloat16
)
self
.
_run_single_test
(
123
,
64
,
6
,
renormalize
,
torch
.
bfloat16
)
self
.
_run_single_test
(
123
,
256
,
4
,
renormalize
,
torch
.
bfloat16
)
self
.
_run_single_test
(
123
,
160
,
6
,
renormalize
,
torch
.
bfloat16
)
class
TestCustomTopK
(
CustomTestCase
):
def
_run_single_test
(
self
,
M
,
E
,
topk
,
renormalize
,
dtype
,
native_custom_f
,
fused_custom_f
):
torch
.
manual_seed
(
16
)
# expand gating_output by M, otherwise bfloat16 fall into same value aftering truncating
hidden_states
=
torch
.
randn
(
M
,
100
,
dtype
=
dtype
)
gating_output
=
torch
.
randn
(
M
,
E
,
dtype
=
dtype
)
*
2
*
M
ref_topk_weights
,
ref_topk_ids
=
native_custom_f
(
hidden_states
.
float
(),
gating_output
.
float
(),
topk
,
renormalize
,
)
# fused version
topk_weights
,
topk_ids
=
fused_custom_f
(
hidden_states
,
gating_output
,
topk
,
renormalize
)
res
=
torch
.
zeros
(
M
,
E
,
dtype
=
torch
.
float
)
ref
=
torch
.
zeros
(
M
,
E
,
dtype
=
torch
.
float
)
res
.
scatter_
(
1
,
topk_ids
.
long
(),
topk_weights
)
ref
.
scatter_
(
1
,
ref_topk_ids
.
long
(),
ref_topk_weights
)
torch
.
testing
.
assert_close
(
res
,
ref
)
def
test_custom_topk
(
self
):
test_custom_functions
=
[
(
Llama4MoE
.
custom_routing_function
,
torch
.
ops
.
sgl_kernel
.
topk_sigmoid_cpu
)
]
for
native_custom_f
,
fused_custom_f
in
test_custom_functions
:
self
.
_run_single_test
(
123
,
8
,
1
,
False
,
torch
.
bfloat16
,
native_custom_f
,
fused_custom_f
)
self
.
_run_single_test
(
123
,
16
,
1
,
False
,
torch
.
bfloat16
,
native_custom_f
,
fused_custom_f
)
self
.
_run_single_test
(
123
,
32
,
1
,
False
,
torch
.
bfloat16
,
native_custom_f
,
fused_custom_f
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment