Unverified Commit 8f7ee69f authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

[fix] [FSDP] Make _get_default_cuda_device more robust to modules without params (#606)

parent 82d6997c
...@@ -1540,11 +1540,14 @@ class FullyShardedDataParallel(nn.Module): ...@@ -1540,11 +1540,14 @@ class FullyShardedDataParallel(nn.Module):
def _get_default_cuda_device(module: nn.Module) -> torch.device: def _get_default_cuda_device(module: nn.Module) -> torch.device:
"""Try to infer CUDA device from module parameters.""" """Try to infer CUDA device from module parameters."""
compute_device = next(module.parameters()).device try:
if compute_device.type != "cuda": compute_device = next(module.parameters()).device
# Fall back to current CUDA device. if compute_device.type == "cuda":
compute_device = torch.device("cuda") return compute_device
return compute_device except StopIteration:
pass
# Fall back to current CUDA device
return torch.device("cuda")
@torch.no_grad() @torch.no_grad()
......
...@@ -88,9 +88,11 @@ def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: A ...@@ -88,9 +88,11 @@ def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: A
with enable_wrap(**params): with enable_wrap(**params):
# Wraps layer in FSDP by default if within context # Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5)) self.l1 = wrap(torch.nn.Linear(5, 5))
# Wraps children modules based on a different min_num_params self.l2 = auto_wrap(
my_auto_wrap_policy = functools.partial(auto_wrap_policy, min_num_params=1e7) TransformerBlock(),
self.l2 = auto_wrap(TransformerBlock(), shuold_wrap=my_auto_wrap_policy) # Wraps children modules based on a different min_num_params
auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1e7)
)
Args: Args:
auto_wrap_policy (Callable, Optional): auto_wrap_policy (Callable, Optional):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment