Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
587a5c60
Commit
587a5c60
authored
Apr 22, 2026
by
zhaosong1
Browse files
[feature] add scaled_fp8_quant_weight for online ptpc_fp8 quant.
parent
f1eb27b8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
2 deletions
+65
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+64
-1
vllm/model_executor/layers/quantization/ptpc_fp8.py
vllm/model_executor/layers/quantization/ptpc_fp8.py
+1
-1
No files found.
vllm/_custom_ops.py
View file @
587a5c60
...
...
@@ -1386,6 +1386,68 @@ def scaled_fp8_quant(
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
scale_ub: Optional upper bound for scaling factor in dynamic
per token case
num_token_padding: If specified, pad the first dimension
of the output to at least this value.
use_per_token_if_dynamic: Whether to do per_tensor or per_token
in the dynamic quantization case.
Returns:
tuple[torch.Tensor, torch.Tensor]: The output tensor in FP8 and
scaling factor.
"""
# This code assumes batch_dim and num_tokens are flattened
assert
(
input
.
ndim
==
2
)
shape
:
Union
[
tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
# For ROCm on MI300, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype
:
torch
.
dtype
=
current_platform
.
fp8_dtype
()
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
if
output
is
None
:
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
else
:
assert
num_token_padding
is
None
,
\
"padding not supported if output passed in"
assert
output
.
dtype
==
out_dtype
if
scale
is
None
:
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
((
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
output
,
scale
=
per_token_quant_fp8
(
input
.
contiguous
())
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
assert
scale
.
numel
()
==
1
,
f
"
{
scale
.
shape
}
"
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
def
scaled_fp8_quant_weight
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
scale_ub
:
Optional
[
torch
.
Tensor
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 and return quantized tensor and scale.
This function supports both static and dynamic quantization: If you
provide the scale, it will use static scaling and if you omit it,
the scale will be determined dynamically. The function also allows
optional padding of the output tensors for downstream kernels that
will benefit from padding.
Args:
input: The input tensor to be quantized to FP8
scale: Optional scaling factor for the FP8 quantization
...
...
@@ -1421,7 +1483,7 @@ def scaled_fp8_quant(
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
(
output
,
input
.
contiguous
(),
scale
,
scale_ub
)
# per_token_quant_fp8 has precision problem.
# per_token_quant_fp8 has precision problem
for online weight quant
.
# output, scale = per_token_quant_fp8(input.contiguous())
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
...
...
@@ -1431,6 +1493,7 @@ def scaled_fp8_quant(
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
# gptq allspark
def
allspark_repack_weight
(
qweight
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/ptpc_fp8.py
View file @
587a5c60
...
...
@@ -107,7 +107,7 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
assert
layer
.
weight
.
data
.
dtype
==
torch
.
bfloat16
,
\
f
"Currently torch._scaled_mm (hipBLASLt) rowwise gemm only support output dtype of bfloat16.
{
str
(
layer
.
weight
.
data
.
dtype
)
}
is specified."
# noqa: E501
# Quantize the weights.
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
_weight
(
layer
.
weight
,
scale
=
None
,
use_per_token_if_dynamic
=
True
)
# Update the layer with the new values.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment