Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
64bbb6e1
Unverified
Commit
64bbb6e1
authored
Mar 08, 2021
by
Sam Shleifer
Committed by
GitHub
Mar 08, 2021
Browse files
[doc] fix enable_wrap syntax in FSDP docs (#497)
parent
c06efdf6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
2 deletions
+8
-2
.gitignore
.gitignore
+1
-0
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+7
-2
No files found.
.gitignore
View file @
64bbb6e1
...
@@ -26,3 +26,4 @@ ENV/
...
@@ -26,3 +26,4 @@ ENV/
env.bak/
env.bak/
venv.bak/
venv.bak/
.vscode/*
.vscode/*
*.DS_Store
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
64bbb6e1
...
@@ -59,6 +59,8 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -59,6 +59,8 @@ class FullyShardedDataParallel(nn.Module):
Usage::
Usage::
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel
torch.cuda.set_device(device_id)
torch.cuda.set_device(device_id)
sharded_module = FullyShardedDataParallel(my_module)
sharded_module = FullyShardedDataParallel(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
...
@@ -73,15 +75,18 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -73,15 +75,18 @@ class FullyShardedDataParallel(nn.Module):
models and to improve training speed by overlapping the all-gather step
models and to improve training speed by overlapping the all-gather step
across the forward pass. For example::
across the forward pass. For example::
import torch
from fairscale.nn.auto_wrap import enable_wrap, auto_wrap
from fairscale.nn.auto_wrap import enable_wrap, auto_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
with enable_wrap(**fsdp_params):
with enable_wrap(
wrapper_cls=FSDP,
**fsdp_params):
# Wraps layer in FSDP by default if within context
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
self.l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(self.l1, FSDP)
assert isinstance(self.l1, FSDP)
# Separately Wraps children modules with more than 1e8 params
# Separately Wraps children modules with more than 1e8 params
self.l2 = auto_wrap(TransformerBlock(), min_num_params=1e8)
large_tfmr = torch.nn.Transformer(d_model=2048, encoder_layers=12, decoder_layers=12)
self.l2 = auto_wrap(large_tfmr, min_num_params=1e8)
assert isinstance(self.l2, FSDP)
.. warning::
.. warning::
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment