Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c09dade2
Unverified
Commit
c09dade2
authored
Jun 08, 2024
by
Michael Goin
Committed by
GitHub
Jun 08, 2024
Browse files
[Misc][Breaking] Change FP8 checkpoint format from act_scale -> input_scale (#5353)
parent
8ea5e44a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
23 deletions
+23
-23
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+15
-15
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+8
-8
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
c09dade2
...
@@ -171,10 +171,10 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -171,10 +171,10 @@ class Fp8LinearMethod(LinearMethodBase):
output_partition_sizes
=
output_partition_sizes
,
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
)
**
extra_weight_attrs
)
# ACTIVATION SCALE
#
INPUT
ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
self
.
_create_scale_param
(
self
.
_create_scale_param
(
scale_name
=
"
ac
t_scale"
,
scale_name
=
"
inpu
t_scale"
,
layer
=
layer
,
layer
=
layer
,
output_partition_sizes
=
output_partition_sizes
,
output_partition_sizes
=
output_partition_sizes
,
**
extra_weight_attrs
)
**
extra_weight_attrs
)
...
@@ -207,7 +207,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -207,7 +207,7 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
logical_widths
=
None
layer
.
logical_widths
=
None
layer
.
ac
t_scale
=
None
layer
.
inpu
t_scale
=
None
return
return
# If checkpoint is fp8, requantize the separately quantized logical
# If checkpoint is fp8, requantize the separately quantized logical
...
@@ -232,18 +232,18 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -232,18 +232,18 @@ class Fp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
#
ACT_
SCALE
#
INPUT ACTIVATION
SCALE
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the
ac
t_scales (since they are equal).
# Static: set to max of the
inpu
t_scales (since they are equal).
if
self
.
quant_config
.
activation_scheme
==
"dynamic"
:
if
self
.
quant_config
.
activation_scheme
==
"dynamic"
:
layer
.
ac
t_scale
=
None
layer
.
inpu
t_scale
=
None
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
if
not
all_close_1d
(
layer
.
ac
t_scale
):
if
not
all_close_1d
(
layer
.
inpu
t_scale
):
raise
ValueError
(
raise
ValueError
(
"All the
ac
t_scales for the logical weights of a
layer
"
"All the
inpu
t_scales for the logical weights of a "
f
"must be equal. But got
{
layer
.
ac
t_scale
}
"
)
f
"
layer
must be equal. But got
{
layer
.
inpu
t_scale
}
"
)
layer
.
ac
t_scale
=
Parameter
(
layer
.
ac
t_scale
.
max
(),
layer
.
inpu
t_scale
=
Parameter
(
layer
.
inpu
t_scale
.
max
(),
requires_grad
=
False
)
requires_grad
=
False
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Unknown scheme
{
self
.
quant_config
.
activation_scheme
}
"
)
f
"Unknown scheme
{
self
.
quant_config
.
activation_scheme
}
"
)
...
@@ -254,11 +254,11 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -254,11 +254,11 @@ class Fp8LinearMethod(LinearMethodBase):
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.
ac
t_scale is None and x_scale computed from x.
# If dynamic, layer.
inpu
t_scale is None and x_scale computed from x.
# If static,
layer.
ac
t_scale is scalar and x_scale
set to ac
t_scale.
# If static, layer.
inpu
t_scale is scalar and x_scale
is inpu
t_scale.
if
bias
is
None
and
self
.
cutlass_fp8_supported
:
if
bias
is
None
and
self
.
cutlass_fp8_supported
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
ac
t_scale
)
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
inpu
t_scale
)
# Fused GEMM_DQ
# Fused GEMM_DQ
output
=
ops
.
cutlass_scaled_mm_dq
(
output
=
ops
.
cutlass_scaled_mm_dq
(
...
@@ -271,7 +271,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -271,7 +271,7 @@ class Fp8LinearMethod(LinearMethodBase):
else
:
else
:
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
x
,
layer
.
ac
t_scale
,
layer
.
inpu
t_scale
,
batch_dim_padding
=
17
)
batch_dim_padding
=
17
)
# Fused GEMM_DQ -- note we padded the input above because
# Fused GEMM_DQ -- note we padded the input above because
...
...
vllm/model_executor/models/mixtral.py
View file @
c09dade2
...
@@ -147,7 +147,7 @@ class MixtralMoE(nn.Module):
...
@@ -147,7 +147,7 @@ class MixtralMoE(nn.Module):
"weight_loader"
:
self
.
weight_loader
,
"weight_loader"
:
self
.
weight_loader
,
})
})
#
AC
T_SCALE (for fp8)
#
INPU
T_SCALE (for fp8)
if
quant_config
.
activation_scheme
==
"static"
:
if
quant_config
.
activation_scheme
==
"static"
:
if
not
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
quant_config
.
is_checkpoint_fp8_serialized
:
raise
ValueError
(
raise
ValueError
(
...
@@ -182,11 +182,11 @@ class MixtralMoE(nn.Module):
...
@@ -182,11 +182,11 @@ class MixtralMoE(nn.Module):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
# Loading scales
# Loading scales
if
"
ac
t_scale"
in
weight_name
or
"w2.weight_scale"
in
weight_name
:
if
"
inpu
t_scale"
in
weight_name
or
"w2.weight_scale"
in
weight_name
:
if
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
if
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
raise
ValueError
(
"
ac
t_scales of w1 and w3 of a layer "
"
inpu
t_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
param_data
[
expert_id
]
=
loaded_weight
...
@@ -225,9 +225,9 @@ class MixtralMoE(nn.Module):
...
@@ -225,9 +225,9 @@ class MixtralMoE(nn.Module):
self
.
w2_weight
=
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
self
.
w2_weight
=
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
else
:
else
:
# If checkpoint is fp8 + static, cleanup
ac
t_scales.
# If checkpoint is fp8 + static, cleanup
inpu
t_scales.
# Since state_dict has an
ac
t_scale per expert but our kernels
# Since state_dict has an
inpu
t_scale per expert but our kernels
# are passed one
ac
t_scale shared across all experts.
# are passed one
inpu
t_scale shared across all experts.
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
a13_scale
is
None
or
self
.
a2_scale
is
None
:
if
self
.
a13_scale
is
None
or
self
.
a2_scale
is
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -237,7 +237,7 @@ class MixtralMoE(nn.Module):
...
@@ -237,7 +237,7 @@ class MixtralMoE(nn.Module):
if
(
not
all_close_1d
(
self
.
a13_scale
)
if
(
not
all_close_1d
(
self
.
a13_scale
)
or
not
all_close_1d
(
self
.
a2_scale
)):
or
not
all_close_1d
(
self
.
a2_scale
)):
print_warning_once
(
print_warning_once
(
"Found
ac
t_scales that are not equal for "
"Found
inpu
t_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
"for each layer. "
)
...
@@ -576,7 +576,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -576,7 +576,7 @@ class MixtralForCausalLM(nn.Module):
# These are the activation scales for the experts
# These are the activation scales for the experts
# (param_name, weight_name, expert_id)
# (param_name, weight_name, expert_id)
(
"a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"a2_scale"
,
(
"a13_scale"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"a2_scale"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.
ac
t_scale"
,
expert_id
)
f
"experts.
{
expert_id
}
.
{
weight_name
}
.
inpu
t_scale"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment