Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
95085d65
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "daed453e84ba4f30681f8a458e522ad1249d10af"
Unverified
Commit
95085d65
authored
Mar 06, 2025
by
Stefan He
Committed by
GitHub
Mar 06, 2025
Browse files
[Refactor] Reducing code duplication across FP8 CUDA quantization kernels (#4163)
parent
c7f25446
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
32 additions
and
64 deletions
+32
-64
sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py
sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py
+2
-3
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
+0
-3
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
+0
-32
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
+0
-26
sgl-kernel/src/sgl-kernel/include/utils.h
sgl-kernel/src/sgl-kernel/include/utils.h
+30
-0
No files found.
sgl-kernel/benchmark/bench_per_token_group_quant_fp8.py
View file @
95085d65
import
itertools
import
math
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Tuple
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
sgl_per_token_group_quant_fp8
from
sglang.srt.utils
import
get_device_core_count
,
get_device_name
,
is_hip
from
sglang.srt.utils
import
is_hip
is_hip_
=
is_hip
()
fp8_type_
=
torch
.
float8_e4m3fnuz
if
is_hip_
else
torch
.
float8_e4m3fn
...
...
sgl-kernel/benchmark/bench_per_token_quant_fp8.py
View file @
95085d65
...
...
@@ -40,9 +40,6 @@ def calculate_diff(batch_size: int, seq_len: int):
scale_diff
=
torch
.
abs
(
vllm_scale
-
sglang_scale
).
mean
().
item
()
output_diff
=
torch
.
abs
(
vllm_out
.
float
()
-
sglang_out
.
float
()).
mean
().
item
()
print
(
f
"Scale difference:
{
scale_diff
}
"
)
print
(
f
"Output difference:
{
output_diff
}
"
)
if
torch
.
allclose
(
vllm_out
.
to
(
torch
.
float32
),
sglang_out
.
to
(
torch
.
float32
),
rtol
=
1e-3
,
atol
=
1e-5
)
and
torch
.
allclose
(
vllm_scale
,
sglang_scale
,
rtol
=
1e-3
,
atol
=
1e-5
):
...
...
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
View file @
95085d65
...
...
@@ -7,38 +7,6 @@
#include "utils.h"
#define WARP_SIZE 32
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
#endif
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
float
old
;
old
=
(
value
>=
0
)
?
__int_as_float
(
atomicMax
((
int
*
)
addr
,
__float_as_int
(
value
)))
:
__uint_as_float
(
atomicMin
((
unsigned
int
*
)
addr
,
__float_as_uint
(
value
)));
return
old
;
}
__device__
__forceinline__
float
warpReduceMax
(
float
max_value
)
{
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
16
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
8
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
4
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
2
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
1
));
return
max_value
;
}
template
<
typename
T
>
__global__
void
per_tensor_absmax_kernel
(
const
T
*
__restrict__
input
,
float
*
__restrict__
output_s
,
const
int64_t
num_elements
)
{
...
...
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
View file @
95085d65
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include <cub/block/block_reduce.cuh>
...
...
@@ -7,31 +6,6 @@
#include "utils.h"
#define WARP_SIZE 32
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
#endif
__device__
__forceinline__
float
warpReduceMax
(
float
max_value
)
{
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
16
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
8
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
4
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
2
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
1
));
return
max_value
;
}
template
<
typename
T
>
__global__
void
per_token_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output_q
,
float
*
__restrict__
output_s
,
const
int64_t
hidden_dim
,
...
...
sgl-kernel/src/sgl-kernel/include/utils.h
View file @
95085d65
...
...
@@ -95,3 +95,33 @@ inline int getSMVersion() {
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
#define WARP_SIZE 32
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#include <c10/util/Float8_e4m3fnuz.h>
#include "amd/quant_utils.cuh"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
#endif
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
float
old
;
old
=
(
value
>=
0
)
?
__int_as_float
(
atomicMax
((
int
*
)
addr
,
__float_as_int
(
value
)))
:
__uint_as_float
(
atomicMin
((
unsigned
int
*
)
addr
,
__float_as_uint
(
value
)));
return
old
;
}
__device__
__forceinline__
float
warpReduceMax
(
float
max_value
)
{
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
16
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
8
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
4
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
2
));
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
1
));
return
max_value
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment