Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
325f679f
Unverified
Commit
325f679f
authored
Jan 31, 2025
by
Robert Shaw
Committed by
GitHub
Jan 31, 2025
Browse files
[BugFix] Fix Torch.Compile For DeepSeek (#12594)
Co-authored-by:
simon-mo
<
xmo@berkeley.edu
>
parent
e3f7ff65
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
25 deletions
+29
-25
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+29
-25
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
325f679f
...
...
@@ -245,20 +245,24 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
#
Block quant doesn't need to process weights after loading
#
TODO(rob): refactor block quant into separate class.
if
self
.
block_quant
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
if
current_platform
.
is_rocm
():
weight
,
weight_scale
,
_
=
\
weight
,
weight_scale
_inv
,
_
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale_inv
,
input_scale
=
layer
.
input_scale
)
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
weight_scale
=
layer
.
weight_scale_inv
)
else
:
weight
=
layer
.
weight
.
data
weight_scale_inv
=
layer
.
weight_scale_inv
.
data
# Torch.compile cannot use Parameter subclasses.
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
Parameter
(
weight_scale_inv
,
requires_grad
=
False
)
return
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
...
...
@@ -507,8 +511,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
#
Block quant doesn't need to process weights after loading
#
TODO (rob): refactor block quant into separate class.
if
self
.
block_quant
:
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
if
current_platform
.
is_rocm
():
w13_weight
,
w13_weight_scale_inv
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
...
...
@@ -518,22 +523,21 @@ class Fp8MoEMethod(FusedMoEMethodBase):
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w2_weight
,
layer
.
w2_weight_scale_inv
,
layer
.
w2_input_scale
)
# Reset the parameter
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale_inv
=
torch
.
nn
.
Parameter
(
w13_weight_scale_inv
,
requires_grad
=
False
)
if
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
w13_input_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale_inv
=
torch
.
nn
.
Parameter
(
w2_weight_scale_inv
,
requires_grad
=
False
)
if
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
w2_input_scale
,
requires_grad
=
False
)
else
:
w13_weight
=
layer
.
w13_weight
.
data
w13_weight_scale_inv
=
layer
.
w13_weight_scale_inv
.
data
w2_weight
=
layer
.
w2_weight
w2_weight_scale_inv
=
layer
.
w2_weight_scale_inv
# torch.compile() cannot use Parameter subclasses.
layer
.
w13_weight
=
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale_inv
=
Parameter
(
w13_weight_scale_inv
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale_inv
=
Parameter
(
w2_weight_scale_inv
,
requires_grad
=
False
)
return
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If rocm, use float8_e4m3fnuz as dtype
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment