Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5223199e
Unverified
Commit
5223199e
authored
Aug 07, 2024
by
Michael Goin
Committed by
GitHub
Aug 07, 2024
Browse files
[Bugfix][FP8] Fix dynamic FP8 Marlin quantization (#7219)
parent
fde47d3b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
5 deletions
+33
-5
tests/quantization/test_fp8.py
tests/quantization/test_fp8.py
+15
-4
vllm/envs.py
vllm/envs.py
+8
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+10
-1
No files found.
tests/quantization/test_fp8.py
View file @
5223199e
...
@@ -9,6 +9,7 @@ from tests.quantization.utils import is_quant_method_supported
...
@@ -9,6 +9,7 @@ from tests.quantization.utils import is_quant_method_supported
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8KVCacheMethod
,
from
vllm.model_executor.layers.quantization.fp8
import
(
Fp8KVCacheMethod
,
Fp8LinearMethod
)
Fp8LinearMethod
)
from
vllm.platforms
import
current_platform
MODELS
=
[
MODELS
=
[
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV"
,
"neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV"
,
...
@@ -20,7 +21,12 @@ MODELS = [
...
@@ -20,7 +21,12 @@ MODELS = [
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"model_id"
,
MODELS
)
def
test_model_load_and_run
(
vllm_runner
,
model_id
:
str
):
@
pytest
.
mark
.
parametrize
(
"force_marlin"
,
[
False
,
True
])
def
test_model_load_and_run
(
vllm_runner
,
model_id
:
str
,
force_marlin
:
bool
,
monkeypatch
)
->
None
:
if
force_marlin
:
monkeypatch
.
setenv
(
"VLLM_TEST_FORCE_FP8_MARLIN"
,
"1"
)
with
vllm_runner
(
model_id
)
as
llm
:
with
vllm_runner
(
model_id
)
as
llm
:
# note: this does not test accuracy, just that we can run through
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
# see lm-eval tests for accuracy
...
@@ -61,7 +67,12 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
...
@@ -61,7 +67,12 @@ def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
@
pytest
.
mark
.
skipif
(
not
is_quant_method_supported
(
"fp8"
),
reason
=
"FP8 is not supported on this GPU type."
)
reason
=
"FP8 is not supported on this GPU type."
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
,
"fp8"
])
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
[
"auto"
,
"fp8"
])
def
test_load_fp16_model
(
vllm_runner
,
kv_cache_dtype
:
str
)
->
None
:
@
pytest
.
mark
.
parametrize
(
"force_marlin"
,
[
False
,
True
])
def
test_load_fp16_model
(
vllm_runner
,
kv_cache_dtype
:
str
,
force_marlin
:
bool
,
monkeypatch
)
->
None
:
if
force_marlin
:
monkeypatch
.
setenv
(
"VLLM_TEST_FORCE_FP8_MARLIN"
,
"1"
)
with
vllm_runner
(
"facebook/opt-125m"
,
with
vllm_runner
(
"facebook/opt-125m"
,
quantization
=
"fp8"
,
quantization
=
"fp8"
,
kv_cache_dtype
=
kv_cache_dtype
)
as
llm
:
kv_cache_dtype
=
kv_cache_dtype
)
as
llm
:
...
@@ -75,9 +86,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str) -> None:
...
@@ -75,9 +86,9 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str) -> None:
assert
attn
.
_k_scale
==
1.0
assert
attn
.
_k_scale
==
1.0
assert
attn
.
_v_scale
==
1.0
assert
attn
.
_v_scale
==
1.0
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
if
capability
>=
89
:
if
capability
>=
89
and
not
force_marlin
:
# For GPUs with hardware support, we keep weights in fp8
# For GPUs with hardware support, we keep weights in fp8
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
assert
fc1
.
weight
.
dtype
==
torch
.
float8_e4m3fn
else
:
else
:
...
...
vllm/envs.py
View file @
5223199e
...
@@ -52,6 +52,7 @@ if TYPE_CHECKING:
...
@@ -52,6 +52,7 @@ if TYPE_CHECKING:
CMAKE_BUILD_TYPE
:
Optional
[
str
]
=
None
CMAKE_BUILD_TYPE
:
Optional
[
str
]
=
None
VERBOSE
:
bool
=
False
VERBOSE
:
bool
=
False
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
VLLM_ALLOW_LONG_MAX_MODEL_LEN
:
bool
=
False
VLLM_TEST_FORCE_FP8_MARLIN
:
bool
=
False
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -342,6 +343,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -342,6 +343,13 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_ALLOW_LONG_MAX_MODEL_LEN"
,
"0"
).
strip
().
lower
()
in
(
os
.
environ
.
get
(
"VLLM_ALLOW_LONG_MAX_MODEL_LEN"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
)),
(
"1"
,
"true"
)),
# If set, forces FP8 Marlin to be used for FP8 quantization regardless
# of the hardware support for FP8 compute.
"VLLM_TEST_FORCE_FP8_MARLIN"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_TEST_FORCE_FP8_MARLIN"
,
"0"
).
strip
().
lower
()
in
(
"1"
,
"true"
)),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
5223199e
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
import
vllm.envs
as
envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
from
vllm.model_executor.layers.fused_moe
import
FusedMoE
,
FusedMoEMethodBase
...
@@ -118,7 +119,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -118,7 +119,7 @@ class Fp8LinearMethod(LinearMethodBase):
# kernel for fast weight-only FP8 quantization
# kernel for fast weight-only FP8 quantization
capability
=
current_platform
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
self
.
use_marlin
=
capability
<
89
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -174,6 +175,14 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -174,6 +175,14 @@ class Fp8LinearMethod(LinearMethodBase):
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
scale
=
None
)
# If using marlin (w8a16), kernel uses channelwise weights,
# so extend the weight scales to be channelwise.
if
self
.
use_marlin
:
assert
weight_scale
.
numel
()
==
1
weight_scale
=
convert_to_channelwise
(
weight_scale
.
expand
(
len
(
layer
.
logical_widths
)),
layer
.
logical_widths
)
# Update the layer with the new values.
# Update the layer with the new values.
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
qweight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
weight_scale
,
requires_grad
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment