Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
955b5191
Unverified
Commit
955b5191
authored
Aug 22, 2024
by
Dipika Sikka
Committed by
GitHub
Aug 22, 2024
Browse files
[Misc] update fp8 to use `vLLMParameter` (#7437)
parent
55d63b12
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
17 deletions
+51
-17
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+1
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+13
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+29
-15
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+8
-1
No files found.
tests/weight_loading/models.txt
View file @
955b5191
...
@@ -15,3 +15,4 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
...
@@ -15,3 +15,4 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
awq, casperhansen/mixtral-instruct-awq, main
awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
\ No newline at end of file
vllm/model_executor/layers/linear.py
View file @
955b5191
...
@@ -22,7 +22,7 @@ logger = init_logger(__name__)
...
@@ -22,7 +22,7 @@ logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED
=
[
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
]
]
...
@@ -349,6 +349,11 @@ class ColumnParallelLinear(LinearBase):
...
@@ -349,6 +349,11 @@ class ColumnParallelLinear(LinearBase):
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
def
weight_loader_v2
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
def
weight_loader_v2
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
assert
loaded_weight
.
numel
()
==
1
loaded_weight
=
loaded_weight
.
reshape
(
1
)
param
.
load_column_parallel_weight
(
loaded_weight
=
loaded_weight
)
param
.
load_column_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
...
@@ -1021,6 +1026,13 @@ class RowParallelLinear(LinearBase):
...
@@ -1021,6 +1026,13 @@ class RowParallelLinear(LinearBase):
def
weight_loader_v2
(
self
,
param
:
BasevLLMParameter
,
def
weight_loader_v2
(
self
,
param
:
BasevLLMParameter
,
loaded_weight
:
torch
.
Tensor
):
loaded_weight
:
torch
.
Tensor
):
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
if
len
(
loaded_weight
.
shape
)
==
0
:
assert
loaded_weight
.
numel
()
==
1
loaded_weight
=
loaded_weight
.
reshape
(
1
)
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
)
param
.
load_row_parallel_weight
(
loaded_weight
=
loaded_weight
)
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
955b5191
...
@@ -19,9 +19,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -19,9 +19,10 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped
)
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
cutlass_fp8_supported
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
requantize_with_max_scale
)
requantize_with_max_scale
)
from
vllm.model_executor.parameter
import
(
ModelWeightParameter
,
PerTensorScaleParameter
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
,
print_warning_once
from
vllm.utils
import
is_hip
,
print_warning_once
...
@@ -137,6 +138,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -137,6 +138,7 @@ class Fp8LinearMethod(LinearMethodBase):
):
):
del
input_size
,
output_size
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
layer
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
...
@@ -148,34 +150,41 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -148,34 +150,41 @@ class Fp8LinearMethod(LinearMethodBase):
weight_dtype
=
(
torch
.
float8_e4m3fn
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
params_dtype
)
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
dtype
=
weight_dtype
),
output_size_per_partition
,
requires_grad
=
False
)
input_size_per_partition
,
dtype
=
weight_dtype
),
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
**
extra_weight_attrs
,
"input_dim"
:
1
,
"output_dim"
:
0
,
})
# If checkpoint is serialized fp8, load them.
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
# Otherwise, wait until process_weights_after_loading.
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALE
# WEIGHT SCALE
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
**
extra_weight_attrs
)
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
scale
)
layer
.
register_parameter
(
"weight_scale"
,
scale
)
# INPUT ACTIVATION SCALE
# INPUT ACTIVATION SCALE
if
self
.
quant_config
.
activation_scheme
==
"static"
:
if
self
.
quant_config
.
activation_scheme
==
"static"
:
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
**
extra_weight_attrs
)
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"input_scale"
,
scale
)
layer
.
register_parameter
(
"input_scale"
,
scale
)
else
:
else
:
layer
.
register_parameter
(
"input_scale"
,
None
)
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
layer
.
weight
=
torch
.
nn
.
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
# If checkpoint not serialized fp8, quantize the weights.
# If checkpoint not serialized fp8, quantize the weights.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
...
@@ -197,6 +206,11 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -197,6 +206,11 @@ class Fp8LinearMethod(LinearMethodBase):
# If checkpoint is fp8, handle that there are N scales for N
# If checkpoint is fp8, handle that there are N scales for N
# shards in a fused module
# shards in a fused module
else
:
else
:
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
if
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
torch
.
nn
.
Parameter
(
layer
.
input_scale
.
data
,
requires_grad
=
False
)
# If using marlin (w8a16), kernel uses channelwise weights,
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
# so extend the weight scales to be channelwise.
if
self
.
use_marlin
:
if
self
.
use_marlin
:
...
...
vllm/model_executor/parameter.py
View file @
955b5191
...
@@ -208,10 +208,17 @@ class PerTensorScaleParameter(BasevLLMParameter):
...
@@ -208,10 +208,17 @@ class PerTensorScaleParameter(BasevLLMParameter):
if
isinstance
(
shard_id
,
int
):
if
isinstance
(
shard_id
,
int
):
return
shard_id
return
shard_id
# if not int, assume shard_id for qkv
# map to int and return
assert
isinstance
(
shard_id
,
str
)
assert
isinstance
(
shard_id
,
str
)
assert
shard_id
in
self
.
qkv_idxs
assert
shard_id
in
self
.
qkv_idxs
return
self
.
qkv_idxs
[
shard_id
]
return
self
.
qkv_idxs
[
shard_id
]
# For row parallel layers, no sharding needed
# load weight into parameter as is
def
load_row_parallel_weight
(
self
,
*
args
,
**
kwargs
):
super
().
load_row_parallel_weight
(
*
args
,
**
kwargs
)
def
load_merged_column_weight
(
self
,
*
args
,
**
kwargs
):
def
load_merged_column_weight
(
self
,
*
args
,
**
kwargs
):
self
.
_load_into_shard_id
(
*
args
,
**
kwargs
)
self
.
_load_into_shard_id
(
*
args
,
**
kwargs
)
...
@@ -219,7 +226,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
...
@@ -219,7 +226,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
self
.
_load_into_shard_id
(
*
args
,
**
kwargs
)
self
.
_load_into_shard_id
(
*
args
,
**
kwargs
)
def
load_column_parallel_weight
(
self
,
*
args
,
**
kwargs
):
def
load_column_parallel_weight
(
self
,
*
args
,
**
kwargs
):
s
elf
.
_load_into_shard_id
(
*
args
,
**
kwargs
)
s
uper
().
load_row_parallel_weight
(
*
args
,
**
kwargs
)
def
_load_into_shard_id
(
self
,
loaded_weight
:
torch
.
Tensor
,
def
_load_into_shard_id
(
self
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
Union
[
str
,
int
],
**
kwargs
):
shard_id
:
Union
[
str
,
int
],
**
kwargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment