Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e16fa99a
Unverified
Commit
e16fa99a
authored
Sep 03, 2024
by
Dipika Sikka
Committed by
GitHub
Sep 03, 2024
Browse files
[Misc] Update fbgemmfp8 to use `vLLMParameters` (#7972)
Co-authored-by:
Michael Goin
<
michael@neuralmagic.com
>
parent
61f4a93d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
41 deletions
+22
-41
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+1
-1
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+21
-13
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+0
-27
No files found.
vllm/model_executor/layers/linear.py
View file @
e16fa99a
...
@@ -26,7 +26,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
...
@@ -26,7 +26,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"CompressedTensorsLinearMethod"
,
"AWQMarlinLinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"AWQLinearMethod"
,
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"MarlinLinearMethod"
,
"QQQLinearMethod"
,
"GPTQMarlin24LinearMethod"
,
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
"TPUInt8LinearMethod"
,
"GPTQLinearMethod"
,
"FBGEMMFp8LinearMethod"
]
]
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
e16fa99a
...
@@ -15,8 +15,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
...
@@ -15,8 +15,9 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
is_layer_skipped
)
is_layer_skipped
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
create_per_channel_scale_param
)
apply_fp8_linear
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
ModelWeightParameter
)
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -85,6 +86,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
...
@@ -85,6 +86,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
):
):
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
del
input_size
,
output_size
del
input_size
,
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
...
@@ -95,20 +97,21 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
...
@@ -95,20 +97,21 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
layer
.
orig_dtype
=
params_dtype
layer
.
orig_dtype
=
params_dtype
# WEIGHT
# WEIGHT
weight
=
Parameter
(
torch
.
empty
(
output_size_per_partition
,
weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
output_size_per_partition
,
input_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
float8_e4m3fn
),
dtype
=
torch
.
float8_e4m3fn
),
requires_grad
=
False
)
input_dim
=
1
,
output_dim
=
0
,
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
{
"input_dim"
:
1
,
"output_dim"
:
0
,
**
extra_weight_attrs
,
})
# WEIGHT SCALE
# WEIGHT SCALE
weight_scale
=
create_per_channel_scale_param
(
output_partition_sizes
,
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
(
**
extra_weight_attrs
)
(
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
weight_scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
# INPUT SCALE UPPER BOUND
# INPUT SCALE UPPER BOUND
...
@@ -118,6 +121,11 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
...
@@ -118,6 +121,11 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
layer
.
input_scale_ub
=
input_scale_ub
layer
.
input_scale_ub
=
input_scale_ub
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# required by torch.compile
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
data
,
requires_grad
=
False
)
weight
=
layer
.
weight
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
e16fa99a
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
...
@@ -38,31 +36,6 @@ def all_close_1d(x: torch.Tensor) -> bool:
...
@@ -38,31 +36,6 @@ def all_close_1d(x: torch.Tensor) -> bool:
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
return
all
(
torch
.
allclose
(
x
[
0
],
x
[
i
])
for
i
in
range
(
x
.
shape
[
0
]))
def
create_per_tensor_scale_param
(
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
,
)
->
Parameter
:
scale
=
Parameter
(
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
set_weight_attrs
(
scale
,
{
"needs_scalar_to_array"
:
True
,
**
extra_weight_attrs
})
return
scale
def
create_per_channel_scale_param
(
output_partition_sizes
:
List
[
int
],
**
extra_weight_attrs
)
->
Parameter
:
scale
=
Parameter
(
torch
.
empty
((
sum
(
output_partition_sizes
),
1
),
dtype
=
torch
.
float32
),
requires_grad
=
False
)
scale
[:]
=
torch
.
finfo
(
torch
.
float32
).
min
set_weight_attrs
(
scale
,
{
"output_dim"
:
0
,
**
extra_weight_attrs
})
return
scale
def
convert_to_channelwise
(
def
convert_to_channelwise
(
weight_scale
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
logical_widths
:
List
[
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
logical_widths
:
List
[
int
])
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment