Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6a384d5c
Unverified
Commit
6a384d5c
authored
Mar 22, 2025
by
Chunan Zeng
Committed by
GitHub
Mar 22, 2025
Browse files
Speed up per token and per tensor quant by 15% (#4639)
parent
f69e0696
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
32 deletions
+29
-32
sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
+12
-14
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
+17
-18
No files found.
sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
View file @
6a384d5c
...
...
@@ -54,42 +54,40 @@ __global__ void per_tensor_quant_fp8_kernel(
const
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
const
float
scale_val
=
1.0
f
/
(
*
scale
);
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
// We want to store 128 bits of data at a time. 16 = 128 / 8 bits
// Load is already vectorized, so 16 elements work for T.
const
uint32_t
VEC_SIZE
=
16
;
using
vec_t
=
flashinfer
::
vec_t
<
T
,
VEC_SIZE
>
;
const
int32_t
num_vec_elems
=
num_elements
/
vec_size
;
const
int32_t
num_vec_elems
=
num_elements
/
VEC_SIZE
;
for
(
int32_t
i
=
gid
;
i
<
num_vec_elems
;
i
+=
grid_size
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
input
+
i
*
vec_size
);
input_vec
.
cast_load
(
input
+
i
*
VEC_SIZE
);
FP8_TYPE
output_arr
[
vec_size
];
FP8_TYPE
output_arr
[
VEC_SIZE
];
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
for
(
uint32_t
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
float
val
=
fmax
(
fmin
(
static_cast
<
float
>
(
input_vec
[
j
])
*
scale_val
,
FP8_E4M3_MAX
),
-
FP8_E4M3_MAX
);
#ifndef USE_ROCM
output_arr
[
j
]
=
static_cast
<
FP8_TYPE
>
(
val
);
#else
output_arr
[
j
]
=
c10
::
Float8_e4m3fnuz
(
__hip_cvt_float_to_fp8
(
val
ue
,
fp8
::
fp8_type
::
__default_saturation
,
fp8
::
fp8_type
::
__default_interpret
),
__hip_cvt_float_to_fp8
(
val
,
fp8
::
fp8_type
::
__default_saturation
,
fp8
::
fp8_type
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
output
[
i
*
vec_size
+
j
]
=
output_arr
[
j
];
}
*
(
uint4
*
)(
output
+
i
*
VEC_SIZE
)
=
*
(
uint4
*
)
output_arr
;
}
const
int32_t
remaining_start
=
num_vec_elems
*
vec_size
;
const
int32_t
remaining_start
=
num_vec_elems
*
VEC_SIZE
;
for
(
int32_t
idx
=
remaining_start
+
gid
;
idx
<
num_elements
;
idx
+=
grid_size
)
{
float
val
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
static_cast
<
float
>
(
input
[
idx
])
*
scale_val
,
FP8_E4M3_MAX
));
#ifndef USE_ROCM
output
[
idx
]
=
static_cast
<
FP8_TYPE
>
(
val
);
#else
output
[
idx
]
=
c10
::
Float8_e4m3fnuz
(
__hip_cvt_float_to_fp8
(
val
ue
,
fp8
::
fp8_type
::
__default_saturation
,
fp8
::
fp8_type
::
__default_interpret
),
__hip_cvt_float_to_fp8
(
val
,
fp8
::
fp8_type
::
__default_saturation
,
fp8
::
fp8_type
::
__default_interpret
),
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
...
...
sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
View file @
6a384d5c
...
...
@@ -24,17 +24,19 @@ __global__ void per_token_quant_fp8_kernel(
float
max_value
=
0.0
f
;
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
const
int32_t
num_vec_elems
=
hidden_dim
/
vec_size
;
// We want to store 128 bits of data at a time. 16 = 128 / 8 bits
// Load is already vectorized, so 16 elements work for T.
const
uint32_t
VEC_SIZE
=
16
;
using
vec_t
=
flashinfer
::
vec_t
<
T
,
VEC_SIZE
>
;
const
int32_t
num_vec_elems
=
hidden_dim
/
VEC_SIZE
;
// Find max using vectorized loads
for
(
int32_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
block_dim
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
token_input
+
i
*
vec_size
);
input_vec
.
cast_load
(
token_input
+
i
*
VEC_SIZE
);
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
for
(
uint32_t
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
max_value
=
fmaxf
(
max_value
,
fabsf
(
val
));
}
...
...
@@ -42,24 +44,24 @@ __global__ void per_token_quant_fp8_kernel(
max_value
=
blockReduceMax
(
max_value
);
__shared__
float
block_max
;
__shared__
float
scale
;
if
(
tid
==
0
)
{
block_max
=
max_value
/
FP8_E4M3_MAX
;
output_s
[
token_idx
]
=
block_max
;
scale
=
max_value
/
FP8_E4M3_MAX
;
output_s
[
token_idx
]
=
scale
;
}
__syncthreads
();
const
float
scale_
val
=
1.0
f
/
block_max
;
const
float
scale_
inv
=
1.0
f
/
scale
;
// Quantize using vectorized loads
for
(
int32_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
block_dim
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
token_input
+
i
*
vec_size
);
input_vec
.
cast_load
(
token_input
+
i
*
VEC_SIZE
);
FP8_TYPE
output_arr
[
vec_size
];
FP8_TYPE
output_arr
[
VEC_SIZE
];
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
fmaxf
(
fminf
(
static_cast
<
float
>
(
input_vec
[
j
])
*
scale_
val
,
FP8_E4M3_MAX
),
-
FP8_E4M3_MAX
);
for
(
uint32_t
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
float
val
=
fmaxf
(
fminf
(
static_cast
<
float
>
(
input_vec
[
j
])
*
scale_
inv
,
FP8_E4M3_MAX
),
-
FP8_E4M3_MAX
);
#ifndef USE_ROCM
output_arr
[
j
]
=
static_cast
<
FP8_TYPE
>
(
val
);
#else
...
...
@@ -69,10 +71,7 @@ __global__ void per_token_quant_fp8_kernel(
#endif
}
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
token_output
[
i
*
vec_size
+
j
]
=
output_arr
[
j
];
}
*
(
uint4
*
)(
token_output
+
i
*
VEC_SIZE
)
=
*
(
uint4
*
)
output_arr
;
}
}
...
...
@@ -85,7 +84,7 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
const
int64_t
num_tokens
=
input_sizes
[
0
];
const
int64_t
hidden_dim
=
input_sizes
[
1
];
TORCH_CHECK
(
hidden_dim
%
8
==
0
,
"Hidden dimension must be divisible by
8
, but got "
,
hidden_dim
);
TORCH_CHECK
(
hidden_dim
%
16
==
0
,
"Hidden dimension must be divisible by
16
, but got "
,
hidden_dim
);
const
int
block_size
=
256
;
const
int
num_blocks
=
num_tokens
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment