Unverified Commit 1c50e100 authored by li haoyang's avatar li haoyang Committed by GitHub
Browse files

[Bugfix] fix quark ptpc (#20251)


Signed-off-by: default avatarHaoyang Li <Haoyang.Li@amd.com>
Co-authored-by: default avatarHaoyang Li <307790822@qq.com>
parent 3ee56e26
...@@ -312,11 +312,7 @@ class QuarkConfig(QuantizationConfig): ...@@ -312,11 +312,7 @@ class QuarkConfig(QuantizationConfig):
is_fp8_w8a8_supported = self._check_scheme_supported( is_fp8_w8a8_supported = self._check_scheme_supported(
QuarkW8A8Fp8.get_min_capability(), error=False) QuarkW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported: if is_fp8_w8a8_supported:
weight_qscheme = cast(str, weight_config.get("qscheme")) return QuarkW8A8Fp8(weight_config, input_config)
input_static = (input_config is not None and
not cast(bool, input_config.get("is_dynamic")))
return QuarkW8A8Fp8(qscheme=weight_qscheme,
is_static_input_scheme=input_static)
elif self._is_static_tensor_w8a8(weight_config, input_config): elif self._is_static_tensor_w8a8(weight_config, input_config):
weight_qscheme = cast(str, weight_config.get("qscheme")) weight_qscheme = cast(str, weight_config.get("qscheme"))
return QuarkW8A8Int8(qscheme=weight_qscheme, return QuarkW8A8Int8(qscheme=weight_qscheme,
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional from typing import Any, Callable, Optional, cast
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
...@@ -19,10 +19,19 @@ __all__ = ["QuarkW8A8Fp8"] ...@@ -19,10 +19,19 @@ __all__ = ["QuarkW8A8Fp8"]
class QuarkW8A8Fp8(QuarkScheme): class QuarkW8A8Fp8(QuarkScheme):
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): def __init__(self, weight_config: dict[str, Any],
self.qscheme = qscheme input_config: Optional[dict[str, Any]]):
self.is_static_input_scheme = is_static_input_scheme self.weight_qscheme = cast(str, weight_config.get("qscheme"))
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False) self.is_static_input_scheme: bool = False
self.input_qscheme: Optional[str] = None
if input_config is not None:
self.is_static_input_scheme = not cast(
bool, input_config.get("is_dynamic"))
self.input_qscheme = cast(str, input_config.get("qscheme"))
self.use_per_token_if_dynamic = (not self.is_static_input_scheme \
and self.input_qscheme == "per_channel")
self.fp8_linear = Fp8LinearOp(
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
@classmethod @classmethod
...@@ -34,7 +43,7 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -34,7 +43,7 @@ class QuarkW8A8Fp8(QuarkScheme):
# If per tensor, when we have a fused module (e.g. QKV) with per # If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel), # tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor # requantize so we can always run per tensor
if self.qscheme == "per_tensor": if self.weight_qscheme == "per_tensor":
if current_platform.is_rocm(): if current_platform.is_rocm():
input_scale = getattr(layer, 'input_scale', None) input_scale = getattr(layer, 'input_scale', None)
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
...@@ -58,7 +67,7 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -58,7 +67,7 @@ class QuarkW8A8Fp8(QuarkScheme):
layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose. # If channelwise, scales are already lined up, so just transpose.
elif self.qscheme == "per_channel": elif self.weight_qscheme == "per_channel":
weight = layer.weight weight = layer.weight
if current_platform.is_fp8_fnuz(): if current_platform.is_fp8_fnuz():
...@@ -73,13 +82,15 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -73,13 +82,15 @@ class QuarkW8A8Fp8(QuarkScheme):
requires_grad=False) requires_grad=False)
else: else:
weight_scale = layer.weight_scale.data weight_scale = layer.weight_scale.data
if self.use_per_token_if_dynamic:
weight_scale = weight_scale.view(-1, 1)
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter # required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else: else:
raise ValueError(f"Unknown quantization scheme {self.qscheme}") raise ValueError(
f"Unknown quantization scheme {self.weight_qscheme}")
# INPUT SCALE # INPUT SCALE
if self.is_static_input_scheme: if self.is_static_input_scheme:
...@@ -109,14 +120,14 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -109,14 +120,14 @@ class QuarkW8A8Fp8(QuarkScheme):
# WEIGHT SCALE # WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return # TODO: update create_xxx_parameter functions to return
# the newly added parameters # the newly added parameters
if self.qscheme == "per_channel": if self.weight_qscheme == "per_channel":
weight_scale = ChannelQuantScaleParameter( weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes)), data=torch.empty((sum(output_partition_sizes)),
dtype=torch.float32), dtype=torch.float32),
output_dim=0, output_dim=0,
weight_loader=weight_loader) weight_loader=weight_loader)
else: else:
assert self.qscheme == "per_tensor" assert self.weight_qscheme == "per_tensor"
weight_scale = PerTensorScaleParameter(data=torch.empty( weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32), len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader) weight_loader=weight_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment