Unverified Commit 53fc1664 authored by Ilya Markov's avatar Ilya Markov Committed by GitHub
Browse files

[BugFix] Fix EPLB fail for MoeFP4 model with Marlin backend (#33262)


Signed-off-by: default avatarilmarkov <markovilya197@gmail.com>
parent 31b25f65
......@@ -44,7 +44,9 @@ def set_weight_attrs(
setattr(weight, key, value)
def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.Tensor):
def replace_parameter(
layer: torch.nn.Module, param_name: str, new_data: torch.Tensor | None
):
"""
Replace a parameter of a layer while maintaining the ability to reload the weight.
Called within implementations of the `process_weights_after_loading` method.
......@@ -54,9 +56,15 @@ def replace_parameter(layer: torch.nn.Module, param_name: str, new_data: torch.T
Args:
layer: Layer containing parameter to replace
param_name: Name of parameter to replace
new_data: New data of the new parameter
new_data: New data of the new parameter, or None to set the parameter to None
"""
# should not be used on a tied/shared param
# If new_data is None, set the parameter to None
if new_data is None:
setattr(layer, param_name, None)
return
if isinstance(new_data, torch.nn.Parameter):
new_data = new_data.data
new_param = torch.nn.Parameter(new_data, requires_grad=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment