Commit 40e78153 authored by Anthony Chen's avatar Anthony Chen Committed by Facebook GitHub Bot
Browse files

Migrate transformer_auto_wrap_policy to ModuleWrapPolicy

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/633

transformer_auto_wrap_policy is buggy and causes issues when wrapping wrapped module. Migrate to ModuleWrapPolicy

Reviewed By: tglik

Differential Revision: D51124721

fbshipit-source-id: 61c4f5f810ead3c3776a7310926b2181121162ac
parent f2a0c52c
......@@ -6,8 +6,8 @@ import torch
from detectron2.utils.registry import Registry
from torch.distributed.fsdp.wrap import (
always_wrap_policy as _always_wrap_policy,
ModuleWrapPolicy,
size_based_auto_wrap_policy as _size_based_auto_wrap_policy,
transformer_auto_wrap_policy as _layer_based_auto_wrap_policy,
)
......@@ -109,7 +109,7 @@ def layer_based_auto_wrap_policy(
model, layer_names: Iterable[str], **kwargs
) -> Optional[Callable]:
"""
Wrapper for transformer_auto_wrap_policy() from torch.distributed.fsdp.wrap
Wrapper for ModuleWrapPolicy() from torch.distributed.fsdp.wrap
Args:
layer_names: a list of layer names
"""
......@@ -117,7 +117,4 @@ def layer_based_auto_wrap_policy(
len(layer_names) > 0
), "layer_names should be a nonempty list of layer names contained in the model"
layer_cls = get_layer_cls_from_names(model, layer_names)
return partial(
_layer_based_auto_wrap_policy,
transformer_layer_cls=layer_cls,
)
return ModuleWrapPolicy(module_classes=layer_cls)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment