Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9364f74e
Unverified
Commit
9364f74e
authored
Jul 20, 2024
by
Robert Shaw
Committed by
GitHub
Jul 20, 2024
Browse files
[ Kernel ] Enable `fp8-marlin` for `fbgemm-fp8` models (#6606)
parent
06d6c5fe
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
44 additions
and
3 deletions
+44
-3
.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml
.../configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml
+11
-0
.buildkite/lm-eval-harness/configs/models-large.txt
.buildkite/lm-eval-harness/configs/models-large.txt
+1
-0
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+25
-1
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
...el_executor/layers/quantization/utils/marlin_utils_fp8.py
+7
-2
No files found.
.buildkite/lm-eval-harness/configs/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml
0 → 100644
View file @
9364f74e
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5
model_name
:
"
nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform"
tasks
:
-
name
:
"
gsm8k"
metrics
:
-
name
:
"
exact_match,strict-match"
value
:
0.905
-
name
:
"
exact_match,flexible-extract"
value
:
0.905
limit
:
1000
num_fewshot
:
5
.buildkite/lm-eval-harness/configs/models-large.txt
View file @
9364f74e
Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml
Meta-Llama-3-70B-Instruct.yaml
Mixtral-8x7B-Instruct-v0.1.yaml
Qwen2-57B-A14-Instruct.yaml
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
9364f74e
...
...
@@ -9,9 +9,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
,
create_per_channel_scale_param
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
logger
=
init_logger
(
__name__
)
...
...
@@ -31,6 +34,12 @@ class FBGEMMFp8Config(QuantizationConfig):
self
.
ignore_list
=
ignore_list
self
.
input_scale_ub
=
input_scale_ub
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"fbgemm_fp8"
...
...
@@ -41,7 +50,7 @@ class FBGEMMFp8Config(QuantizationConfig):
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
8
9
return
8
0
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
...
...
@@ -143,11 +152,26 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
if
self
.
quant_config
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
# Activations not quantized for marlin.
del
layer
.
input_scale_ub
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
quant_config
.
use_marlin
:
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
)
return
apply_fp8_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
...
...
vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py
View file @
9364f74e
...
...
@@ -76,8 +76,13 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
).
to
(
layer
.
orig_dtype
).
to
(
device
)
is_channelwise
=
layer
.
weight_scale
.
shape
[
0
]
==
part_size_n
if
is_channelwise
:
scales
=
layer
.
weight_scale
else
:
scales
=
layer
.
weight_scale
.
repeat
(
1
,
part_size_n
)
scales
=
scales
.
to
(
layer
.
orig_dtype
).
to
(
device
)
# Permute scales
marlin_scales
=
marlin_permute_scales
(
s
=
scales
,
size_k
=
part_size_k
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment