Unverified Commit 83d09d36 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[CI][Bugfix][AMD][ Ensure weights created when using emulating OCP MXFP4 (#36993)


Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
parent 92b9afee
......@@ -267,20 +267,26 @@ class QuarkOCP_MX(QuarkScheme):
def get_min_capability(cls) -> int:
return 70
def process_dynamic_mxfp4_weights_after_loading(
self, layer: torch.nn.Module
) -> None:
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
layer.weight_scale = torch.nn.Parameter(w_s.T.contiguous(), requires_grad=False)
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
if self.emulate:
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
else:
if self.dynamic_mxfp4_quant:
w_q, w_s = dynamic_mxfp4_quant(layer.weight)
self.process_dynamic_mxfp4_weights_after_loading(layer)
else:
layer.weight_scale = torch.nn.Parameter(
w_s.T.contiguous(), requires_grad=False
layer.weight_scale.data, requires_grad=False
)
layer.weight = torch.nn.Parameter(w_q, requires_grad=False)
else:
if self.dynamic_mxfp4_quant:
self.process_dynamic_mxfp4_weights_after_loading(layer)
elif self.rocm_use_aiter_fp4_asm_gemm:
# shuffle weight scale
weight_scale_shuffle = layer.weight_scale.data
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment