"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "d7f8db8e21fe63d4279afafadc6ed4663952cba8"
Commit 7e4de520 authored by Hongxin Liu's avatar Hongxin Liu
Browse files

[shardformer] fix base policy (#4229)

parent 208ac8f2
...@@ -156,7 +156,10 @@ class Policy(ABC): ...@@ -156,7 +156,10 @@ class Policy(ABC):
# append or create a new description # append or create a new description
if target_key in policy: if target_key in policy:
policy[target_key].sub_module_replacement.extend(description) if policy[target_key].sub_module_replacement is None:
policy[target_key].sub_module_replacement = description
else:
policy[target_key].sub_module_replacement.extend(description)
else: else:
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description) policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)
...@@ -174,7 +177,10 @@ class Policy(ABC): ...@@ -174,7 +177,10 @@ class Policy(ABC):
target_key (Union[str, nn.Module]): the key of the policy to be updated target_key (Union[str, nn.Module]): the key of the policy to be updated
""" """
if target_key in policy: if target_key in policy:
policy[target_key].method_replacement.update(description) if policy[target_key].method_replacement is None:
policy[target_key].method_replacement = description
else:
policy[target_key].method_replacement.update(description)
else: else:
policy[target_key] = ModulePolicyDescription(method_replacement=description) policy[target_key] = ModulePolicyDescription(method_replacement=description)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment