Unverified Commit 5f7fab88 authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[ROCm][FEAT] Integrate aiter gemm w8a8 ptpc (#33773)


Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
parent 343f6523
...@@ -1856,6 +1856,7 @@ class TestFP8Layer(torch.nn.Module): ...@@ -1856,6 +1856,7 @@ class TestFP8Layer(torch.nn.Module):
out_dtype=out_dtype, out_dtype=out_dtype,
force_kernel=force_kernel, force_kernel=force_kernel,
) )
self.kernel.process_weights_after_loading(self)
def is_quant_fp8_enabled(self) -> bool: def is_quant_fp8_enabled(self) -> bool:
return self.kernel.quant_fp8.enabled() return self.kernel.quant_fp8.enabled()
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import functools import functools
from collections.abc import Callable from collections.abc import Callable
import pandas as pd
import torch import torch
from torch._ops import OpOverload from torch._ops import OpOverload
...@@ -56,6 +57,29 @@ def is_aiter_found_and_supported() -> bool: ...@@ -56,6 +57,29 @@ def is_aiter_found_and_supported() -> bool:
return False return False
@functools.cache
def _load_gemm_tuned_configs(
q_dtype_w: torch.dtype, csv_path: str
) -> set[tuple[int, int, int]]:
try:
df = pd.read_csv(csv_path).drop_duplicates()
df = df[df["q_dtype_w"] == str(q_dtype_w)]
return set(zip(df["N"].astype(int), df["K"].astype(int), df["M"].astype(int)))
except Exception:
return set()
def _check_kernel_tuned(N: int, K: int, q_dtype_w: torch.dtype, csv_path: str) -> bool:
configs = _load_gemm_tuned_configs(q_dtype_w, csv_path)
l_m = (
[1, 2, 4]
+ list(range(8, 513, 8))
+ [1024, 1536]
+ [2**i for i in range(11, 19)]
)
return any((N, K, M) in configs for M in l_m)
def if_aiter_supported(func: Callable) -> Callable: def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if """Decorator that only executes the function if
ROCm AITER package is supported and enabled on gfx9 archs. ROCm AITER package is supported and enabled on gfx9 archs.
...@@ -468,7 +492,7 @@ def _rocm_aiter_mla_decode_fwd_fake( ...@@ -468,7 +492,7 @@ def _rocm_aiter_mla_decode_fwd_fake(
pass pass
def _rocm_aiter_gemm_a8w8_impl( def _rocm_aiter_w8a8_gemm_impl(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
...@@ -485,7 +509,7 @@ def _rocm_aiter_gemm_a8w8_impl( ...@@ -485,7 +509,7 @@ def _rocm_aiter_gemm_a8w8_impl(
return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype) return gemm_a8w8_CK(A, B, As, Bs, bias, output_dtype)
def _rocm_aiter_gemm_a8w8_fake( def _rocm_aiter_w8a8_gemm_fake(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
...@@ -499,6 +523,35 @@ def _rocm_aiter_gemm_a8w8_fake( ...@@ -499,6 +523,35 @@ def _rocm_aiter_gemm_a8w8_fake(
return Y return Y
def _rocm_aiter_preshuffled_per_token_w8a8_gemm_impl(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
from aiter import gemm_a8w8_bpreshuffle
output = gemm_a8w8_bpreshuffle(A, B, As, Bs, None, output_dtype)
if bias is not None:
output.add_(bias)
return output
def _rocm_aiter_preshuffled_per_token_w8a8_gemm_fake(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
m = A.shape[0]
n = B.shape[0]
return torch.empty(m, n, dtype=output_dtype, device=A.device)
def _rocm_aiter_triton_gemm_a8w8_blockscale_impl( def _rocm_aiter_triton_gemm_a8w8_blockscale_impl(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
...@@ -1313,11 +1366,15 @@ class rocm_aiter_ops: ...@@ -1313,11 +1366,15 @@ class rocm_aiter_ops:
) )
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_gemm_a8w8", op_name="rocm_aiter_w8a8_gemm",
op_func=_rocm_aiter_gemm_a8w8_impl, op_func=_rocm_aiter_w8a8_gemm_impl,
mutates_args=[], fake_impl=_rocm_aiter_w8a8_gemm_fake,
fake_impl=_rocm_aiter_gemm_a8w8_fake, )
dispatch_key=current_platform.dispatch_key,
direct_register_custom_op(
op_name="_rocm_aiter_preshuffled_per_token_w8a8_gemm",
op_func=_rocm_aiter_preshuffled_per_token_w8a8_gemm_impl,
fake_impl=_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake,
) )
direct_register_custom_op( direct_register_custom_op(
...@@ -1493,7 +1550,18 @@ class rocm_aiter_ops: ...@@ -1493,7 +1550,18 @@ class rocm_aiter_ops:
) )
@staticmethod @staticmethod
def gemm_a8w8( def w8a8_gemm(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_w8a8_gemm(A, B, As, Bs, bias, output_dtype)
@staticmethod
def preshuffled_per_token_w8a8_gemm(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
...@@ -1501,7 +1569,9 @@ class rocm_aiter_ops: ...@@ -1501,7 +1569,9 @@ class rocm_aiter_ops:
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops.vllm.rocm_aiter_gemm_a8w8(A, B, As, Bs, bias, output_dtype) return torch.ops.vllm._rocm_aiter_preshuffled_per_token_w8a8_gemm(
A, B, As, Bs, bias, output_dtype
)
@staticmethod @staticmethod
def triton_gemm_a8w8_blockscale( def triton_gemm_a8w8_blockscale(
...@@ -1920,6 +1990,24 @@ class rocm_aiter_ops: ...@@ -1920,6 +1990,24 @@ class rocm_aiter_ops:
(8192, 3584), (8192, 3584),
] ]
@staticmethod
def is_shuffled_per_token_w8a8_gemm_tuned(
N: int, K: int, q_dtype_w: torch.dtype
) -> bool:
import aiter.ops.gemm_op_a8w8 as aiter_gemm_a8w8_ops
csv_path = (
aiter_gemm_a8w8_ops.AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE
)
return _check_kernel_tuned(N, K, q_dtype_w, csv_path)
@staticmethod
def is_per_token_w8a8_gemm_tuned(N: int, K: int, q_dtype_w: torch.dtype) -> bool:
import aiter.ops.gemm_op_a8w8 as aiter_gemm_a8w8_ops
csv_path = aiter_gemm_a8w8_ops.AITER_CONFIGS.AITER_CONFIG_GEMM_A8W8_FILE
return _check_kernel_tuned(N, K, q_dtype_w, csv_path)
@staticmethod @staticmethod
def shuffle_weight( def shuffle_weight(
tensor: torch.Tensor, layout: tuple[int, int] = (16, 16) tensor: torch.Tensor, layout: tuple[int, int] = (16, 16)
......
...@@ -106,6 +106,8 @@ from vllm.model_executor.kernels.linear.scaled_mm import ( ...@@ -106,6 +106,8 @@ from vllm.model_executor.kernels.linear.scaled_mm import (
from vllm.model_executor.kernels.linear.scaled_mm.aiter import ( from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterFp8BlockScaledMMKernel, AiterFp8BlockScaledMMKernel,
AiterInt8ScaledMMLinearKernel, AiterInt8ScaledMMLinearKernel,
AiterPerTokenFp8ScaledMMLinearKernel,
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel,
) )
from vllm.model_executor.kernels.linear.scaled_mm.cpu import ( from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
CPUInt8ScaledMMLinearKernel, CPUInt8ScaledMMLinearKernel,
...@@ -165,6 +167,8 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = ...@@ -165,6 +167,8 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
ChannelWiseTorchFP8ScaledMMLinearKernel, ChannelWiseTorchFP8ScaledMMLinearKernel,
], ],
PlatformEnum.ROCM: [ PlatformEnum.ROCM: [
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel,
AiterPerTokenFp8ScaledMMLinearKernel,
ROCmFP8ScaledMMLinearKernel, ROCmFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel, PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel, RowWiseTorchFP8ScaledMMLinearKernel,
...@@ -360,18 +364,18 @@ def choose_scaled_mm_linear_kernel( ...@@ -360,18 +364,18 @@ def choose_scaled_mm_linear_kernel(
def init_fp8_linear_kernel( def init_fp8_linear_kernel(
activation_quant_key: QuantKey, activation_quant_key: QuantKey,
weight_quant_key: QuantKey, weight_quant_key: QuantKey,
weight_shape: tuple[int, int],
input_dtype: torch.dtype, input_dtype: torch.dtype,
out_dtype: torch.dtype, out_dtype: torch.dtype,
force_kernel: type[_KernelT] | None = None, weight_shape: tuple[int, int],
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
module_name: str | None = None, module_name: str | None = None,
) -> FP8ScaledMMLinearKernel | Fp8BlockScaledMMLinearKernel: ) -> FP8ScaledMMLinearKernel | Fp8BlockScaledMMLinearKernel:
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig( scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
weight_quant_key=weight_quant_key, weight_quant_key=weight_quant_key,
activation_quant_key=activation_quant_key, activation_quant_key=activation_quant_key,
weight_shape=weight_shape,
input_dtype=input_dtype, input_dtype=input_dtype,
out_dtype=out_dtype, out_dtype=out_dtype,
weight_shape=weight_shape,
) )
if activation_quant_key.scale.group_shape.is_per_group(): if activation_quant_key.scale.group_shape.is_per_group():
...@@ -725,6 +729,8 @@ __all__ = [ ...@@ -725,6 +729,8 @@ __all__ = [
"FP8ScaledMMLinearLayerConfig", "FP8ScaledMMLinearLayerConfig",
"Int8ScaledMMLinearLayerConfig", "Int8ScaledMMLinearLayerConfig",
"ScaledMMLinearLayerConfig", "ScaledMMLinearLayerConfig",
"AiterPreshuffledPerTokenFp8ScaledMMLinearKernel",
"AiterPerTokenFp8ScaledMMLinearKernel",
"NvFp4LinearKernel", "NvFp4LinearKernel",
"NvFp4LinearLayerConfig", "NvFp4LinearLayerConfig",
"AiterInt8ScaledMMLinearKernel", "AiterInt8ScaledMMLinearKernel",
......
...@@ -5,18 +5,27 @@ ...@@ -5,18 +5,27 @@
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import (
rocm_aiter_ops,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
) )
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .BlockScaledMMLinearKernel import ( from .BlockScaledMMLinearKernel import (
Fp8BlockScaledMMLinearKernel, Fp8BlockScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
) )
from .cutlass import CutlassInt8ScaledMMLinearKernel from .cutlass import CutlassInt8ScaledMMLinearKernel
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearLayerConfig,
)
logger = init_logger(__name__)
class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel): class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
...@@ -113,7 +122,154 @@ class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel): ...@@ -113,7 +122,154 @@ class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
# a to be [M, K] # a to be [M, K]
# b to be [N, K] # b to be [N, K]
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format # CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype) return rocm_aiter_ops.w8a8_gemm(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
class AiterPreshuffledPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return False, "requires ROCm."
if not rocm_aiter_ops.is_linear_fp8_enabled():
return (
False,
"requires setting `VLLM_ROCM_USE_AITER=1` "
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
)
try:
import aiter # noqa: F401
except Exception:
return False, "requires aiter library to be installed."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
is_ptpc = (
c.activation_quant_key.scale.group_shape.is_per_token()
and c.weight_quant_key.scale.group_shape.is_per_channel()
)
if c.weight_shape is None:
return False, "weight_shape is required for Aiter kernels"
N, K = c.weight_shape
fp8_dtype = current_platform.fp8_dtype()
if c.out_dtype is not torch.bfloat16:
return False, "requires bfloat16 output dtype."
if not is_ptpc:
return (
False,
"requires per token activation scales and per channel weight scales.",
)
if not (N % 16 == 0 and K % 16 == 0):
return (
False,
f"requires N and K dimensions divisible by 16, received "
f"N={N} and K={K}.",
)
# Aiter's shuffled per-token Gemm performs better than torch only when its
# tuned.
if not rocm_aiter_ops.is_shuffled_per_token_w8a8_gemm_tuned(N, K, fp8_dtype):
return (
False,
f"requires a tuned configuration for N: {N} and K: {K} "
f"and fp8 dtype {fp8_dtype}.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w_name, *_ = self.layer_param_names
w, *_ = self._get_layer_params(layer)
replace_parameter(
layer,
w_name,
torch.nn.Parameter(
rocm_aiter_ops.shuffle_weight(w.t().contiguous()).data,
requires_grad=False,
),
)
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
return rocm_aiter_ops.preshuffled_per_token_w8a8_gemm(
A, B, As, Bs, bias, out_dtype
)
class AiterPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
return AiterPreshuffledPerTokenFp8ScaledMMLinearKernel.is_supported(
compute_capability
)
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
is_ptpc = (
c.activation_quant_key.scale.group_shape.is_per_token()
and c.weight_quant_key.scale.group_shape.is_per_channel()
)
if c.weight_shape is None:
return False, "weight_shape is required for Aiter kernels"
N, K = c.weight_shape
fp8_dtype = current_platform.fp8_dtype()
if not is_ptpc:
return (
False,
"requires per token activation scales and per channel weight scales.",
)
# Aiter's per-token Gemm performs better than torch oonly when its
# tuned.
if not rocm_aiter_ops.is_per_token_w8a8_gemm_tuned(N, K, fp8_dtype):
return (
False,
f"requires a tuned configuration for N: {N} and K: {K} "
f"and fp8 dtype {fp8_dtype}.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w_name, *_ = self.layer_param_names
w, *_ = self._get_layer_params(layer)
replace_parameter(
layer,
w_name,
torch.nn.Parameter(w.t(), requires_grad=False),
)
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
return rocm_aiter_ops.w8a8_gemm(A, B, As, Bs, bias, out_dtype)
class AiterFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel): class AiterFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
......
...@@ -31,8 +31,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -31,8 +31,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
create_fp8_quant_key, create_fp8_quant_key,
kFp8DynamicTokenSym, kFp8DynamicTokenSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym, kFp8StaticTensorSym,
kFp8StaticTokenSym,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported, cutlass_block_fp8_supported,
...@@ -47,7 +47,7 @@ activation_quant_key_mapping = { ...@@ -47,7 +47,7 @@ activation_quant_key_mapping = {
DYNAMIC_QUANT: kFp8DynamicTokenSym, DYNAMIC_QUANT: kFp8DynamicTokenSym,
} }
weight_quant_key_mapping = { weight_quant_key_mapping = {
QuantizationStrategy.CHANNEL: kFp8StaticTokenSym, QuantizationStrategy.CHANNEL: kFp8StaticChannelSym,
QuantizationStrategy.TENSOR: kFp8StaticTensorSym, QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
} }
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -67,7 +67,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -67,7 +67,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled() self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enabled()
assert not self.is_static_input_scheme assert not self.is_static_input_scheme
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
self.weight_quant_key = create_fp8_quant_key( self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(*self.weight_block_size) static=True, group_shape=GroupShape(*self.weight_block_size)
) )
...@@ -76,7 +75,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -76,7 +75,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
) )
else: else:
self.activation_quant_key = activation_quant_key_mapping[ self.activation_quant_key = activation_quant_key_mapping[
is_static_input_scheme self.is_static_input_scheme
] ]
self.weight_quant_key = weight_quant_key_mapping[self.strategy] self.weight_quant_key = weight_quant_key_mapping[self.strategy]
...@@ -138,9 +137,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -138,9 +137,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key, activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key, weight_quant_key=self.weight_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype, input_dtype=self.input_dtype,
out_dtype=self.out_dtype, out_dtype=self.out_dtype,
weight_shape=(output_size_per_partition, input_size_per_partition),
module_name=self.__class__.__name__, module_name=self.__class__.__name__,
) )
......
...@@ -175,6 +175,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): ...@@ -175,6 +175,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
# Activations not quantized for marlin. # Activations not quantized for marlin.
del layer.input_scale_ub del layer.input_scale_ub
self.fp8_linear.process_weights_after_loading(layer)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
...@@ -397,8 +397,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -397,8 +397,6 @@ class Fp8LinearMethod(LinearMethodBase):
if self.block_quant: if self.block_quant:
assert not self.act_q_static assert not self.act_q_static
self.fp8_linear.process_weights_after_loading(layer)
# If checkpoint not serialized fp8, quantize the weights. # If checkpoint not serialized fp8, quantize the weights.
else: else:
# If checkpoint is fp8 per-tensor, handle that there are N scales for N # If checkpoint is fp8 per-tensor, handle that there are N scales for N
...@@ -428,6 +426,8 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -428,6 +426,8 @@ class Fp8LinearMethod(LinearMethodBase):
else: else:
layer.input_scale = None layer.input_scale = None
self.fp8_linear.process_weights_after_loading(layer)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
...@@ -517,6 +517,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): ...@@ -517,6 +517,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
self.fp8_linear.process_weights_after_loading(layer)
def apply( def apply(
self, self,
...@@ -597,6 +598,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase): ...@@ -597,6 +598,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = Parameter(layer.weight.t(), requires_grad=False) layer.weight = Parameter(layer.weight.t(), requires_grad=False)
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
self.fp8_linear.process_weights_after_loading(layer)
def apply( def apply(
self, self,
......
...@@ -120,6 +120,8 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -120,6 +120,8 @@ class QuarkW8A8Fp8(QuarkScheme):
if self.is_static_input_scheme: if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
self.fp8_linear.process_weights_after_loading(layer)
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment