Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
3fb8aa2b
Unverified
Commit
3fb8aa2b
authored
Sep 10, 2021
by
Min Xu
Committed by
GitHub
Sep 10, 2021
Browse files
[doc]: updating FSDP example (#788)
Co-authored-by:
Min Xu
<
min.xu.public@gmail.com
>
parent
e1f36346
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
7 deletions
+18
-7
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+18
-7
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
3fb8aa2b
...
@@ -93,10 +93,11 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -93,10 +93,11 @@ class FullyShardedDataParallel(nn.Module):
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
.. _`Xu et al.`: https://arxiv.org/abs/2004.13336
.. _DeepSpeed: https://www.deepspeed.ai/
.. _DeepSpeed: https://www.deepspeed.ai/
U
sage::
Pseudo-code u
sage::
import torch
import torch
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
torch.cuda.set_device(device_id)
torch.cuda.set_device(device_id)
sharded_module = FSDP(my_module)
sharded_module = FSDP(my_module)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001)
...
@@ -112,17 +113,27 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -112,17 +113,27 @@ class FullyShardedDataParallel(nn.Module):
across the forward pass. For example::
across the forward pass. For example::
import torch
import torch
from fairscale.nn.
auto_
wrap import enable_wrap, auto_wrap
, wrap
from fairscale.nn.wrap import
wrap,
enable_wrap, auto_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.utils.testing import dist_init, teardown, rmf
result = dist_init(0, 1, "/tmp/t1", "/tmp/t2")
assert result
fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True)
fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True)
with enable_wrap(**fsdp_params):
with enable_wrap(**fsdp_params):
l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(l1, FSDP)
# Wraps layer in FSDP by default if within context
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(self.l1, FSDP)
# Separately Wraps children modules with more than 1e8 params
# Separately Wraps children modules with more than 1e8 params
large_tfmr = torch.nn.Transformer(d_model=2048, encoder_layers=12, decoder_layers=12)
large_tfmr = torch.nn.Transformer(d_model=2048, num_encoder_layers=12,
self.l2 = auto_wrap(large_tfmr, min_num_params=1e8)
num_decoder_layers=12)
assert isinstance(self.l2, FSDP)
l2 = auto_wrap(large_tfmr)
assert isinstance(l2.encoder, FSDP)
assert isinstance(l2.decoder, FSDP)
print(l2) # You can print the model to examine FSDP wrapping.
teardown()
rmf("/tmp/t1")
rmf("/tmp/t2")
.. warning::
.. warning::
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment