Unverified Commit 8e2363dc authored by Alex Sun's avatar Alex Sun Committed by GitHub
Browse files

fix amd EP MoE FP8 issue (#7125)

parent f9dc9dd2
...@@ -33,10 +33,12 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -33,10 +33,12 @@ from sglang.srt.layers.quantization.base_config import (
) )
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
scaled_fp8_quant, scaled_fp8_quant,
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
sglang_per_token_quant_fp8, sglang_per_token_quant_fp8,
) )
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.managers.expert_location import get_global_expert_location_metadata from sglang.srt.managers.expert_location import get_global_expert_location_metadata
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -50,6 +52,7 @@ from sglang.srt.utils import ( ...@@ -50,6 +52,7 @@ from sglang.srt.utils import (
) )
_is_hip = is_hip() _is_hip = is_hip()
_is_fp8_fnuz = is_fp8_fnuz()
if _is_hip: if _is_hip:
from vllm._custom_ops import scaled_fp8_quant from vllm._custom_ops import scaled_fp8_quant
...@@ -843,6 +846,33 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ...@@ -843,6 +846,33 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
torch.max(layer.w13_weight_scale, dim=1).values, torch.max(layer.w13_weight_scale, dim=1).values,
requires_grad=False, requires_grad=False,
) )
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if _is_fp8_fnuz:
# activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight,
weight_scale=layer.w13_weight_scale_inv,
input_scale=None,
)
w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w2_weight,
weight_scale=layer.w2_weight_scale_inv,
input_scale=None,
)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(
w13_weight, requires_grad=False
)
layer.w13_weight_scale_inv = torch.nn.Parameter(
w13_weight_scale, requires_grad=False
)
layer.w13_input_scale = None
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_weight_scale_inv = torch.nn.Parameter(
w2_weight_scale, requires_grad=False
)
layer.w2_input_scale = None
return return
def apply( def apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment