Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
873edda6
Unverified
Commit
873edda6
authored
Sep 25, 2024
by
Michael Goin
Committed by
GitHub
Sep 25, 2024
Browse files
[Misc] Support FP8 MoE for compressed-tensors (#8588)
parent
64840dfa
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
226 additions
and
8 deletions
+226
-8
tests/weight_loading/models-large.txt
tests/weight_loading/models-large.txt
+1
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+7
-2
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+215
-3
vllm/model_executor/models/phimoe.py
vllm/model_executor/models/phimoe.py
+2
-2
No files found.
tests/weight_loading/models-large.txt
View file @
873edda6
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W8A16-quantized, main
compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main
vllm/model_executor/layers/fused_moe/layer.py
View file @
873edda6
...
@@ -323,10 +323,12 @@ class FusedMoE(torch.nn.Module):
...
@@ -323,10 +323,12 @@ class FusedMoE(torch.nn.Module):
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
expert_id
:
int
)
->
None
:
shard_id
:
str
,
expert_id
:
int
)
->
None
:
# compressed-tensors represents weights on disk which are flipped
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO (mgoin): check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
(
loaded_weight
=
loaded_weight
.
t
().
contiguous
()
if
(
self
.
quant_method
.
__class__
.
__name__
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsMoEMethod"
)
else
loaded_weight
==
"CompressedTensors
WNA16
MoEMethod"
)
else
loaded_weight
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
if
shard_id
not
in
(
"w1"
,
"w2"
,
"w3"
):
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
raise
ValueError
(
f
"shard_id must be ['w1','w2','w3'] but "
...
@@ -353,6 +355,9 @@ class FusedMoE(torch.nn.Module):
...
@@ -353,6 +355,9 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8
# Case input scale: input_scale loading is only supported for fp8
if
"input_scale"
in
weight_name
:
if
"input_scale"
in
weight_name
:
# this is needed for compressed-tensors only
loaded_weight
=
loaded_weight
.
to
(
param
.
data
.
device
)
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
if
param
.
data
[
expert_id
]
!=
1
and
(
param
.
data
[
expert_id
]
-
loaded_weight
).
abs
()
>
1e-5
:
loaded_weight
).
abs
()
>
1e-5
:
raise
ValueError
(
raise
ValueError
(
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
873edda6
...
@@ -73,7 +73,7 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -73,7 +73,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if
isinstance
(
layer
,
Attention
):
if
isinstance
(
layer
,
Attention
):
return
CompressedTensorsKVCacheMethod
(
self
)
return
CompressedTensorsKVCacheMethod
(
self
)
if
isinstance
(
layer
,
FusedMoE
):
if
isinstance
(
layer
,
FusedMoE
):
return
CompressedTensorsMoEMethod
(
self
)
return
CompressedTensorsMoEMethod
.
get_moe_method
(
self
)
return
None
return
None
@
classmethod
@
classmethod
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
873edda6
...
@@ -5,12 +5,16 @@ from typing import Callable, List, Optional
...
@@ -5,12 +5,16 @@ from typing import Callable, List, Optional
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
FusedMoeWeightScaleSupported
)
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
WNA16_SUPPORTED_BITS
)
WNA16_SUPPORTED_BITS
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
)
CompressionFormat
,
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
is_hip
,
print_warning_once
class
GPTQMarlinState
(
Enum
):
class
GPTQMarlinState
(
Enum
):
...
@@ -18,11 +22,219 @@ class GPTQMarlinState(Enum):
...
@@ -18,11 +22,219 @@ class GPTQMarlinState(Enum):
READY
=
enum
.
auto
()
READY
=
enum
.
auto
()
__all__
=
[
"CompressedTensorsMoEMethod"
]
__all__
=
[
"CompressedTensorsMoEMethod"
,
"CompressedTensorsW8A8Fp8MoEMethod"
,
"CompressedTensorsWNA16MoEMethod"
]
class
CompressedTensorsMoEMethod
(
FusedMoEMethodBase
):
class
CompressedTensorsMoEMethod
(
FusedMoEMethodBase
):
@
staticmethod
def
get_moe_method
(
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
)
->
"CompressedTensorsMoEMethod"
:
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
quant_config
.
_is_wNa16_group_channel
(
weight_quant
,
input_quant
):
return
CompressedTensorsWNA16MoEMethod
(
quant_config
)
elif
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Fp8MoEMethod
(
quant_config
)
else
:
raise
RuntimeError
(
f
"Unsupported FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
class
CompressedTensorsW8A8Fp8MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
):
self
.
quant_config
=
quant_config
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
not
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TENSOR
):
raise
ValueError
(
"For FP8 Fused MoE layers, only per-tensor scales"
"for weights and activations are supported. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
if
self
.
static_input_scales
:
w13_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
set_weight_attrs
(
w13_input_scale
,
extra_weight_attrs
)
w2_input_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
set_weight_attrs
(
w2_input_scale
,
extra_weight_attrs
)
else
:
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if
self
.
static_input_scales
:
if
(
layer
.
w13_input_scale
is
None
or
layer
.
w2_input_scale
is
None
):
raise
ValueError
(
"QuantConfig has static quantization, but found "
"activation scales are None."
)
if
(
not
all_close_1d
(
layer
.
w13_input_scale
)
or
not
all_close_1d
(
layer
.
w2_input_scale
)):
print_warning_once
(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. "
)
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w13_weight
,
layer
.
w13_weight_scale
,
layer
.
w13_input_scale
)
w2_weight
,
w2_weight_scale
,
w2_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w2_weight
,
layer
.
w2_weight_scale
,
layer
.
w2_input_scale
)
# Reset the parameter
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
if
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
w13_input_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
if
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
w2_input_scale
,
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert
layer
.
w13_weight_scale
is
not
None
shard_size
=
layer
.
intermediate_size_per_partition
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
for
expert_id
in
range
(
layer
.
num_experts
):
start
=
0
for
shard_id
in
range
(
2
):
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
])
layer
.
w13_weight
[
expert_id
][
start
:
start
+
shard_size
,
:],
_
=
ops
.
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
shard_size
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe
import
fused_experts
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
)
class
CompressedTensorsWNA16MoEMethod
(
CompressedTensorsMoEMethod
):
def
__init__
(
def
__init__
(
self
,
self
,
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
quant_config
:
"CompressedTensorsConfig"
# type: ignore # noqa E501
...
...
vllm/model_executor/models/phimoe.py
View file @
873edda6
...
@@ -321,13 +321,13 @@ class PhiMoEAttention(nn.Module):
...
@@ -321,13 +321,13 @@ class PhiMoEAttention(nn.Module):
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
self
.
total_num_kv_heads
,
bias
=
True
,
bias
=
True
,
quant_config
=
None
,
quant_config
=
quant_config
,
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
hidden_size
,
bias
=
True
,
bias
=
True
,
quant_config
=
None
,
quant_config
=
quant_config
,
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment