"tests/compile/distributed/test_async_tp.py" did not exist on "287f527f5403bb42a32136cf6c802faeb92a09ef"
Unverified Commit bb239a73 authored by fxmarty-amd's avatar fxmarty-amd Committed by GitHub
Browse files

[Bugfix] Fix quark fp8 format loading on AMD GPUs (#12612)


Signed-off-by: default avatarFelix Marty <felmarty@amd.com>
Signed-off-by: default avatarkewang2 <kewang2@amd.com>
Co-authored-by: default avatarkewang2 <kewang2@amd.com>
parent a463555d
......@@ -5,6 +5,7 @@ Run `pytest tests/quantization/test_quark.py`.
"""
import pytest
import torch
from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
......@@ -63,3 +64,28 @@ def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
def test_quark_fp8_parity(vllm_runner):
quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"
llm_kwargs = {
"tensor_parallel_size": 1,
"enforce_eager": True,
"gpu_memory_utilization": 0.1
}
with (vllm_runner(quark_model_id, **llm_kwargs) as
quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
quark_model = (quark_handle.model.llm_engine.model_executor.
driver_worker.model_runner.model)
quark_state_dict = quark_model.state_dict()
fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker.
model_runner.model)
fp8_state_dict = fp8_model.state_dict()
assert fp8_state_dict.keys() == quark_state_dict.keys()
for key in fp8_state_dict:
assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
......@@ -34,21 +34,24 @@ class QuarkW8A8Fp8(QuarkScheme):
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.qscheme == "per_tensor":
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
if current_platform.is_fp8_fnuz():
if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
else:
max_w_scale = layer.weight_scale
weight = layer.weight
max_w_scale, weight = requantize_with_max_scale(
weight=weight,
weight_scale=max_w_scale,
logical_widths=layer.logical_widths,
)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment