Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b382a7f2
Unverified
Commit
b382a7f2
authored
Feb 26, 2025
by
Woosuk Kwon
Committed by
GitHub
Feb 26, 2025
Browse files
[BugFix] Make FP8 Linear compatible with torch.compile (#13918)
Signed-off-by:
Woosuk Kwon
<
woosuk.kwon@berkeley.edu
>
parent
4cb6fa0a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
4 deletions
+21
-4
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+1
-4
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+20
-0
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
b382a7f2
...
@@ -369,12 +369,9 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -369,12 +369,9 @@ class Fp8LinearMethod(LinearMethodBase):
size_k
=
layer
.
input_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
)
bias
=
bias
)
# Note: lazy import to avoid triton import error.
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_w8a8_block_fp8_linear
)
if
self
.
block_quant
:
if
self
.
block_quant
:
assert
self
.
quant_config
.
weight_block_size
is
not
None
assert
self
.
quant_config
.
weight_block_size
is
not
None
return
apply_w8a8_block_fp8_linear
(
return
torch
.
ops
.
vllm
.
apply_w8a8_block_fp8_linear
(
input
=
x
,
input
=
x
,
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
block_size
=
self
.
quant_config
.
weight_block_size
,
block_size
=
self
.
quant_config
.
weight_block_size
,
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
b382a7f2
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
,
CUTLASS_FP8_SUPPORTED
,
apply_fp8_linear
)
CUTLASS_BLOCK_FP8_SUPPORTED
,
CUTLASS_FP8_SUPPORTED
,
apply_fp8_linear
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -81,6 +82,25 @@ def apply_w8a8_block_fp8_linear(
...
@@ -81,6 +82,25 @@ def apply_w8a8_block_fp8_linear(
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
return
output
.
to
(
dtype
=
input
.
dtype
).
view
(
*
output_shape
)
def
apply_w8a8_block_fp8_linear_fake
(
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
block_size
:
List
[
int
],
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
return
torch
.
empty
(
output_shape
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
direct_register_custom_op
(
op_name
=
"apply_w8a8_block_fp8_linear"
,
op_func
=
apply_w8a8_block_fp8_linear
,
mutates_args
=
[],
fake_impl
=
apply_w8a8_block_fp8_linear_fake
,
)
# Unify the interface between `apply_w8a8_block_fp8_linear` and
# Unify the interface between `apply_w8a8_block_fp8_linear` and
# `apply_fp8_linear`
# `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally
# NOTE(lucas): this is quite messy, we should think through this more formally
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment