Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0c8dab9e
Unverified
Commit
0c8dab9e
authored
Jul 23, 2025
by
Yuan Luo
Committed by
GitHub
Jul 23, 2025
Browse files
[sgl-kernel] Opt per_token_quant_fp8 with warp reduce (#8130)
Co-authored-by:
luoyuan.luo
<
luoyuan.luo@antgroup.com
>
parent
f39037ff
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
106 additions
and
16 deletions
+106
-16
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
+106
-16
No files found.
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
View file @
0c8dab9e
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <cmath>
#include <cmath>
#include <cub/block/block_reduce.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
#include "utils.h"
template
<
typename
T
>
static
constexpr
int
kWarpSize
=
32
;
// ---------------------------------------------------------------------------
// 1. Warp‑local, no shared memory
// • One warp handles one token.
// • Eight tokens per 256‑thread CTA.
// ---------------------------------------------------------------------------
template
<
typename
T
,
int
kTokensPerCTA
=
8
,
int
kVecSize
=
16
>
__global__
void
per_token_quant_fp8_kernel
(
__global__
void
per_token_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output_q
,
FP8_TYPE
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
float
*
__restrict__
output_s
,
const
int64_t
hidden_dim
,
const
int64_t
hidden_dim
,
const
int64_t
num_tokens
)
{
const
int64_t
num_tokens
)
{
const
int
warp_id
=
threadIdx
.
x
/
kWarpSize
;
// 0‑7 (8 warps)
const
int
lane_id
=
threadIdx
.
x
&
(
kWarpSize
-
1
);
// 0‑31
const
int
token_id
=
blockIdx
.
x
*
kTokensPerCTA
+
warp_id
;
if
(
token_id
>=
num_tokens
)
return
;
// Global tensors for this token
const
T
*
token_input
=
input
+
token_id
*
hidden_dim
;
FP8_TYPE
*
token_output
=
output_q
+
token_id
*
hidden_dim
;
float
*
token_scale
=
output_s
+
token_id
;
//
// Pass-1: Perform a warp reduce to find the max_value of a token's hidden_dim
//
float
max_value
=
0.
f
;
using
vec_t
=
flashinfer
::
vec_t
<
T
,
kVecSize
>
;
const
int32_t
num_vec_elems
=
hidden_dim
/
kVecSize
;
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
kWarpSize
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
token_input
+
i
*
kVecSize
);
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
kVecSize
;
++
j
)
{
max_value
=
fmaxf
(
max_value
,
fabsf
(
static_cast
<
float
>
(
input_vec
[
j
])));
}
}
float
warp_max
=
warpReduceMax
(
max_value
);
__shared__
float
scale
;
scale
=
warp_max
/
FP8_E4M3_MAX
;
// Broadcast scale
if
(
lane_id
==
0
)
{
token_scale
[
0
]
=
scale
;
}
float
scale_inv
=
(
scale
==
0.
f
)
?
0.
f
:
1.0
f
/
scale
;
//
// Pass-2: quantize and write back
//
for
(
int
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
kWarpSize
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
token_input
+
i
*
kVecSize
);
FP8_TYPE
output_arr
[
kVecSize
];
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
kVecSize
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
])
*
scale_inv
;
val
=
fmaxf
(
fminf
(
val
,
FP8_E4M3_MAX
),
-
FP8_E4M3_MAX
);
#ifndef USE_ROCM
output_arr
[
j
]
=
static_cast
<
FP8_TYPE
>
(
val
);
#else
output_arr
[
j
]
=
c10
::
Float8_e4m3fnuz
(
__hip_cvt_float_to_fp8
(
val
,
fp8
::
fp8_type
::
__default_saturation
,
fp8
::
fp8_type
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
*
(
uint4
*
)(
token_output
+
i
*
kVecSize
)
=
*
(
uint4
*
)
output_arr
;
}
}
// ---------------------------------------------------------------------------
// 2. Baseline kernel (1 token / CTA, CUB block reduce)
// ---------------------------------------------------------------------------
template
<
typename
T
>
__global__
void
per_token_quant_fp8_small_batch_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int64_t
hidden_dim
,
const
int64_t
num_tokens
)
{
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
if
(
token_idx
>=
num_tokens
)
return
;
if
(
token_idx
>=
num_tokens
)
return
;
...
@@ -79,28 +156,41 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
...
@@ -79,28 +156,41 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_q
);
CHECK_INPUT
(
output_s
);
CHECK_INPUT
(
output_s
);
const
auto
input_sizes
=
input
.
sizes
();
const
auto
input_sizes
=
input
.
sizes
();
const
int64_t
num_tokens
=
input_sizes
[
0
];
const
int64_t
num_tokens
=
input_sizes
[
0
];
const
int64_t
hidden_dim
=
input_sizes
[
1
];
const
int64_t
hidden_dim
=
input_sizes
[
1
];
TORCH_CHECK
(
hidden_dim
%
16
==
0
,
"Hidden dimension must be divisible by 16, but got "
,
hidden_dim
);
TORCH_CHECK
(
hidden_dim
%
16
==
0
,
"Hidden dimension must be divisible by 16, but got "
,
hidden_dim
);
const
int
block_size
=
256
;
const
int
num_blocks
=
num_tokens
;
dim3
grid
(
num_blocks
);
dim3
block
(
block_size
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
// Hard-code sm_count
int
sm_count
=
132
;
constexpr
int
TOKENS_PER_CTA
=
8
;
const
bool
use_warp_kernel
=
(
num_tokens
>=
sm_count
*
2
*
TOKENS_PER_CTA
);
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16
(
input
.
scalar_type
(),
scalar_t
,
[
&
]
{
per_token_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
if
(
use_warp_kernel
)
{
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
// -------- warp‑local ---------------------------------------------------
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
constexpr
int
THREADS
=
TOKENS_PER_CTA
*
kWarpSize
;
// 256
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
dim3
grid
((
num_tokens
+
TOKENS_PER_CTA
-
1
)
/
TOKENS_PER_CTA
);
hidden_dim
,
dim3
block
(
THREADS
);
num_tokens
);
per_token_quant_fp8_kernel
<
scalar_t
,
TOKENS_PER_CTA
,
16
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
}
else
{
// -------- baseline -----------------------------------------------------
constexpr
int
THREADS
=
256
;
dim3
grid
(
num_tokens
);
dim3
block
(
THREADS
);
per_token_quant_fp8_small_batch_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
const
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
hidden_dim
,
num_tokens
);
}
return
true
;
return
true
;
});
});
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment