Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b3251e9f
Unverified
Commit
b3251e9f
authored
Mar 08, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Mar 08, 2025
Browse files
refine quant kernel code style (#4211)
parent
2cadd51d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
25 deletions
+18
-25
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
+1
-12
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
+1
-13
sgl-kernel/src/sgl-kernel/include/utils.h
sgl-kernel/src/sgl-kernel/include/utils.h
+16
-0
No files found.
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
View file @
b3251e9f
...
...
@@ -37,18 +37,7 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
max_value
=
fmaxf
(
max_value
,
fabsf
(
val
));
}
static
__shared__
float
warpLevelMaxs
[
WARP_SIZE
];
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
max_value
=
warpReduceMax
(
max_value
);
if
(
laneId
==
0
)
warpLevelMaxs
[
warpId
]
=
max_value
;
__syncthreads
();
max_value
=
(
threadIdx
.
x
<
blockDim
.
x
/
WARP_SIZE
)
?
warpLevelMaxs
[
laneId
]
:
0
;
if
(
warpId
==
0
)
max_value
=
warpReduceMax
(
max_value
);
max_value
=
blockReduceMax
(
max_value
);
if
(
tid
==
0
)
{
atomicMaxFloat
(
output_s
,
max_value
/
FP8_E4M3_MAX
);
...
...
sgl-kernel/src/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu
View file @
b3251e9f
...
...
@@ -30,19 +30,7 @@ __global__ void per_token_quant_fp8_kernel(
max_value
=
fmaxf
(
max_value
,
fabsf
(
val
));
}
max_value
=
warpReduceMax
(
max_value
);
static
__shared__
float
warpLevelMaxs
[
WARP_SIZE
];
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
if
(
laneId
==
0
)
warpLevelMaxs
[
warpId
]
=
max_value
;
__syncthreads
();
if
(
warpId
==
0
)
{
max_value
=
(
threadIdx
.
x
<
blockDim
.
x
/
WARP_SIZE
)
?
warpLevelMaxs
[
laneId
]
:
0
;
max_value
=
warpReduceMax
(
max_value
);
}
max_value
=
blockReduceMax
(
max_value
);
__shared__
float
block_max
;
if
(
tid
==
0
)
{
...
...
sgl-kernel/src/sgl-kernel/include/utils.h
View file @
b3251e9f
...
...
@@ -124,4 +124,20 @@ __device__ __forceinline__ float warpReduceMax(float max_value) {
max_value
=
fmaxf
(
max_value
,
__shfl_xor_sync
(
0xffffffff
,
max_value
,
1
));
return
max_value
;
}
__device__
__forceinline__
float
blockReduceMax
(
float
max_value
)
{
static
__shared__
float
warpLevelMaxs
[
WARP_SIZE
];
const
int
laneId
=
threadIdx
.
x
%
WARP_SIZE
;
const
int
warpId
=
threadIdx
.
x
/
WARP_SIZE
;
max_value
=
warpReduceMax
(
max_value
);
if
(
laneId
==
0
)
warpLevelMaxs
[
warpId
]
=
max_value
;
__syncthreads
();
max_value
=
(
threadIdx
.
x
<
blockDim
.
x
/
WARP_SIZE
)
?
warpLevelMaxs
[
laneId
]
:
0
;
if
(
warpId
==
0
)
max_value
=
warpReduceMax
(
max_value
);
return
max_value
;
}
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment