Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
55a7ec38
Unverified
Commit
55a7ec38
authored
Feb 19, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Feb 19, 2025
Browse files
use warp shuffle style reduce and flashinfer vectorize (#3628)
parent
fe0673f1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
42 deletions
+48
-42
sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py
sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py
+1
-1
sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu
+47
-41
No files found.
sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py
View file @
55a7ec38
...
...
@@ -186,7 +186,7 @@ configs = list(itertools.product(batch_size_range, seq_len_range, group_size_ran
def
benchmark
(
batch_size
,
seq_len
,
group_size
,
provider
):
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda"
)
hidden_dim
=
group_size
*
2
hidden_dim
=
7168
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_dim
,
device
=
device
,
dtype
=
dtype
)
...
...
sgl-kernel/src/sgl-kernel/csrc/per_token_group_quant_fp8.cu
View file @
55a7ec38
...
...
@@ -2,17 +2,18 @@
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
__device__
__forceinline__
float
GroupReduce
(
volatile
float
*
smem
,
const
int
tid
)
{
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
8
]
);
if
(
tid
<
4
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
4
]
);
if
(
tid
<
2
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
2
]
);
if
(
tid
<
1
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
1
]
);
return
smem
[
0
]
;
__device__
__forceinline__
float
GroupReduce
(
float
val
,
const
int
tid
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
0xffff
,
val
,
8
)
);
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
0xffff
,
val
,
4
)
);
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
0xffff
,
val
,
2
)
);
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
0xffff
,
val
,
1
)
);
return
val
;
}
template
<
typename
T
>
...
...
@@ -21,54 +22,60 @@ __global__ void per_token_group_quant_fp8_kernel(const T* __restrict__ input, vo
const
int
num_groups
,
const
float
eps
,
const
float
fp8_min
,
const
float
fp8_max
)
{
const
int
groups_per_block
=
16
;
const
int
local_group_id
=
threadIdx
.
x
/
16
;
const
int
lane_id
=
threadIdx
.
x
%
16
;
const
int
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int
tid
=
threadIdx
.
x
;
const
int
local_group_id
=
tid
/
16
;
// Each 16 threads handle one group
const
int
local_tid
=
tid
%
16
;
// Thread ID within the group
const
int
block_group_offset
=
(
block_group_id
+
local_group_id
)
*
group_size
;
__shared__
float
s_absmax
[
16
]
[
17
];
// Use 17 instead of 16 to avoid bank conflicts
__shared__
float
s_absmax
[
16
]
;
// Local maximum value for each thread
float
local_absmax
=
eps
;
// Ensure this block doesn't process out-of-bounds groups
if
(
block_group_id
+
local_group_id
<
num_groups
)
{
// Calculate input/output pointers for current group
const
T
*
group_input
=
input
+
(
block_group_id
+
local_group_id
)
*
group_size
;
FP8_TYPE
*
group_output
=
static_cast
<
FP8_TYPE
*>
(
output_q
)
+
(
block_group_id
+
local_group_id
)
*
group_size
;
float
*
scale_output
=
output_s
+
block_group_id
+
local_group_id
;
const
T
*
group_input
=
input
+
block_group_offset
;
FP8_TYPE
*
group_output
=
static_cast
<
FP8_TYPE
*>
(
output_q
)
+
block_group_offset
;
float
*
scale_output
=
output_s
+
block_group_id
+
local_group_id
;
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
const
int32_t
num_vec_elems
=
group_size
/
vec_size
;
// Calculate local maximum absolute value
for
(
int
i
=
local_tid
;
i
<
group_size
;
i
+=
16
)
{
float
val
=
static_cast
<
float
>
(
group_input
[
i
]);
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
16
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
group_input
+
i
*
vec_size
);
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
float
abs_val
=
fabsf
(
val
);
local_absmax
=
fmaxf
(
local_absmax
,
abs_val
);
}
}
// Store in shared memory
s_absmax
[
local_group_id
][
local_tid
]
=
local_absmax
;
__syncthreads
();
local_absmax
=
GroupReduce
(
local_absmax
,
lane_id
);
// Perform reduction within each group
if
(
local_tid
<
8
)
{
GroupReduce
(
&
s_absmax
[
local_group_id
][
0
],
local_tid
);
}
__syncthreads
();
if
(
lane_id
==
0
)
{
s_absmax
[
local_group_id
]
=
local_absmax
;
}
__syncthreads
();
// Get the maximum value for this group
const
float
group_absmax
=
s_absmax
[
local_group_id
][
0
];
const
float
y_s
=
group_absmax
/
fp8_max
;
const
float
group_absmax
=
s_absmax
[
local_group_id
];
const
float
y_s
=
group_absmax
/
fp8_max
;
// Only the first thread in each group writes the scale
if
(
local_tid
==
0
)
{
*
scale_output
=
y_s
;
}
if
(
lane_id
==
0
)
{
*
scale_output
=
y_s
;
}
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
16
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
group_input
+
i
*
vec_size
);
// Quantize the data
for
(
int
i
=
local_tid
;
i
<
group
_size
;
i
+=
16
)
{
float
val
=
static_cast
<
float
>
(
group_input
[
i
]);
#pragma unroll
for
(
u
int
32_t
j
=
0
;
j
<
vec
_size
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
float
q_val
=
fminf
(
fmaxf
(
val
/
y_s
,
fp8_min
),
fp8_max
);
group_output
[
i
]
=
FP8_TYPE
(
q_val
);
group_output
[
i
*
vec_size
+
j
]
=
FP8_TYPE
(
q_val
);
}
}
}
...
...
@@ -83,9 +90,8 @@ void sgl_per_token_group_quant_fp8(torch::Tensor input, torch::Tensor output_q,
CHECK_EQ
(
input
.
numel
()
%
group_size
,
0
);
// Each block processes 16 groups, adjust grid size accordingly
dim3
grid
((
num_groups
+
15
)
/
16
);
dim3
block
(
256
);
// Keep 256 threads, each 16 threads handle one group
dim3
block
(
256
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment