Unverified Commit b498cd21 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny make fp4 moe method parameters more static (#8520)

parent 0fc54b97
...@@ -812,6 +812,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -812,6 +812,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) )
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
# Only use `swizzle_blockscale` for shapes, not for real content
layer.w13_blockscale_swizzled = Parameter(
self.swizzle_blockscale(layer.w13_weight_scale), requires_grad=False
)
w2_weight_scale = ModelWeightParameter( w2_weight_scale = ModelWeightParameter(
data=torch.empty( data=torch.empty(
layer.num_local_experts, layer.num_local_experts,
...@@ -826,6 +831,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -826,6 +831,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) )
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
layer.w2_blockscale_swizzled = Parameter(
self.swizzle_blockscale(layer.w2_weight_scale), requires_grad=False
)
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
extra_weight_attrs.update( extra_weight_attrs.update(
...@@ -1129,16 +1138,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -1129,16 +1138,12 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
# Process w13 weights # Process w13 weights
w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale) w13_blockscale_swizzled = self.swizzle_blockscale(layer.w13_weight_scale)
layer.w13_blockscale_swizzled = Parameter( layer.w13_blockscale_swizzled.data.copy_(w13_blockscale_swizzled)
w13_blockscale_swizzled, requires_grad=False
)
layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False)
# Process w2 weights # Process w2 weights
w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale) w2_blockscale_swizzled = self.swizzle_blockscale(layer.w2_weight_scale)
layer.w2_blockscale_swizzled = Parameter( layer.w2_blockscale_swizzled.data.copy_(w2_blockscale_swizzled)
w2_blockscale_swizzled, requires_grad=False
)
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
# Both flashinfer cutlass and regular cutlass use same processing for w2 # Both flashinfer cutlass and regular cutlass use same processing for w2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment