Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5563a4de
Unverified
Commit
5563a4de
authored
Jun 05, 2024
by
Cody Yu
Committed by
GitHub
Jun 05, 2024
Browse files
[Model] Correct Mixtral FP8 checkpoint loading (#5231)
parent
ccd4f129
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
35 deletions
+80
-35
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+4
-3
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+76
-32
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
5563a4de
...
@@ -300,14 +300,15 @@ def all_close_1d(x: torch.Tensor) -> bool:
...
@@ -300,14 +300,15 @@ def all_close_1d(x: torch.Tensor) -> bool:
def
per_tensor_quantize
(
tensor
:
torch
.
Tensor
,
def
per_tensor_quantize
(
tensor
:
torch
.
Tensor
,
inv_scale
:
float
)
->
torch
.
Tensor
:
inv_scale
:
Union
[
float
,
torch
.
Tensor
]
)
->
torch
.
Tensor
:
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
qweight
=
(
tensor
/
inv_scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
qweight
=
(
tensor
/
inv_scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
return
qweight
.
to
(
torch
.
float8_e4m3fn
)
return
qweight
.
to
(
torch
.
float8_e4m3fn
)
def
per_tensor_dequantize
(
tensor
:
torch
.
Tensor
,
def
per_tensor_dequantize
(
inv_scale
:
float
)
->
torch
.
Tensor
:
tensor
:
torch
.
Tensor
,
inv_scale
:
Union
[
float
,
torch
.
Tensor
])
->
torch
.
Tensor
:
fake_qweight
=
tensor
.
to
(
torch
.
float16
)
fake_qweight
=
tensor
.
to
(
torch
.
float16
)
dq_weight
=
fake_qweight
*
inv_scale
dq_weight
=
fake_qweight
*
inv_scale
return
dq_weight
return
dq_weight
vllm/model_executor/models/mixtral.py
View file @
5563a4de
...
@@ -41,7 +41,9 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
...
@@ -41,7 +41,9 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8Config
,
per_tensor_dequantize
,
per_tensor_quantize
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
@@ -98,16 +100,16 @@ class MixtralMoE(nn.Module):
...
@@ -98,16 +100,16 @@ class MixtralMoE(nn.Module):
if
self
.
use_fp8
and
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
use_fp8
and
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
params_dtype
=
torch
.
float8_e4m3fn
params_dtype
=
torch
.
float8_e4m3fn
self
.
w13_weight
=
nn
.
Parameter
(
self
.
w13_weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
2
*
self
.
intermediate
_size
,
self
.
hidden
_size
,
self
.
hidden_size
,
dtype
=
params_dtype
)
,
dtype
=
params_dtyp
e
)
)
requires_grad
=
Fals
e
)
self
.
w2_weight
=
nn
.
Parameter
(
self
.
w2_weight
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
torch
.
empty
(
self
.
num_total_experts
,
self
.
hidden_size
,
self
.
hidden
_size
,
self
.
intermediate
_size
,
self
.
intermediate_size
,
dtype
=
params_dtype
)
,
dtype
=
params_dtyp
e
)
)
requires_grad
=
Fals
e
)
set_weight_attrs
(
self
.
w13_weight
,
{
set_weight_attrs
(
self
.
w13_weight
,
{
"weight_loader"
:
self
.
weight_loader
,
"weight_loader"
:
self
.
weight_loader
,
...
@@ -124,7 +126,10 @@ class MixtralMoE(nn.Module):
...
@@ -124,7 +126,10 @@ class MixtralMoE(nn.Module):
if
self
.
use_fp8
:
if
self
.
use_fp8
:
# WEIGHT_SCALE (for fp8)
# WEIGHT_SCALE (for fp8)
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
self
.
w13_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
self
.
w13_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
2
,
dtype
=
torch
.
float32
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
requires_grad
=
False
)
self
.
w2_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
self
.
w2_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
...
@@ -148,11 +153,11 @@ class MixtralMoE(nn.Module):
...
@@ -148,11 +153,11 @@ class MixtralMoE(nn.Module):
raise
ValueError
(
raise
ValueError
(
"Found static activation scheme for checkpoint that "
"Found static activation scheme for checkpoint that "
"was not serialized fp8."
)
"was not serialized fp8."
)
self
.
a13_scale
=
nn
.
Parameter
(
torch
.
zero
s
(
self
.
a13_scale
=
nn
.
Parameter
(
torch
.
one
s
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
a2_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
requires_grad
=
False
)
set_weight_attrs
(
self
.
a13_scale
,
{
set_weight_attrs
(
self
.
a13_scale
,
{
...
@@ -175,8 +180,22 @@ class MixtralMoE(nn.Module):
...
@@ -175,8 +180,22 @@ class MixtralMoE(nn.Module):
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w2.weight"
):
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
if
"act_scale"
in
weight_name
or
"weight_scale"
in
weight_name
:
# Loading scales
if
"act_scale"
in
weight_name
or
"w2.weight_scale"
in
weight_name
:
if
param_data
[
expert_id
]
!=
1
and
(
param_data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
"act_scales of w1 and w3 of a layer "
f
"must be equal. But got
{
param_data
[
expert_id
]
}
"
f
"vs.
{
loaded_weight
}
"
)
param_data
[
expert_id
]
=
loaded_weight
param_data
[
expert_id
]
=
loaded_weight
elif
"weight_scale"
in
weight_name
:
# We have to keep the weight scales of w1 and w3 because
# we need to re-quantize w1/w3 weights after weight loading.
assert
"w1"
in
weight_name
or
"w3"
in
weight_name
shard_id
=
0
if
"w1"
in
weight_name
else
1
param_data
[
expert_id
][
shard_id
]
=
loaded_weight
def
process_weights_after_loading
(
self
):
def
process_weights_after_loading
(
self
):
# Fp8 is the only case where we need to process after loading.
# Fp8 is the only case where we need to process after loading.
...
@@ -189,6 +208,12 @@ class MixtralMoE(nn.Module):
...
@@ -189,6 +208,12 @@ class MixtralMoE(nn.Module):
dtype
=
torch
.
float8_e4m3fn
)
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
self
.
w2_weight
.
data
,
w2_weight
=
torch
.
empty_like
(
self
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
dtype
=
torch
.
float8_e4m3fn
)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
self
.
w13_scale
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_total_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
for
expert
in
range
(
self
.
num_total_experts
):
for
expert
in
range
(
self
.
num_total_experts
):
w13_weight
[
expert
,
:,
:],
self
.
w13_scale
[
w13_weight
[
expert
,
:,
:],
self
.
w13_scale
[
expert
]
=
ops
.
scaled_fp8_quant
(
expert
]
=
ops
.
scaled_fp8_quant
(
...
@@ -199,25 +224,44 @@ class MixtralMoE(nn.Module):
...
@@ -199,25 +224,44 @@ class MixtralMoE(nn.Module):
self
.
w13_weight
=
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
self
.
w13_weight
=
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
self
.
w2_weight
=
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
self
.
w2_weight
=
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
# If checkpoint is fp8 + static, cleanup act_scales.
else
:
# Since state_dict has an act_scale per expert but our kernels
# If checkpoint is fp8 + static, cleanup act_scales.
# are passed one act_scale shared across all experts.
# Since state_dict has an act_scale per expert but our kernels
elif
self
.
quant_config
.
activation_scheme
==
"static"
:
# are passed one act_scale shared across all experts.
if
self
.
a13_scale
is
None
or
self
.
a2_scale
is
None
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
raise
ValueError
(
if
self
.
a13_scale
is
None
or
self
.
a2_scale
is
None
:
"QuantConfig has static quantization, but found "
raise
ValueError
(
"activation scales are None."
)
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
(
not
all_close_1d
(
self
.
a13_scale
)
if
(
not
all_close_1d
(
self
.
a13_scale
)
or
not
all_close_1d
(
self
.
a2_scale
)):
or
not
all_close_1d
(
self
.
a2_scale
)):
print_warning_once
(
print_warning_once
(
"Found act_scales that are not equal for fp8 MoE layer. "
"Found act_scales that are not equal for "
"Using the maximum across experts for each layer. "
)
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
self
.
a13_scale
=
nn
.
Parameter
(
self
.
a13_scale
.
max
(),
self
.
a13_scale
=
nn
.
Parameter
(
self
.
a13_scale
.
max
(),
requires_grad
=
False
)
requires_grad
=
False
)
self
.
a2_scale
=
nn
.
Parameter
(
self
.
a2_scale
.
max
(),
self
.
a2_scale
=
nn
.
Parameter
(
self
.
a2_scale
.
max
(),
requires_grad
=
False
)
requires_grad
=
False
)
assert
self
.
w13_scale
is
not
None
shard_size
=
self
.
intermediate_size
max_w13_scales
=
self
.
w13_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
self
.
num_total_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
self
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
self
.
w13_scale
[
expert_id
][
shard_id
])
self
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:]
=
per_tensor_quantize
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
self
.
w13_scale
=
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
num_tokens
,
hidden_size
=
hidden_states
.
shape
num_tokens
,
hidden_size
=
hidden_states
.
shape
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment