Unverified Commit e6bf9f15 authored by BadrBasowid's avatar BadrBasowid Committed by GitHub
Browse files

[Bugfix][CI] Fix Marlin FP8 Linear Kernel for Compressed Tensors Format (#38092)


Signed-off-by: default avatarBadrBasowid <Badr.Basowid@gmail.com>
Signed-off-by: default avatarBadrBasowid <61441185+BadrBasowid@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 144030c8
...@@ -76,8 +76,25 @@ class MarlinFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): ...@@ -76,8 +76,25 @@ class MarlinFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
replace_parameter(layer, "weight", weight.data) replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data) replace_parameter(layer, "weight_scale_inv", weight_scale_inv.data)
else: else:
weight = layer.weight.t() w_q, *_ = self._get_layer_params(layer)
replace_parameter(layer, "weight", weight.data) # Compressed tensors transposes the weight to (K, N)
# for channel and tensor quant strategies.
# So we can skip the transpose if the layout is
# already (K, N).
# TODO: Remove this check once the layouts have been
# canonicalized to a standard (N, K) dimension. See issue
# #33314 for more details.
if w_q.shape != (
layer.input_size_per_partition,
layer.output_size_per_partition,
):
# transpose the weights to (K,N)
replace_parameter(
layer,
"weight",
w_q.t(),
)
layer.input_scale = None layer.input_scale = None
prepare_fp8_layer_for_marlin( prepare_fp8_layer_for_marlin(
layer, self.size_k_first, input_dtype=self.marlin_input_dtype layer, self.size_k_first, input_dtype=self.marlin_input_dtype
......
...@@ -188,6 +188,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -188,6 +188,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
if self.strategy == QuantizationStrategy.BLOCK: if self.strategy == QuantizationStrategy.BLOCK:
maybe_post_process_fp8_weight_block(layer) maybe_post_process_fp8_weight_block(layer)
if hasattr(self, "fp8_linear"):
self.fp8_linear.process_weights_after_loading(layer)
def apply_weights( def apply_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
...@@ -705,6 +705,9 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase): ...@@ -705,6 +705,9 @@ class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False) layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False)
if hasattr(self, "fp8_linear"):
self.fp8_linear.process_weights_after_loading(layer)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment