Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3efbdf68
"src/vscode:/vscode.git/clone" did not exist on "909742dbd6873052995dc6cd5f4150ff238015d2"
Unverified
Commit
3efbdf68
authored
Feb 14, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Feb 14, 2025
Browse files
fix sgl-kernel codestyle (#3563)
parent
6cc30955
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
29 deletions
+34
-29
sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu
.../src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu
+20
-15
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
+8
-6
sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu
+6
-8
No files found.
sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu
View file @
3efbdf68
...
...
@@ -33,11 +33,11 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
const
int
batch_size
,
const
int
num_heads
,
const
int
qk_dim
,
const
int
v_dim
)
{
extern
__shared__
char
smem
[];
T
*
q_shared
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
k_shared
=
reinterpret_cast
<
T
*>
(
smem
+
qk_dim
*
sizeof
(
T
));
T
*
v_shared
=
reinterpret_cast
<
T
*>
(
smem
+
2
*
qk_dim
*
sizeof
(
T
));
float
*
new_kv_shared
=
reinterpret_cast
<
float
*>
(
smem
+
(
2
*
qk_dim
+
v_dim
)
*
sizeof
(
T
));
T
*
output_shared
=
T
*
__restrict__
q_shared
=
reinterpret_cast
<
T
*>
(
smem
);
T
*
__restrict__
k_shared
=
reinterpret_cast
<
T
*>
(
smem
+
qk_dim
*
sizeof
(
T
));
T
*
__restrict__
v_shared
=
reinterpret_cast
<
T
*>
(
smem
+
2
*
qk_dim
*
sizeof
(
T
));
float
*
__restrict__
new_kv_shared
=
reinterpret_cast
<
float
*>
(
smem
+
(
2
*
qk_dim
+
v_dim
)
*
sizeof
(
T
));
T
*
__restrict__
output_shared
=
reinterpret_cast
<
T
*>
(
smem
+
(
2
*
qk_dim
+
v_dim
)
*
sizeof
(
T
)
+
qk_dim
*
(
v_dim
+
1
)
*
sizeof
(
float
));
const
int32_t
tid
=
threadIdx
.
x
;
...
...
@@ -51,6 +51,7 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
const
int32_t
v_offset
=
b
*
num_heads
*
v_dim
+
h
*
v_dim
;
const
int32_t
kv_offset
=
b
*
num_heads
*
qk_dim
*
v_dim
+
h
*
qk_dim
*
v_dim
;
// Load q, k, v into shared memory
for
(
int
d
=
tid
;
d
<
qk_dim
;
d
+=
blockDim
.
x
)
{
q_shared
[
d
]
=
q
[
qk_offset
+
d
];
k_shared
[
d
]
=
k
[
qk_offset
+
d
];
...
...
@@ -63,33 +64,36 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
const
float
ratio
=
expf
(
-
1.0
f
*
slope
[
h
]);
// Compute new_kv
for
(
int
d
=
tid
;
d
<
qk_dim
;
d
+=
blockDim
.
x
)
{
T
k_val
=
k_shared
[
d
];
const
T
k_val
=
k_shared
[
d
];
for
(
int
e
=
0
;
e
<
v_dim
;
++
e
)
{
int
past_kv_idx
=
kv_offset
+
d
*
v_dim
+
e
;
T
v_val
=
v_shared
[
e
];
float
new_val
=
ratio
*
past_kv
[
past_kv_idx
]
+
k_val
*
v_val
;
int
shared_idx
=
d
*
(
v_dim
+
1
)
+
e
;
const
int
past_kv_idx
=
kv_offset
+
d
*
v_dim
+
e
;
const
T
v_val
=
v_shared
[
e
];
const
float
new_val
=
ratio
*
past_kv
[
past_kv_idx
]
+
k_val
*
v_val
;
const
int
shared_idx
=
d
*
(
v_dim
+
1
)
+
e
;
new_kv_shared
[
shared_idx
]
=
new_val
;
}
}
__syncthreads
();
// Store new_kv to global memory
for
(
int
idx
=
tid
;
idx
<
qk_dim
*
v_dim
;
idx
+=
blockDim
.
x
)
{
int
d
=
idx
/
v_dim
;
int
e
=
idx
%
v_dim
;
int
shared_idx
=
d
*
(
v_dim
+
1
)
+
e
;
int
global_idx
=
kv_offset
+
idx
;
const
int
d
=
idx
/
v_dim
;
const
int
e
=
idx
%
v_dim
;
const
int
shared_idx
=
d
*
(
v_dim
+
1
)
+
e
;
const
int
global_idx
=
kv_offset
+
idx
;
new_kv
[
global_idx
]
=
new_kv_shared
[
shared_idx
];
}
__syncthreads
();
// Compute output
for
(
int
e
=
tid
;
e
<
v_dim
;
e
+=
blockDim
.
x
)
{
float
sum
=
0.0
f
;
for
(
int
d
=
0
;
d
<
qk_dim
;
++
d
)
{
int
shared_idx
=
d
*
(
v_dim
+
1
)
+
e
;
const
int
shared_idx
=
d
*
(
v_dim
+
1
)
+
e
;
sum
+=
q_shared
[
d
]
*
new_kv_shared
[
shared_idx
];
}
output_shared
[
e
]
=
static_cast
<
T
>
(
sum
);
...
...
@@ -97,6 +101,7 @@ __global__ void lightning_attention_decode_kernel(const T* __restrict__ q,
__syncthreads
();
// Store output to global memory
if
(
tid
==
0
)
{
for
(
int
e
=
0
;
e
<
v_dim
;
++
e
)
{
output
[
v_offset
+
e
]
=
output_shared
[
e
];
...
...
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
View file @
3efbdf68
...
...
@@ -25,8 +25,9 @@ limitations under the License.
#define WARP_SIZE 32
template
<
typename
scalar_t
>
__global__
void
moe_token_sort_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
cumsum_buffer
,
size_t
numel
)
{
__global__
void
count_and_sort_expert_tokens_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
cumsum_buffer
,
size_t
numel
)
{
const
size_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
...
...
@@ -38,9 +39,10 @@ __global__ void moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, int32_t*
}
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
cumsum
)
{
__global__
void
moe_align_block_size_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
)
{
__shared__
int32_t
shared_counts
[
WARP_SIZE
][
8
];
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
...
...
@@ -106,7 +108,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
const
int
max_blocks
=
65535
;
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
auto
sort_kernel
=
moe
_token
_sort
_kernel
<
scalar_t
>
;
auto
sort_kernel
=
count_and_sort_expert
_token
s
_kernel
<
scalar_t
>
;
sort_kernel
<<<
actual_blocks
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
());
...
...
sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu
View file @
3efbdf68
...
...
@@ -7,13 +7,11 @@
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
__device__
__forceinline__
float
WarpReduce
(
volatile
float
*
smem
,
const
int
tid
)
{
if
(
tid
<
8
)
{
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
8
]);
if
(
tid
<
4
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
4
]);
if
(
tid
<
2
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
2
]);
if
(
tid
<
1
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
1
]);
}
__device__
__forceinline__
float
GroupReduce
(
volatile
float
*
smem
,
const
int
tid
)
{
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
8
]);
if
(
tid
<
4
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
4
]);
if
(
tid
<
2
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
2
]);
if
(
tid
<
1
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
1
]);
return
smem
[
0
];
}
...
...
@@ -53,7 +51,7 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
// Perform reduction within each group
if
(
local_tid
<
8
)
{
War
pReduce
(
&
s_absmax
[
local_group_id
][
0
],
local_tid
);
Grou
pReduce
(
&
s_absmax
[
local_group_id
][
0
],
local_tid
);
}
__syncthreads
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment