Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0eb0757b
Unverified
Commit
0eb0757b
authored
Jul 23, 2024
by
Michael Goin
Committed by
GitHub
Jul 23, 2024
Browse files
[Misc] Add ignored layers for `fp8` quantization (#6657)
parent
38c4b7e8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
47 deletions
+57
-47
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
..._executor/layers/quantization/compressed_tensors/utils.py
+5
-9
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+3
-36
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+11
-2
vllm/model_executor/layers/quantization/utils/quant_utils.py
vllm/model_executor/layers/quantization/utils/quant_utils.py
+38
-0
No files found.
vllm/model_executor/layers/quantization/compressed_tensors/utils.py
View file @
0eb0757b
...
@@ -5,6 +5,9 @@ from typing import Any, Dict, Iterable, Optional
...
@@ -5,6 +5,9 @@ from typing import Any, Dict, Iterable, Optional
from
pydantic
import
BaseModel
,
Field
from
pydantic
import
BaseModel
,
Field
from
torch.nn
import
Module
from
torch.nn
import
Module
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
FUSED_LAYER_NAME_MAPPING
)
class
CompressionFormat
(
Enum
):
class
CompressionFormat
(
Enum
):
dense
=
"dense"
dense
=
"dense"
...
@@ -86,13 +89,6 @@ def is_activation_quantization_format(format: str) -> bool:
...
@@ -86,13 +89,6 @@ def is_activation_quantization_format(format: str) -> bool:
return
format
in
_ACTIVATION_QUANTIZATION_FORMATS
return
format
in
_ACTIVATION_QUANTIZATION_FORMATS
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
def
should_ignore_layer
(
layer_name
:
Optional
[
str
],
def
should_ignore_layer
(
layer_name
:
Optional
[
str
],
ignore
:
Iterable
[
str
])
->
bool
:
ignore
:
Iterable
[
str
])
->
bool
:
if
layer_name
is
None
:
if
layer_name
is
None
:
...
@@ -106,8 +102,8 @@ def should_ignore_layer(layer_name: Optional[str],
...
@@ -106,8 +102,8 @@ def should_ignore_layer(layer_name: Optional[str],
# in the safetensors checkpoint. So, we convert the name
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
# each shard of the fused layer has the same scheme.
if
proj_name
in
_
FUSED_LAYER_NAME_MAPPING
:
if
proj_name
in
FUSED_LAYER_NAME_MAPPING
:
shard_proj_names
=
_
FUSED_LAYER_NAME_MAPPING
[
proj_name
]
shard_proj_names
=
FUSED_LAYER_NAME_MAPPING
[
proj_name
]
# Convert fused_name --> [shard_names]
# Convert fused_name --> [shard_names]
shard_names
=
[
shard_names
=
[
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
0eb0757b
...
@@ -11,6 +11,8 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -11,6 +11,8 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
create_per_channel_scale_param
)
apply_fp8_linear
,
create_per_channel_scale_param
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -18,14 +20,6 @@ from vllm.platforms import current_platform
...
@@ -18,14 +20,6 @@ from vllm.platforms import current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
class
FBGEMMFp8Config
(
QuantizationConfig
):
class
FBGEMMFp8Config
(
QuantizationConfig
):
"""Config class for FBGEMM Fp8."""
"""Config class for FBGEMM Fp8."""
...
@@ -62,37 +56,10 @@ class FBGEMMFp8Config(QuantizationConfig):
...
@@ -62,37 +56,10 @@ class FBGEMMFp8Config(QuantizationConfig):
input_scale_ub
=
cls
.
get_from_keys
(
config
,
[
"activation_scale_ub"
])
input_scale_ub
=
cls
.
get_from_keys
(
config
,
[
"activation_scale_ub"
])
return
cls
(
ignore_list
=
ignore_list
,
input_scale_ub
=
input_scale_ub
)
return
cls
(
ignore_list
=
ignore_list
,
input_scale_ub
=
input_scale_ub
)
def
_is_layer_skipped
(
self
,
prefix
:
str
)
->
bool
:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name
=
prefix
.
split
(
"."
)[
-
1
]
if
proj_name
in
_FUSED_LAYER_NAME_MAPPING
:
shard_prefixes
=
[
prefix
.
replace
(
proj_name
,
shard_proj_name
)
for
shard_proj_name
in
_FUSED_LAYER_NAME_MAPPING
[
proj_name
]
]
is_skipped
=
None
for
shard_prefix
in
shard_prefixes
:
is_shard_skipped
=
shard_prefix
in
self
.
ignore_list
if
is_skipped
is
None
:
is_skipped
=
is_shard_skipped
elif
is_shard_skipped
!=
is_skipped
:
raise
ValueError
(
f
"Detected some but not all shards of
{
prefix
}
"
"are quantized. All shards of fused layers "
"to have the same precision."
)
else
:
is_skipped
=
prefix
in
self
.
ignore_list
assert
is_skipped
is
not
None
return
is_skipped
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
if
self
.
_
is_layer_skipped
(
prefix
):
if
is_layer_skipped
(
prefix
,
self
.
ignore_list
):
return
UnquantizedLinearMethod
()
return
UnquantizedLinearMethod
()
return
FBGEMMFp8LinearMethod
(
self
)
return
FBGEMMFp8LinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
0eb0757b
...
@@ -8,12 +8,15 @@ from vllm import _custom_ops as ops
...
@@ -8,12 +8,15 @@ from vllm import _custom_ops as ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
from
vllm.model_executor.layers.fused_moe
import
(
FusedMoE
,
FusedMoEMethodBase
,
fused_moe
)
fused_moe
)
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
create_per_tensor_scale_param
,
all_close_1d
,
apply_fp8_linear
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
)
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
)
...
@@ -33,6 +36,7 @@ class Fp8Config(QuantizationConfig):
...
@@ -33,6 +36,7 @@ class Fp8Config(QuantizationConfig):
self
,
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
activation_scheme
:
str
=
"dynamic"
,
activation_scheme
:
str
=
"dynamic"
,
ignored_layers
:
Optional
[
List
[
str
]]
=
None
,
)
->
None
:
)
->
None
:
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
if
is_checkpoint_fp8_serialized
:
if
is_checkpoint_fp8_serialized
:
...
@@ -42,6 +46,7 @@ class Fp8Config(QuantizationConfig):
...
@@ -42,6 +46,7 @@ class Fp8Config(QuantizationConfig):
raise
ValueError
(
raise
ValueError
(
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
f
"Unsupported activation scheme
{
activation_scheme
}
"
)
self
.
activation_scheme
=
activation_scheme
self
.
activation_scheme
=
activation_scheme
self
.
ignored_layers
=
ignored_layers
or
[]
@
classmethod
@
classmethod
def
get_name
(
cls
)
->
str
:
def
get_name
(
cls
)
->
str
:
...
@@ -64,14 +69,18 @@ class Fp8Config(QuantizationConfig):
...
@@ -64,14 +69,18 @@ class Fp8Config(QuantizationConfig):
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
quant_method
=
cls
.
get_from_keys
(
config
,
[
"quant_method"
])
is_checkpoint_fp8_serialized
=
(
"fp8"
in
quant_method
)
is_checkpoint_fp8_serialized
=
(
"fp8"
in
quant_method
)
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
activation_scheme
=
cls
.
get_from_keys
(
config
,
[
"activation_scheme"
])
ignored_layers
=
cls
.
get_from_keys_or
(
config
,
[
"ignored_layers"
],
None
)
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
return
cls
(
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
,
activation_scheme
=
activation_scheme
)
activation_scheme
=
activation_scheme
,
ignored_layers
=
ignored_layers
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
if
is_layer_skipped
(
prefix
,
self
.
ignored_layers
):
return
UnquantizedLinearMethod
()
return
Fp8LinearMethod
(
self
)
return
Fp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
elif
isinstance
(
layer
,
FusedMoE
):
return
Fp8MoEMethod
(
self
)
return
Fp8MoEMethod
(
self
)
...
...
vllm/model_executor/layers/quantization/utils/quant_utils.py
View file @
0eb0757b
"""This file is used for /tests and /benchmarks"""
"""This file is used for /tests and /benchmarks"""
from
typing
import
List
import
numpy
import
numpy
import
torch
import
torch
SUPPORTED_NUM_BITS
=
[
4
,
8
]
SUPPORTED_NUM_BITS
=
[
4
,
8
]
SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
SUPPORTED_GROUP_SIZES
=
[
-
1
,
32
,
64
,
128
]
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
FUSED_LAYER_NAME_MAPPING
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
]
}
def
is_layer_skipped
(
prefix
:
str
,
ignored_layers
:
List
[
str
])
->
bool
:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name
=
prefix
.
split
(
"."
)[
-
1
]
if
proj_name
in
FUSED_LAYER_NAME_MAPPING
:
shard_prefixes
=
[
prefix
.
replace
(
proj_name
,
shard_proj_name
)
for
shard_proj_name
in
FUSED_LAYER_NAME_MAPPING
[
proj_name
]
]
is_skipped
=
None
for
shard_prefix
in
shard_prefixes
:
is_shard_skipped
=
shard_prefix
in
ignored_layers
if
is_skipped
is
None
:
is_skipped
=
is_shard_skipped
elif
is_shard_skipped
!=
is_skipped
:
raise
ValueError
(
f
"Detected some but not all shards of
{
prefix
}
"
"are quantized. All shards of fused layers "
"to have the same precision."
)
else
:
is_skipped
=
prefix
in
ignored_layers
assert
is_skipped
is
not
None
return
is_skipped
def
get_pack_factor
(
num_bits
):
def
get_pack_factor
(
num_bits
):
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
assert
num_bits
in
SUPPORTED_NUM_BITS
,
f
"Unsupported num_bits =
{
num_bits
}
"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment