Unverified Commit 4cf256ae authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[misc][distributed] fix pp missing layer condition (#6446)

parent 64fdc08c
......@@ -83,7 +83,10 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
missing_layer_names = []
for name, module in model.named_modules():
if isinstance(module, PPMissingLayer):
missing_layer_names.append(name)
# NOTE: the trailing dot is used to match the prefix of the layer.
# without the dot, we could match a layer that is not missing,
# e.g., 'encoder.layer.1' would match 'encoder.layer.11'
missing_layer_names.append(name + '.')
_model_to_pp_missing_layer_names[model_id] = missing_layer_names
return missing_layer_names
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment