Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
90bb2be2
Unverified
Commit
90bb2be2
authored
Mar 07, 2025
by
Rex
Committed by
GitHub
Mar 07, 2025
Browse files
Minor improvement to per_tensor_quant_fp8 (#4197)
parent
b93ef5e5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
10 deletions
+3
-10
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
+3
-10
No files found.
sgl-kernel/src/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu
View file @
90bb2be2
...
...
@@ -57,13 +57,9 @@ per_tensor_absmax_kernel(const T* __restrict__ input, float* __restrict__ output
template
<
typename
T
>
__global__
void
per_tensor_quant_fp8_kernel
(
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output
,
const
float
*
__restrict__
scale
,
const
int64_t
num_elements
)
{
const
T
*
__restrict__
input
,
FP8_TYPE
*
__restrict__
output
,
const
float
scale_val
,
const
int64_t
num_elements
)
{
const
int
gid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
grid_size
=
blockDim
.
x
*
gridDim
.
x
;
const
float
scale_val
=
1.0
f
/
(
*
scale
);
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
T
);
using
vec_t
=
flashinfer
::
vec_t
<
T
,
vec_size
>
;
...
...
@@ -125,12 +121,9 @@ void sgl_per_tensor_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch
per_tensor_absmax_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
num_elements
);
}
float
scale_val
=
1.0
f
/
(
*
static_cast
<
float
*>
(
output_s
.
data_ptr
()));
per_tensor_quant_fp8_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
static_cast
<
float
*>
(
output_s
.
data_ptr
()),
num_elements
);
static_cast
<
scalar_t
*>
(
input
.
data_ptr
()),
static_cast
<
FP8_TYPE
*>
(
output_q
.
data_ptr
()),
scale_val
,
num_elements
);
return
true
;
});
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment