Unverified Commit e9feb488 authored by Zilin Zhu's avatar Zilin Zhu Committed by GitHub
Browse files

[RL] Remove the w13 weight_scale and input_scale for UnquantizedEPMoE… (#6308)

parent fc992a09
...@@ -497,7 +497,8 @@ class EPMoE(torch.nn.Module): ...@@ -497,7 +497,8 @@ class EPMoE(torch.nn.Module):
# Input scales can be loaded directly and should be equal. # Input scales can be loaded directly and should be equal.
if "input_scale" in weight_name: if "input_scale" in weight_name:
if ( if (
param_data[expert_id] != 1 (shard_id == "w1" or shard_id == "w3")
and param_data[expert_id] != 1
and (param_data[expert_id] - loaded_weight).abs() > 1e-5 and (param_data[expert_id] - loaded_weight).abs() > 1e-5
): ):
raise ValueError( raise ValueError(
...@@ -571,13 +572,10 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -571,13 +572,10 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
# scale # scale
layer.register_parameter("w13_input_scale", None)
layer.register_parameter("w13_weight_scale", None)
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32) ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
w13_input_scale = torch.nn.Parameter(
ones_tensor,
requires_grad=False,
)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter( w2_input_scale = torch.nn.Parameter(
ones_tensor, ones_tensor,
...@@ -586,13 +584,6 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -586,13 +584,6 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs) set_weight_attrs(w2_input_scale, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
ones_tensor,
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter( w2_weight_scale = torch.nn.Parameter(
ones_tensor, ones_tensor,
requires_grad=False, requires_grad=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment