Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
889da130
Unverified
Commit
889da130
authored
Jul 25, 2024
by
Robert Shaw
Committed by
GitHub
Jul 25, 2024
Browse files
[ Misc ] `fp8-marlin` channelwise via `compressed-tensors` (#6524)
Co-authored-by:
mgoin
<
michael@neuralmagic.com
>
parent
b75e314f
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
219 additions
and
49 deletions
+219
-49
.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml
...te/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml
+11
-0
.buildkite/lm-eval-harness/configs/models-small.txt
.buildkite/lm-eval-harness/configs/models-small.txt
+1
-0
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+51
-10
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
...ayers/quantization/compressed_tensors/schemes/__init__.py
+2
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
...n/compressed_tensors/schemes/compressed_tensors_scheme.py
+2
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
...pressed_tensors/schemes/compressed_tensors_unquantized.py
+2
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
...compressed_tensors/schemes/compressed_tensors_w4a16_24.py
+2
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
...ompressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
+105
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+6
-4
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+10
-9
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+2
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+22
-11
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+3
-11
No files found.
.buildkite/lm-eval-harness/configs/Qwen2-1.5B-Instruct-FP8W8.yaml
0 → 100644
View file @
889da130
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1
model_name
:
"
nm-testing/Qwen2-1.5B-Instruct-FP8W8"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.578
-
name
:
"
exact_match,flexible-extract"
value
:
0.585
limit
:
1000
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/models-small.txt
View file @
889da130
...
@@ -5,3 +5,4 @@ Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
...
@@ -5,3 +5,4 @@ Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
Qwen2-1.5B-Instruct-FP8W8.yaml
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
889da130
...
@@ -10,7 +10,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
...
@@ -10,7 +10,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
W4A16SPARSE24_SUPPORTED_BITS
,
WNA16_SUPPORTED_BITS
,
CompressedTensorsScheme
,
CompressedTensorsUnquantized
,
CompressedTensorsScheme
,
CompressedTensorsUnquantized
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW4A16Sparse24
,
CompressedTensorsW8A8Fp8
,
CompressedTensorsW8A8Int8
,
CompressedTensorsWNA16
)
CompressedTensorsW8A8Int8
,
CompressedTensorsW8A16Fp8
,
CompressedTensorsWNA16
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
CompressionFormat
,
QuantizationArgs
,
QuantizationStrategy
,
QuantizationType
,
find_matched_target
,
is_activation_quantization_format
,
QuantizationType
,
find_matched_target
,
is_activation_quantization_format
,
...
@@ -100,14 +101,18 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -100,14 +101,18 @@ class CompressedTensorsConfig(QuantizationConfig):
def
get_config_filenames
(
cls
)
->
List
[
str
]:
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[]
return
[]
def
_check_scheme_supported
(
self
,
min_capability
:
int
):
def
_check_scheme_supported
(
self
,
min_capability
:
int
,
error
:
bool
=
True
)
->
bool
:
capability
=
current_platform
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
<
min_capability
:
supported
=
capability
>=
min_capability
if
error
and
not
supported
:
raise
RuntimeError
(
raise
RuntimeError
(
"Quantization scheme is not supported for "
,
"Quantization scheme is not supported for "
,
f
"the current GPU. Min capability:
{
min_capability
}
. "
,
f
"the current GPU. Min capability:
{
min_capability
}
. "
,
f
"Current capability:
{
capability
}
."
)
f
"Current capability:
{
capability
}
."
)
return
supported
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
def
_is_static_tensor_w8a8
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant
:
BaseModel
)
->
bool
:
...
@@ -170,6 +175,29 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -170,6 +175,29 @@ class CompressedTensorsConfig(QuantizationConfig):
# All conditions satisfied.
# All conditions satisfied.
return
True
return
True
def
_is_fp8_w8a16
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
# Confirm weights quantized.
if
weight_quant
is
None
:
return
False
# Confirm we have floating points.
if
weight_quant
.
type
!=
QuantizationType
.
FLOAT
:
return
False
# Confirm weight scheme is supported.
is_symmetric_weight
=
weight_quant
.
symmetric
is_static_weight
=
not
weight_quant
.
dynamic
is_per_tensor_or_channel_weight
=
(
weight_quant
.
strategy
in
[
QuantizationStrategy
.
TENSOR
,
QuantizationStrategy
.
CHANNEL
])
if
not
(
is_symmetric_weight
and
is_static_weight
and
is_per_tensor_or_channel_weight
):
return
False
# All conditions satisfied.
return
True
def
_is_wNa16_group_channel
(
self
,
weight_quant
:
BaseModel
,
def
_is_wNa16_group_channel
(
self
,
weight_quant
:
BaseModel
,
input_quant
:
BaseModel
)
->
bool
:
input_quant
:
BaseModel
)
->
bool
:
input_quant_none
=
input_quant
is
None
input_quant_none
=
input_quant
is
None
...
@@ -204,9 +232,23 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -204,9 +232,23 @@ class CompressedTensorsConfig(QuantizationConfig):
# Detect If Activation Quantization.
# Detect If Activation Quantization.
if
is_activation_quantization_format
(
self
.
quant_format
):
if
is_activation_quantization_format
(
self
.
quant_format
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
is_fp8_w8a8_supported
=
self
.
_check_scheme_supported
(
CompressedTensorsW8A8Fp8
.
get_min_capability
(),
error
=
False
)
if
is_fp8_w8a8_supported
:
return
CompressedTensorsW8A8Fp8
(
return
CompressedTensorsW8A8Fp8
(
strategy
=
weight_quant
.
strategy
,
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
(
not
input_quant
.
dynamic
))
is_static_input_scheme
=
(
not
input_quant
.
dynamic
))
else
:
return
CompressedTensorsW8A16Fp8
(
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
(
input_quant
and
not
input_quant
.
dynamic
))
if
self
.
_is_fp8_w8a16
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A16Fp8
(
strategy
=
weight_quant
.
strategy
,
is_static_input_scheme
=
(
input_quant
and
not
input_quant
.
dynamic
))
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
if
self
.
_is_static_tensor_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8
(
return
CompressedTensorsW8A8Int8
(
...
@@ -257,11 +299,10 @@ class CompressedTensorsConfig(QuantizationConfig):
...
@@ -257,11 +299,10 @@ class CompressedTensorsConfig(QuantizationConfig):
targets
=
self
.
target_scheme_map
.
keys
())
targets
=
self
.
target_scheme_map
.
keys
())
# Find the quant_scheme
# Find the quant_scheme
scheme
=
self
.
target_scheme_map
[
matched_target
]
scheme_dict
=
self
.
target_scheme_map
[
matched_target
]
scheme
=
self
.
_get_scheme_from_parts
(
return
self
.
_get_scheme_from_parts
(
weight_quant
=
scheme_dict
[
"weights"
],
weight_quant
=
scheme
[
"weights"
],
input_quant
=
scheme_dict
[
"input_activations"
])
input_quant
=
scheme
[
"input_activations"
])
# Raise error if device does not support the scheme
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
# (e.g. fp8 needs ada lovelace)
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py
View file @
889da130
...
@@ -4,6 +4,7 @@ from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
...
@@ -4,6 +4,7 @@ from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24
)
CompressedTensorsW4A16Sparse24
)
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
from
.compressed_tensors_w8a8_fp8
import
CompressedTensorsW8A8Fp8
from
.compressed_tensors_w8a8_int8
import
CompressedTensorsW8A8Int8
from
.compressed_tensors_w8a8_int8
import
CompressedTensorsW8A8Int8
from
.compressed_tensors_w8a16_fp8
import
CompressedTensorsW8A16Fp8
from
.compressed_tensors_wNa16
import
(
WNA16_SUPPORTED_BITS
,
from
.compressed_tensors_wNa16
import
(
WNA16_SUPPORTED_BITS
,
CompressedTensorsWNA16
)
CompressedTensorsWNA16
)
...
@@ -11,6 +12,7 @@ __all__ = [
...
@@ -11,6 +12,7 @@ __all__ = [
"CompressedTensorsScheme"
,
"CompressedTensorsScheme"
,
"CompressedTensorsUnquantized"
,
"CompressedTensorsUnquantized"
,
"CompressedTensorsWNA16"
,
"CompressedTensorsWNA16"
,
"CompressedTensorsW8A16Fp8"
,
"CompressedTensorsW4A16Sparse24"
,
"CompressedTensorsW4A16Sparse24"
,
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsW8A8Int8"
,
"CompressedTensorsW8A8Fp8"
,
"CompressedTensorsW8A8Fp8"
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py
View file @
889da130
...
@@ -12,8 +12,9 @@ class CompressedTensorsScheme(ABC):
...
@@ -12,8 +12,9 @@ class CompressedTensorsScheme(ABC):
of different quantization schemes supported by CompressedTensors.
of different quantization schemes supported by CompressedTensors.
"""
"""
@
classmethod
@
abstractmethod
@
abstractmethod
def
get_min_capability
(
self
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
"""
"""
Get minimum device capability.
Get minimum device capability.
"""
"""
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py
View file @
889da130
...
@@ -18,7 +18,8 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
...
@@ -18,7 +18,8 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation.
in a linear transformation.
"""
"""
def
get_min_capability
(
self
)
->
int
:
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# volta and up
# volta and up
return
70
return
70
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py
View file @
889da130
...
@@ -29,7 +29,8 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
...
@@ -29,7 +29,8 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
raise
ValueError
(
raise
ValueError
(
"group_size must be given when using strategy group"
)
"group_size must be given when using strategy group"
)
def
get_min_capability
(
self
)
->
int
:
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# ampere + up
# ampere + up
return
80
return
80
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py
0 → 100644
View file @
889da130
from
typing
import
Callable
,
List
,
Optional
import
torch
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
CompressedTensorsScheme
)
from
vllm.model_executor.layers.quantization.compressed_tensors.utils
import
(
QuantizationStrategy
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
convert_to_channelwise
,
create_per_channel_scale_param
,
create_per_tensor_scale_param
)
from
vllm.model_executor.utils
import
set_weight_attrs
__all__
=
[
"CompressedTensorsW8A16Fp8"
]
SUPPORTED_STRATEGIES
=
[
QuantizationStrategy
.
CHANNEL
,
QuantizationStrategy
.
TENSOR
]
class
CompressedTensorsW8A16Fp8
(
CompressedTensorsScheme
):
def
__init__
(
self
,
strategy
:
str
,
is_static_input_scheme
:
bool
):
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# ampere and up
return
80
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
def
process_weights_after_loading
(
self
,
layer
)
->
None
:
if
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
ws_channelwise
=
convert_to_channelwise
(
layer
.
weight_scale
,
layer
.
logical_widths
)
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
ws_channelwise
,
requires_grad
=
False
)
# Weights must be transposed for marlin
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
prepare_fp8_layer_for_marlin
(
layer
,
strategy
=
"channel"
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size
:
int
,
output_partition_sizes
:
List
[
int
],
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
output_size_per_partition
=
sum
(
output_partition_sizes
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
orig_dtype
=
params_dtype
# WEIGHT
weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
requires_grad
=
False
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
"weight_loader"
:
weight_loader
,
})
# WEIGHT SCALE
layer_kwargs
=
{
"weight_loader"
:
weight_loader
}
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight_scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
elif
self
.
strategy
==
QuantizationStrategy
.
TENSOR
:
weight_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
else
:
raise
ValueError
(
f
"Unsupported weight strategy=
{
self
.
strategy
}
, "
f
"supported strategies are
{
SUPPORTED_STRATEGIES
}
"
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE (to deal with converted checkpoints)
if
self
.
is_static_input_scheme
:
input_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
layer_kwargs
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
)
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
889da130
...
@@ -23,7 +23,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -23,7 +23,8 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
get_min_capability
(
self
)
->
int
:
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# lovelace and up
# lovelace and up
return
89
return
89
...
@@ -77,19 +78,20 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -77,19 +78,20 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
})
})
# WEIGHT SCALE
# WEIGHT SCALE
layer_kwargs
=
{
"weight_loader"
:
weight_loader
}
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
weight_scale
=
create_per_channel_scale_param
(
weight_scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
output_partition_sizes
,
**
layer_kwargs
)
else
:
else
:
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
weight_scale
=
create_per_tensor_scale_param
(
weight_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
output_partition_sizes
,
**
layer_kwargs
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE
# INPUT SCALE
if
self
.
is_static_input_scheme
:
if
self
.
is_static_input_scheme
:
input_scale
=
create_per_tensor_scale_param
(
input_scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
weight_loader
=
weight_loader
)
output_partition_sizes
,
**
layer_kwargs
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
def
apply_weights
(
self
,
def
apply_weights
(
self
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
889da130
...
@@ -19,7 +19,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -19,7 +19,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
is_static_input_scheme
=
is_static_input_scheme
def
get_min_capability
(
self
)
->
int
:
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# turing and up
# turing and up
return
75
return
75
...
@@ -68,19 +69,19 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -68,19 +69,19 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
# WEIGHT SCALE
# WEIGHT SCALE
layer_kwargs
=
{
"weight_loader"
:
weight_loader
}
layer_kwargs
=
{
"weight_loader"
:
weight_loader
}
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
if
self
.
strategy
==
QuantizationStrategy
.
CHANNEL
:
scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
weight_
scale
=
create_per_channel_scale_param
(
**
layer_kwargs
)
output_partition_sizes
,
**
layer_kwargs
)
else
:
else
:
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
assert
self
.
strategy
==
QuantizationStrategy
.
TENSOR
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
weight_
scale
=
create_per_tensor_scale_param
(
**
layer_kwargs
)
output_partition_sizes
,
**
layer_kwargs
)
layer
.
register_parameter
(
"weight_scale"
,
scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_
scale
)
# INPUT SCALE
# INPUT SCALE
if
self
.
is_static_input_scheme
:
if
self
.
is_static_input_scheme
:
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
input_
scale
=
create_per_tensor_scale_param
(
**
layer_kwargs
)
output_partition_sizes
,
**
layer_kwargs
)
layer
.
register_parameter
(
"input_scale"
,
scale
)
layer
.
register_parameter
(
"input_scale"
,
input_
scale
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
889da130
...
@@ -42,7 +42,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -42,7 +42,8 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
group_size
=
self
.
group_size
,
group_size
=
self
.
group_size
,
is_sym
=
True
)
is_sym
=
True
)
def
get_min_capability
(
self
)
->
int
:
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
# ampere and up
# ampere and up
return
80
return
80
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
889da130
...
@@ -18,8 +18,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
...
@@ -18,8 +18,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
create_per_tensor_scale_param
,
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
)
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
...
@@ -179,11 +180,21 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -179,11 +180,21 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
None
layer
.
input_scale
=
None
# If checkpoint is fp8,
requantize the separately quantized logical
# If checkpoint is fp8,
handle that there are N scales for N
#
weight
s in
to
a
single fp8 weight with a single weight sca
le
.
#
shard
s in a
fused modu
le
else
:
else
:
# Dequant -> Quant with max scale.
# If using marlin (w8a16), kernel uses channelwise weights,
max_w_scale
,
weight
=
requantize_with_max_scale
(
# so extend the weight scales to be channelwise.
if
self
.
use_marlin
:
weight
=
layer
.
weight
weight_scale
=
convert_to_channelwise
(
layer
.
weight_scale
,
layer
.
logical_widths
)
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
else
:
# Dequant -> Quant with max scale so we can run per tensor.
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight_scale
=
layer
.
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
logical_widths
=
layer
.
logical_widths
,
...
@@ -191,7 +202,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -191,7 +202,7 @@ class Fp8LinearMethod(LinearMethodBase):
# Update layer with new values.
# Update layer with new values.
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w
_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight
_scale
,
requires_grad
=
False
)
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
889da130
...
@@ -46,7 +46,8 @@ def apply_fp8_marlin_linear(
...
@@ -46,7 +46,8 @@ def apply_fp8_marlin_linear(
return
output
.
reshape
(
out_shape
)
return
output
.
reshape
(
out_shape
)
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
)
->
None
:
def
prepare_fp8_layer_for_marlin
(
layer
:
torch
.
nn
.
Module
,
strategy
:
str
=
"tensor"
)
->
None
:
print_warning_once
(
print_warning_once
(
"Your GPU does not have native support for FP8 computation but "
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"FP8 quantization is being used. Weight-only FP8 compression will "
...
@@ -74,16 +75,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
...
@@ -74,16 +75,7 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
layer
.
weight
=
torch
.
nn
.
Parameter
(
marlin_qweight
,
requires_grad
=
False
)
# WEIGHT SCALES
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
scales
=
layer
.
weight_scale
.
to
(
layer
.
orig_dtype
)
# expand it to channelwise
is_channelwise
=
(
len
(
layer
.
weight_scale
.
shape
)
>
0
and
layer
.
weight_scale
.
shape
[
0
]
==
part_size_n
)
if
is_channelwise
:
scales
=
layer
.
weight_scale
else
:
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
)
scales
=
scales
.
to
(
layer
.
orig_dtype
).
to
(
device
)
# Permute scales
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
size_k
=
part_size_k
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment