Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
9e90c9f7
Unverified
Commit
9e90c9f7
authored
Apr 11, 2025
by
chaow-amd
Committed by
GitHub
Apr 11, 2025
Browse files
[Bugfix] Fix bugs of running Quark quantized models (#16236)
Signed-off-by:
chaow
<
chaow@amd.com
>
parent
e9528f6d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
22 deletions
+67
-22
tests/quantization/test_quark.py
tests/quantization/test_quark.py
+37
-8
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
...cutor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
+6
-4
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py
...utor/layers/quantization/quark/schemes/quark_w8a8_int8.py
+24
-10
No files found.
tests/quantization/test_quark.py
View file @
9e90c9f7
...
...
@@ -4,17 +4,28 @@
Run `pytest tests/quantization/test_quark.py`.
"""
import
torch
import
pytest
from
vllm.model_executor.layers.quantization.quark.quark
import
(
# noqa: E501
QuarkLinearMethod
,
QuarkW8A8Fp8
)
QuarkLinearMethod
,
QuarkW8A8Fp8
,
QuarkW8A8Int8
)
from
vllm.platforms
import
current_platform
def
test_quark_fp8
(
vllm_runner
,
monkeypatch
):
# vllm_runner.apply_model() relies on V0 internals.
monkeypatch
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
@
pytest
.
fixture
(
scope
=
"function"
,
autouse
=
True
)
def
use_v0_only
(
monkeypatch
):
"""
This module relies on V0 internals, so set VLLM_USE_V1=0.
"""
monkeypatch
.
setenv
(
'VLLM_USE_V1'
,
'0'
)
@
pytest
.
mark
.
parametrize
(
'kv_cache_dtype'
,
[
'auto'
,
'fp8'
])
@
pytest
.
mark
.
parametrize
(
'tp'
,
[
1
])
def
test_quark_fp8_w_per_tensor_a_per_tensor
(
vllm_runner
,
kv_cache_dtype
,
tp
):
model_path
=
"amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
with
vllm_runner
(
model_path
)
as
llm
:
with
vllm_runner
(
model_path
,
kv_cache_dtype
=
kv_cache_dtype
,
tensor_parallel_size
=
tp
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
...
...
@@ -26,11 +37,29 @@ def test_quark_fp8(vllm_runner, monkeypatch):
if
isinstance
(
qkv_proj
.
scheme
,
QuarkW8A8Fp8
):
assert
len
(
qkv_proj
.
input_scale
.
shape
)
==
0
assert
qkv_proj
.
weight
.
dtype
is
torch
.
float8_e4m3fn
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
assert
qkv_proj
.
weight
.
dtype
is
current_platform
.
fp8_dtype
()
assert
len
(
qkv_proj
.
weight_scale
.
shape
)
==
0
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
@
pytest
.
mark
.
parametrize
(
'tp'
,
[
1
])
def
test_quark_int8_w_per_tensor_a_per_tensor
(
vllm_runner
,
tp
):
model_path
=
"amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
with
vllm_runner
(
model_path
,
tensor_parallel_size
=
tp
)
as
llm
:
def
check_model
(
model
):
layer
=
model
.
model
.
layers
[
0
]
qkv_proj
=
layer
.
self_attn
.
qkv_proj
assert
isinstance
(
qkv_proj
.
quant_method
,
QuarkLinearMethod
)
assert
isinstance
(
qkv_proj
.
scheme
,
QuarkW8A8Int8
)
llm
.
apply_model
(
check_model
)
output
=
llm
.
generate_greedy
(
"Hello my name is"
,
max_tokens
=
20
)
assert
output
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
View file @
9e90c9f7
...
...
@@ -21,7 +21,7 @@ class QuarkW8A8Fp8(QuarkScheme):
def
__init__
(
self
,
qscheme
:
str
,
is_static_input_scheme
:
Optional
[
bool
]):
self
.
qscheme
=
qscheme
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
fp8_linear
=
Fp8LinearOp
(
use_per_token_if_dynamic
=
Tru
e
)
self
.
fp8_linear
=
Fp8LinearOp
(
use_per_token_if_dynamic
=
Fals
e
)
self
.
out_dtype
=
torch
.
get_default_dtype
()
@
classmethod
...
...
@@ -41,10 +41,11 @@ class QuarkW8A8Fp8(QuarkScheme):
)
if
current_platform
.
is_fp8_fnuz
():
input_scale
=
getattr
(
layer
,
'input_scale'
,
None
)
weight
,
max_w_scale
,
input_scale
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
max_w_scale
,
input_scale
=
layer
.
input_scale
)
input_scale
=
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
...
...
@@ -57,11 +58,12 @@ class QuarkW8A8Fp8(QuarkScheme):
weight
=
layer
.
weight
if
current_platform
.
is_fp8_fnuz
():
input_scale
=
getattr
(
layer
,
'input_scale'
,
None
)
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
layer
.
weight_scale
,
input_scale
=
layer
.
input_scale
)
input_scale
=
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
...
...
@@ -105,7 +107,7 @@ class QuarkW8A8Fp8(QuarkScheme):
# the newly added parameters
if
self
.
qscheme
==
"per_channel"
:
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
)
,
1
),
data
=
torch
.
empty
((
sum
(
output_partition_sizes
)),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
...
...
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py
View file @
9e90c9f7
...
...
@@ -35,7 +35,7 @@ class QuarkW8A8Int8(QuarkScheme):
input_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
weight_loader
:
Callable
,
**
kwargs
):
self
.
logical_widths
=
output_partition_sizes
layer
.
logical_widths
=
output_partition_sizes
scaled_mm_linear_kernel_config
=
ScaledMMLinearLayerConfig
(
is_channelwise
=
(
self
.
qscheme
==
"per_channel"
),
...
...
@@ -63,16 +63,28 @@ class QuarkW8A8Int8(QuarkScheme):
# WEIGHT SCALE
if
self
.
qscheme
==
"per_channel"
:
weight_scale
=
ChannelQuantScaleParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
)
,
1
),
data
=
torch
.
empty
((
sum
(
output_partition_sizes
)),
dtype
=
torch
.
float32
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
ChannelQuantZPParameter
=
ChannelQuantScaleParameter
weight_zero_point
=
ChannelQuantZPParameter
(
data
=
torch
.
empty
((
sum
(
output_partition_sizes
)),
dtype
=
torch
.
int8
),
output_dim
=
0
,
weight_loader
=
weight_loader
)
else
:
assert
self
.
qscheme
==
"per_tensor"
weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
)
PerTensorZPParameter
=
PerTensorScaleParameter
weight_zero_point
=
PerTensorZPParameter
(
data
=
torch
.
empty
(
len
(
output_partition_sizes
),
dtype
=
torch
.
int8
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_zero_point"
,
weight_zero_point
)
# INPUT SCALE
if
self
.
is_static_input_scheme
:
...
...
@@ -81,14 +93,10 @@ class QuarkW8A8Int8(QuarkScheme):
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_scale"
,
input_scale
)
if
not
self
.
input_symmetric
:
# Note: quark stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
input_zero_point
=
BasevLLMParameter
(
data
=
torch
.
empty
(
1
,
dtype
=
torch
.
int8
),
weight_loader
=
weight_loader
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
self
.
kernel
=
kernel_type
(
c
=
scaled_mm_linear_kernel_config
,
w_q_param_name
=
"weight"
,
...
...
@@ -100,6 +108,12 @@ class QuarkW8A8Int8(QuarkScheme):
# Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here.
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
register_parameter
(
"weight_zero_point"
,
None
)
delattr
(
layer
,
'weight_zero_point'
)
if
self
.
input_symmetric
:
layer
.
register_parameter
(
"input_zero_point"
,
None
)
delattr
(
layer
,
'input_zero_point'
)
self
.
kernel
.
process_weights_after_loading
(
layer
)
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment