Unverified Commit 83d09d36 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI][Bugfix][AMD][ Ensure weights created when using emulating OCP MXFP4 (#36993)


Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
parent 92b9afee
...@@ -267,20 +267,26 @@ class QuarkOCP_MX(QuarkScheme): ...@@ -267,20 +267,26 @@ class QuarkOCP_MX(QuarkScheme):
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 70 return 70
def process_dynamic_mxfp4_weights_after_loading(
self, layer: torch.nn.Module
) -> None:
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
layer.weight_scale = torch.nn.Parameter(w_s.T.contiguous(), requires_grad=False)
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
if self.emulate: if self.emulate:
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
else:
if self.dynamic_mxfp4_quant: if self.dynamic_mxfp4_quant:
w_q, w_s = dynamic_mxfp4_quant(layer.weight) self.process_dynamic_mxfp4_weights_after_loading(layer)
else:
layer.weight_scale = torch.nn.Parameter( layer.weight_scale = torch.nn.Parameter(
w_s.T.contiguous(), requires_grad=False layer.weight_scale.data, requires_grad=False
) )
layer.weight = torch.nn.Parameter(w_q, requires_grad=False) else:
if self.dynamic_mxfp4_quant:
self.process_dynamic_mxfp4_weights_after_loading(layer)
elif self.rocm_use_aiter_fp4_asm_gemm: elif self.rocm_use_aiter_fp4_asm_gemm:
# shuffle weight scale # shuffle weight scale
weight_scale_shuffle = layer.weight_scale.data weight_scale_shuffle = layer.weight_scale.data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment