Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
79a321af
Unverified
Commit
79a321af
authored
Mar 09, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Mar 08, 2025
Browse files
revert pr 3628 to pass test_mla ci (#4219)
parent
6eec3cdc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
30 additions
and
46 deletions
+30
-46
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
...nel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
+30
-46
No files found.
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_group_quant_fp8.cu
View file @
79a321af
...
...
@@ -2,18 +2,17 @@
#include <c10/util/Float8_e4m3fn.h>
#include <cmath>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
__device__
__forceinline__
float
GroupReduce
(
float
val
,
const
int
tid
)
{
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
0xffff
,
val
,
8
)
);
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
0xffff
,
val
,
4
)
);
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
0xffff
,
val
,
2
)
);
val
=
fmaxf
(
val
,
__shfl_xor_sync
(
0xffff
,
val
,
1
)
);
return
val
;
__device__
__forceinline__
float
GroupReduce
Max
(
volatile
float
*
smem
,
const
int
tid
)
{
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
8
]
);
if
(
tid
<
4
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
4
]
);
if
(
tid
<
2
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
2
]
);
if
(
tid
<
1
)
smem
[
tid
]
=
fmaxf
(
smem
[
tid
],
smem
[
tid
+
1
]
);
return
smem
[
0
]
;
}
template
<
typename
T
>
...
...
@@ -27,60 +26,45 @@ __global__ void per_token_group_quant_fp8_kernel(
const
float
fp8_min
,
const
float
fp8_max
)
{
const
int
groups_per_block
=
16
;
const
int
local_group_id
=
threadIdx
.
x
/
16
;
const
int
lane_id
=
threadIdx
.
x
%
16
;
const
int
block_group_id
=
blockIdx
.
x
*
groups_per_block
;
const
int
block_group_offset
=
(
block_group_id
+
local_group_id
)
*
group_size
;
const
int
tid
=
threadIdx
.
x
;
const
int
local_group_id
=
tid
/
16
;
const
int
local_tid
=
tid
%
16
;
__shared__
float
s_absmax
[
16
];
__shared__
float
s_absmax
[
16
]
[
17
]
;
float
local_absmax
=
eps
;
const
T
*
group_input
=
input
+
block_group_offset
;
FP8_TYPE
*
group_output
=
static_cast
<
FP8_TYPE
*>
(
output_q
)
+
block_group_offset
;
float
*
scale_output
=
output_s
+
block_group_id
+
local_group_id
;
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
const
int32_t
num_vec_elems
=
group_size
/
vec_size
;
if
(
block_group_id
+
local_group_id
<
num_groups
)
{
const
T
*
group_input
=
input
+
(
block_group_id
+
local_group_id
)
*
group_size
;
FP8_TYPE
*
group_output
=
static_cast
<
FP8_TYPE
*>
(
output_q
)
+
(
block_group_id
+
local_group_id
)
*
group_size
;
float
*
scale_output
=
output_s
+
block_group_id
+
local_group_id
;
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
16
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
group_input
+
i
*
vec_size
);
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
for
(
int
i
=
local_tid
;
i
<
group_size
;
i
+=
16
)
{
float
val
=
static_cast
<
float
>
(
group_input
[
i
]);
float
abs_val
=
fabsf
(
val
);
local_absmax
=
fmaxf
(
local_absmax
,
abs_val
);
}
}
local_absmax
=
GroupReduce
(
local_absmax
,
lane_id
);
s_absmax
[
local_group_id
][
local_tid
]
=
local_absmax
;
__syncthreads
();
if
(
l
ane_id
==
0
)
{
s_absmax
[
local_group_id
]
=
local_
absmax
;
}
__syncthreads
();
if
(
l
ocal_tid
<
8
)
{
GroupReduceMax
(
&
s_absmax
[
local_group_id
]
[
0
],
local_
tid
)
;
}
__syncthreads
();
const
float
group_absmax
=
s_absmax
[
local_group_id
];
const
float
y_s
=
group_absmax
/
fp8_max
;
const
float
group_absmax
=
s_absmax
[
local_group_id
]
[
0
]
;
const
float
y_s
=
group_absmax
/
fp8_max
;
if
(
lane_id
==
0
)
{
*
scale_output
=
y_s
;
}
for
(
int32_t
i
=
lane_id
;
i
<
num_vec_elems
;
i
+=
16
)
{
vec_t
input_vec
;
input_vec
.
cast_load
(
group_input
+
i
*
vec_size
);
if
(
local_tid
==
0
)
{
*
scale_output
=
y_s
;
}
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
float
val
=
static_cast
<
float
>
(
input_vec
[
j
]);
for
(
int
i
=
local_tid
;
i
<
group_size
;
i
+=
16
)
{
float
val
=
static_cast
<
float
>
(
group_input
[
i
]);
float
q_val
=
fminf
(
fmaxf
(
val
/
y_s
,
fp8_min
),
fp8_max
);
group_output
[
i
*
vec_size
+
j
]
=
FP8_TYPE
(
q_val
);
group_output
[
i
]
=
FP8_TYPE
(
q_val
);
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment