Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
d80c38f9
Unverified
Commit
d80c38f9
authored
Sep 22, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Sep 22, 2020
Browse files
[chore] OSS doc (#101)
* Doc extensions to some APIs * FIx the benchmark and tutorial
parent
63f7796a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
80 additions
and
38 deletions
+80
-38
benchmarks/oss.py
benchmarks/oss.py
+13
-9
docs/source/tutorials/oss.rst
docs/source/tutorials/oss.rst
+15
-10
fairscale/optim/oss.py
fairscale/optim/oss.py
+52
-19
No files found.
benchmarks/oss.py
View file @
d80c38f9
...
...
@@ -3,9 +3,8 @@
import
argparse
import
math
import
os
import
time
from
typing
import
Any
,
List
,
cast
from
typing
import
Any
,
List
,
Optional
,
cast
import
torch
import
torch.distributed
as
dist
...
...
@@ -24,9 +23,9 @@ OPTIM = torch.optim.RMSprop
def
dist_init
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_process_group
(
backend
=
BACKEND
,
init_method
=
"tcp://localhost:29501"
,
rank
=
rank
,
world_size
=
world_size
,
store
=
None
)
def
get_problem
(
rank
,
data_size
,
batch_size
):
...
...
@@ -81,9 +80,11 @@ def train_oss_ddp(
model
.
zero_grad
()
outputs
=
model
(
batch
[
"inputs"
])
loss
=
loss_fn
(
outputs
,
batch
[
"label"
])
dist
.
all_reduce
(
loss
,
op
=
dist
.
ReduceOp
.
SUM
)
loss
/=
world_size
loss
.
backward
()
dist
.
all_reduce
(
loss
,
op
=
dist
.
ReduceOp
.
SUM
)
if
dist
.
get_rank
()
==
0
:
print
(
f
"Loss:
{
loss
.
item
()
}
"
)
...
...
@@ -146,6 +147,7 @@ def train(
model
.
train
()
measurements
=
[]
final_loss
:
Optional
[
float
]
=
-
1.0
for
epoch
in
range
(
num_epochs
):
epoch_start
=
time
.
monotonic
()
...
...
@@ -156,12 +158,14 @@ def train(
model
.
zero_grad
()
outputs
=
model
(
batch
[
"inputs"
])
loss
=
loss_fn
(
outputs
,
batch
[
"label"
])
dist
.
all_reduce
(
loss
,
op
=
dist
.
ReduceOp
.
SUM
)
loss
/=
world_size
loss
.
backward
()
dist
.
all_reduce
(
loss
,
op
=
dist
.
ReduceOp
.
SUM
)
return
loss
optimizer
.
step
(
closure
)
final_loss
=
optimizer
.
step
(
closure
)
epoch_end
=
time
.
monotonic
()
...
...
@@ -176,7 +180,7 @@ def train(
measurements
.
append
(
data_size
/
(
epoch_end
-
epoch_start
))
if
dist
.
get_rank
()
==
0
:
print
(
f
"Epoch
{
epoch
}
- processed
{
measurements
[
-
1
]:.
2
f
}
img per sec"
)
print
(
f
"Epoch
{
epoch
}
- processed
{
measurements
[
-
1
]:.
2
f
}
img per sec
. Loss
{
final_loss
}
"
)
torch
.
cuda
.
synchronize
(
rank
)
training_stop
=
time
.
monotonic
()
...
...
docs/source/tutorials/oss.rst
View file @
d80c38f9
...
...
@@ -2,9 +2,9 @@ Optimizer state sharding
========================
Using
torch
.
nn
.
parallel
.
DistributedDataParallel
leads
to
some
wasted
communications
,
but
it
is
possible
and
makes
OSS
a
drop
in
solution
in
your
existing
torch
distributed
code
.
Let
's suppose that your trainer looks like
make html
Let
's suppose that your trainer looks like
.. code-block::
default
.. code-block::
python
import torch
...
...
@@ -23,7 +23,9 @@ Let's suppose that your trainer looks likemake html
loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
optimizer = torch.optim.SGD(params=model.parameters(), **base_optimizer_arguments)
optimizer = torch.optim.SGD(
params=model.parameters(),
**base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
...
...
@@ -33,18 +35,17 @@ Let's suppose that your trainer looks likemake html
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows
.. code-block::
default
.. code-block::
python
:emphasize-lines: 49, 65, 66
import torch
from fairscale.optim.oss import OSS
...
...
@@ -61,9 +62,14 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
# ** NEW ** Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
optimizer = OSS(
params=model.parameters(),
optim=base_optimizer,
**base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
...
...
@@ -73,8 +79,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
optimizer.step()
fairscale/optim/oss.py
View file @
d80c38f9
...
...
@@ -25,8 +25,7 @@ else:
class
OSS
(
Optimizer
):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_.
::
opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
:: opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
.. _ZeRO: https://arxiv.org/abs/1910.02054
...
...
@@ -142,6 +141,14 @@ class OSS(Optimizer):
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
,
**
kwargs
:
Any
)
->
Optional
[
float
]:
"""Performs a single optimization step (parameter update).
Arguments:
closure (callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
.. note: Any extra parameter is passed to the base optimizer as-is"""
# Sync oss param_groups attributes in case they've been updated by a scheduler.
self
.
_sync_param_groups
()
...
...
@@ -162,13 +169,22 @@ class OSS(Optimizer):
return
loss
def
local_state_dict
(
self
)
->
dict
:
""" Gets this rank's state_dict. """
"""Gets this rank's state_dict.
Returns:
The state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
"""
return
self
.
optim
.
state_dict
()
def
consolidate_state_dict
(
self
,
recipient_rank
:
int
=
0
)
->
None
:
"""
Update the consolidated state_dict list, one per rank.
"""Update the consolidated state_dict list, one per rank.
This needs to be called on all replicas
"""
.. warning:
This needs to be called on all replicas"""
# Sync lr and other attributes in case its been updated
self
.
_sync_param_groups
()
...
...
@@ -183,13 +199,14 @@ class OSS(Optimizer):
self
.
_broadcast_state_dict
()
def
state_dict
(
self
)
->
Dict
[
str
,
Any
]:
"""
Return the last known global optimizer state, which consist of a list of the shards.
"""Return the last known global optimizer state, which consist of a list of the shards.
.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.
NOTE:
- If the state has not been consolidated, this returns a shard's worth, not the global state.
- Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
.. warning:
Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""
if
len
(
self
.
_all_states
)
==
0
:
...
...
@@ -218,7 +235,10 @@ class OSS(Optimizer):
}
def
load_local_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
""" Loads this rank's state_dict. """
"""Loads this rank's state_dict.
.. warning: This is not meant to load the global state dict.
"""
self
.
optim
.
load_state_dict
(
state_dict
)
...
...
@@ -242,7 +262,12 @@ class OSS(Optimizer):
global_group
[
k
]
=
v
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
])
->
None
:
""" Restore the global parameter groups as well as the shard """
"""Restore the global parameter groups as well as the shard.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`
"""
# Check whether we got a local or global dict
if
state_dict
[
"local_state_dict"
]:
...
...
@@ -256,6 +281,18 @@ class OSS(Optimizer):
self
.
load_local_state_dict
({
"state"
:
state_dict
[
"state"
][
self
.
rank
],
"param_groups"
:
param_groups
})
def
add_param_group
(
self
,
param_group
:
dict
)
->
None
:
"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen layers can be made
trainable and added to the :class:`Optimizer` as training progresses.
Arguments:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options
.. warning: This handles updating the shards on all partitions, but needs to be called on all ranks.
"""
super
().
add_param_group
(
param_group
)
if
not
self
.
in_super_constructor
:
self
.
_partition_parameters
.
clear
()
# Force a re-partitioning
...
...
@@ -273,9 +310,7 @@ class OSS(Optimizer):
local_group
[
k
]
=
global_group
[
k
]
def
_collect_sharded_states
(
self
)
->
List
[
Dict
[
str
,
Any
]]:
"""
Collect all the state shards, in CPU memory.
"""
"""Collect all the state shards, in CPU memory."""
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
all_states
:
List
[
Dict
[
str
,
Any
]]
=
[]
...
...
@@ -304,9 +339,7 @@ class OSS(Optimizer):
return
all_states
def
_broadcast_state_dict
(
self
)
->
None
:
"""
Broadcast this rank's state shard, discard others
"""
"""Broadcast this rank's state shard, discard others"""
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
for
rank
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment