Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8f7ee69f
Unverified
Commit
8f7ee69f
authored
Apr 14, 2021
by
Myle Ott
Committed by
GitHub
Apr 14, 2021
Browse files
[fix] [FSDP] Make _get_default_cuda_device more robust to modules without params (#606)
parent
82d6997c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
8 deletions
+13
-8
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+8
-5
fairscale/nn/wrap/auto_wrap.py
fairscale/nn/wrap/auto_wrap.py
+5
-3
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
8f7ee69f
...
...
@@ -1540,11 +1540,14 @@ class FullyShardedDataParallel(nn.Module):
def
_get_default_cuda_device
(
module
:
nn
.
Module
)
->
torch
.
device
:
"""Try to infer CUDA device from module parameters."""
compute_device
=
next
(
module
.
parameters
()).
device
if
compute_device
.
type
!=
"cuda"
:
# Fall back to current CUDA device.
compute_device
=
torch
.
device
(
"cuda"
)
return
compute_device
try
:
compute_device
=
next
(
module
.
parameters
()).
device
if
compute_device
.
type
==
"cuda"
:
return
compute_device
except
StopIteration
:
pass
# Fall back to current CUDA device
return
torch
.
device
(
"cuda"
)
@
torch
.
no_grad
()
...
...
fairscale/nn/wrap/auto_wrap.py
View file @
8f7ee69f
...
...
@@ -88,9 +88,11 @@ def enable_wrap(auto_wrap_policy: Optional[Callable] = None, **wrapper_kwargs: A
with enable_wrap(**params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
# Wraps children modules based on a different min_num_params
my_auto_wrap_policy = functools.partial(auto_wrap_policy, min_num_params=1e7)
self.l2 = auto_wrap(TransformerBlock(), shuold_wrap=my_auto_wrap_policy)
self.l2 = auto_wrap(
TransformerBlock(),
# Wraps children modules based on a different min_num_params
auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1e7)
)
Args:
auto_wrap_policy (Callable, Optional):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment