"tests/entrypoints/instrumentator/test_basic.py" did not exist on "e3b318216d13225221ffbf03cc815648104e37c5"
Unverified Commit d9e62c03 authored by Bowen Bao's avatar Bowen Bao Committed by GitHub
Browse files

[Quark] Fix MoE fp8 activation scale handling on mi300 (#34386)


Signed-off-by: default avatarBowen Bao <bowenbao@amd.com>
parent a1a2d794
......@@ -858,7 +858,7 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_input_scale = None
def process_weights_after_loading(self, layer):
if self.static_input_scales:
if self.static_input_scales and self.input_dtype == "fp8":
# firstly, process activations if fp8 static input
if layer.w13_input_scale is None or layer.w2_input_scale is None:
raise ValueError(
......@@ -883,14 +883,14 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
if current_platform.is_fp8_fnuz():
# Normalize the weights and scales
_, _, w13_input_scale = normalize_e4m3fn_to_e4m3fnuz(
torch.empty_like(layer.w13_weight, dtype=torch.float8_e4m3fnuz),
torch.empty_like(layer.w13_weight, dtype=torch.float8_e4m3fn),
torch.empty_like(
layer.w13_weight_scale, dtype=layer.w13_weight_scale.dtype
),
layer.w13_input_scale,
)
_, _, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
torch.empty_like(layer.w2_weight, dtype=torch.float8_e4m3fnuz),
torch.empty_like(layer.w2_weight, dtype=torch.float8_e4m3fn),
torch.empty_like(
layer.w2_weight_scale, dtype=layer.w13_weight_scale.dtype
),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment