Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
753b29c0
Commit
753b29c0
authored
Apr 23, 2026
by
zhaosong
Committed by
zhangzbb
Apr 23, 2026
Browse files
[bugfix][fp8]Enable torch.compile for quant_fp8.
parent
2444e959
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
2 deletions
+30
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+29
-1
vllm/model_executor/layers/quantization/input_quant_fp8.py
vllm/model_executor/layers/quantization/input_quant_fp8.py
+1
-1
No files found.
vllm/_custom_ops.py
View file @
753b29c0
...
@@ -1900,6 +1900,28 @@ def scaled_fp4_experts_quant(
...
@@ -1900,6 +1900,28 @@ def scaled_fp4_experts_quant(
output_scales
=
output_scales
.
view
(
torch
.
float8_e4m3fn
)
output_scales
=
output_scales
.
view
(
torch
.
float8_e4m3fn
)
return
output
,
output_scales
return
output
,
output_scales
def
_lightop_per_token_quant_fp8_impl
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
)
->
None
:
from
lightop
import
op
op
.
per_token_quant_fp8
(
out
,
input
,
scales
)
def
_lightop_per_token_quant_fp8_fake
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
)
->
None
:
pass
direct_register_custom_op
(
"lightop_per_token_quant_fp8"
,
_lightop_per_token_quant_fp8_impl
,
mutates_args
=
[
"out"
,
"scales"
],
fake_impl
=
_lightop_per_token_quant_fp8_fake
,
)
def
scaled_fp8_quant
(
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -1952,7 +1974,13 @@ def scaled_fp8_quant(
...
@@ -1952,7 +1974,13 @@ def scaled_fp8_quant(
dtype
=
torch
.
float32
)
dtype
=
torch
.
float32
)
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# torch.ops._C.dynamic_per_token_scaled_fp8_quant(
# output, input.contiguous(), scale, scale_ub)
# output, input.contiguous(), scale, scale_ub)
output
,
scale
=
per_token_quant_fp8
(
input
.
contiguous
())
# output, scale = per_token_quant_fp8(input.contiguous())
output
=
torch
.
empty_like
(
input
,
device
=
input
.
device
,
dtype
=
torch
.
float8_e4m3fn
)
scale
=
torch
.
empty
(
shape
[:
-
1
]
+
(
1
,
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
vllm
.
lightop_per_token_quant_fp8
(
output
,
input
,
scale
)
else
:
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
...
...
vllm/model_executor/layers/quantization/input_quant_fp8.py
View file @
753b29c0
...
@@ -52,7 +52,7 @@ class QuantFP8(CustomOp):
...
@@ -52,7 +52,7 @@ class QuantFP8(CustomOp):
column major format
column major format
:param compile_native: Manually compile forward_native if compile mode > None
:param compile_native: Manually compile forward_native if compile mode > None
"""
"""
super
().
__init__
(
compile_native
=
compile_native
)
super
().
__init__
(
compile_native
=
compile_native
,
enforce_enable
=
True
)
self
.
static
=
static
self
.
static
=
static
self
.
group_shape
=
group_shape
self
.
group_shape
=
group_shape
self
.
use_per_token_if_dynamic
=
group_shape
==
GroupShape
.
PER_TOKEN
self
.
use_per_token_if_dynamic
=
group_shape
==
GroupShape
.
PER_TOKEN
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment