Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
d80c38f9
Unverified
Commit
d80c38f9
authored
Sep 22, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Sep 22, 2020
Browse files
[chore] OSS doc (#101)
* Doc extensions to some APIs * FIx the benchmark and tutorial
parent
63f7796a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
80 additions
and
38 deletions
+80
-38
benchmarks/oss.py
benchmarks/oss.py
+13
-9
docs/source/tutorials/oss.rst
docs/source/tutorials/oss.rst
+15
-10
fairscale/optim/oss.py
fairscale/optim/oss.py
+52
-19
No files found.
benchmarks/oss.py
View file @
d80c38f9
...
@@ -3,9 +3,8 @@
...
@@ -3,9 +3,8 @@
import
argparse
import
argparse
import
math
import
math
import
os
import
time
import
time
from
typing
import
Any
,
List
,
cast
from
typing
import
Any
,
List
,
Optional
,
cast
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -24,9 +23,9 @@ OPTIM = torch.optim.RMSprop
...
@@ -24,9 +23,9 @@ OPTIM = torch.optim.RMSprop
def
dist_init
(
rank
,
world_size
):
def
dist_init
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
dist
.
init_process_group
(
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
backend
=
BACKEND
,
init_method
=
"tcp://localhost:29501"
,
rank
=
rank
,
world_size
=
world_size
,
store
=
None
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
rank
,
world_size
=
world_size
)
)
def
get_problem
(
rank
,
data_size
,
batch_size
):
def
get_problem
(
rank
,
data_size
,
batch_size
):
...
@@ -81,9 +80,11 @@ def train_oss_ddp(
...
@@ -81,9 +80,11 @@ def train_oss_ddp(
model
.
zero_grad
()
model
.
zero_grad
()
outputs
=
model
(
batch
[
"inputs"
])
outputs
=
model
(
batch
[
"inputs"
])
loss
=
loss_fn
(
outputs
,
batch
[
"label"
])
loss
=
loss_fn
(
outputs
,
batch
[
"label"
])
dist
.
all_reduce
(
loss
,
op
=
dist
.
ReduceOp
.
SUM
)
loss
/=
world_size
loss
/=
world_size
loss
.
backward
()
loss
.
backward
()
dist
.
all_reduce
(
loss
,
op
=
dist
.
ReduceOp
.
SUM
)
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
print
(
f
"Loss:
{
loss
.
item
()
}
"
)
print
(
f
"Loss:
{
loss
.
item
()
}
"
)
...
@@ -146,6 +147,7 @@ def train(
...
@@ -146,6 +147,7 @@ def train(
model
.
train
()
model
.
train
()
measurements
=
[]
measurements
=
[]
final_loss
:
Optional
[
float
]
=
-
1.0
for
epoch
in
range
(
num_epochs
):
for
epoch
in
range
(
num_epochs
):
epoch_start
=
time
.
monotonic
()
epoch_start
=
time
.
monotonic
()
...
@@ -156,12 +158,14 @@ def train(
...
@@ -156,12 +158,14 @@ def train(
model
.
zero_grad
()
model
.
zero_grad
()
outputs
=
model
(
batch
[
"inputs"
])
outputs
=
model
(
batch
[
"inputs"
])
loss
=
loss_fn
(
outputs
,
batch
[
"label"
])
loss
=
loss_fn
(
outputs
,
batch
[
"label"
])
dist
.
all_reduce
(
loss
,
op
=
dist
.
ReduceOp
.
SUM
)
loss
/=
world_size
loss
/=
world_size
loss
.
backward
()
loss
.
backward
()
dist
.
all_reduce
(
loss
,
op
=
dist
.
ReduceOp
.
SUM
)
return
loss
return
loss
optimizer
.
step
(
closure
)
final_loss
=
optimizer
.
step
(
closure
)
epoch_end
=
time
.
monotonic
()
epoch_end
=
time
.
monotonic
()
...
@@ -176,7 +180,7 @@ def train(
...
@@ -176,7 +180,7 @@ def train(
measurements
.
append
(
data_size
/
(
epoch_end
-
epoch_start
))
measurements
.
append
(
data_size
/
(
epoch_end
-
epoch_start
))
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
print
(
f
"Epoch
{
epoch
}
- processed
{
measurements
[
-
1
]:.
2
f
}
img per sec"
)
print
(
f
"Epoch
{
epoch
}
- processed
{
measurements
[
-
1
]:.
2
f
}
img per sec
. Loss
{
final_loss
}
"
)
torch
.
cuda
.
synchronize
(
rank
)
torch
.
cuda
.
synchronize
(
rank
)
training_stop
=
time
.
monotonic
()
training_stop
=
time
.
monotonic
()
...
...
docs/source/tutorials/oss.rst
View file @
d80c38f9
...
@@ -2,9 +2,9 @@ Optimizer state sharding
...
@@ -2,9 +2,9 @@ Optimizer state sharding
========================
========================
Using
torch
.
nn
.
parallel
.
DistributedDataParallel
leads
to
some
wasted
communications
,
but
it
is
possible
and
makes
OSS
a
drop
in
solution
in
your
existing
torch
distributed
code
.
Using
torch
.
nn
.
parallel
.
DistributedDataParallel
leads
to
some
wasted
communications
,
but
it
is
possible
and
makes
OSS
a
drop
in
solution
in
your
existing
torch
distributed
code
.
Let
's suppose that your trainer looks like
make html
Let
's suppose that your trainer looks like
.. code-block::
default
.. code-block::
python
import torch
import torch
...
@@ -23,7 +23,9 @@ Let's suppose that your trainer looks likemake html
...
@@ -23,7 +23,9 @@ Let's suppose that your trainer looks likemake html
loss = myVeryRelevantLoss()
loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
optimizer = torch.optim.SGD(params=model.parameters(), **base_optimizer_arguments)
optimizer = torch.optim.SGD(
params=model.parameters(),
**base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
model.train()
...
@@ -33,18 +35,17 @@ Let's suppose that your trainer looks likemake html
...
@@ -33,18 +35,17 @@ Let's suppose that your trainer looks likemake html
model.zero_grad()
model.zero_grad()
outputs = model(batch["inputs"])
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss /= world_size
loss.backward()
loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
optimizer.step()
optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows
.. code-block::
default
.. code-block::
python
:emphasize-lines: 49, 65, 66
import torch
import torch
from fairscale.optim.oss import OSS
from fairscale.optim.oss import OSS
...
@@ -61,9 +62,14 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
...
@@ -61,9 +62,14 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
dataloader = mySuperFastDataloader()
dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss()
loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
# ** NEW ** Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
optimizer = OSS(
params=model.parameters(),
optim=base_optimizer,
**base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
model.train()
...
@@ -73,8 +79,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
...
@@ -73,8 +79,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
model.zero_grad()
model.zero_grad()
outputs = model(batch["inputs"])
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss /= world_size
loss.backward()
loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
optimizer.step()
optimizer.step()
fairscale/optim/oss.py
View file @
d80c38f9
...
@@ -25,8 +25,7 @@ else:
...
@@ -25,8 +25,7 @@ else:
class
OSS
(
Optimizer
):
class
OSS
(
Optimizer
):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_.
optimizer and shards its state as described by ZeRO_.
::
:: opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
.. _ZeRO: https://arxiv.org/abs/1910.02054
.. _ZeRO: https://arxiv.org/abs/1910.02054
...
@@ -142,6 +141,14 @@ class OSS(Optimizer):
...
@@ -142,6 +141,14 @@ class OSS(Optimizer):
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
,
**
kwargs
:
Any
)
->
Optional
[
float
]:
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
,
**
kwargs
:
Any
)
->
Optional
[
float
]:
"""Performs a single optimization step (parameter update).
Arguments:
closure (callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
.. note: Any extra parameter is passed to the base optimizer as-is"""
# Sync oss param_groups attributes in case they've been updated by a scheduler.
# Sync oss param_groups attributes in case they've been updated by a scheduler.
self
.
_sync_param_groups
()
self
.
_sync_param_groups
()
...
@@ -162,13 +169,22 @@ class OSS(Optimizer):
...
@@ -162,13 +169,22 @@ class OSS(Optimizer):
return
loss
return
loss
def
local_state_dict
(
self
)
->
dict
:
def
local_state_dict
(
self
)
->
dict
:
""" Gets this rank's state_dict. """
"""Gets this rank's state_dict.
Returns:
The state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
"""
return
self
.
optim
.
state_dict
()
return
self
.
optim
.
state_dict
()
def
consolidate_state_dict
(
self
,
recipient_rank
:
int
=
0
)
->
None
:
def
consolidate_state_dict
(
self
,
recipient_rank
:
int
=
0
)
->
None
:
"""
Update the consolidated state_dict list, one per rank.
"""Update the consolidated state_dict list, one per rank.
This needs to be called on all replicas
"""
.. warning:
This needs to be called on all replicas"""
# Sync lr and other attributes in case its been updated
# Sync lr and other attributes in case its been updated
self
.
_sync_param_groups
()
self
.
_sync_param_groups
()
...
@@ -183,13 +199,14 @@ class OSS(Optimizer):
...
@@ -183,13 +199,14 @@ class OSS(Optimizer):
self
.
_broadcast_state_dict
()
self
.
_broadcast_state_dict
()
def
state_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
state_dict
(
self
)
->
Dict
[
str
,
Any
]:
"""
"""Return the last known global optimizer state, which consist of a list of the shards.
Return the last known global optimizer state, which consist of a list of the shards.
.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.
NOTE:
.. warning:
- If the state has not been consolidated, this returns a shard's worth, not the global state.
Returning the global state is limited to the replica which was responsible for the consolidation.
- Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""
"""
if
len
(
self
.
_all_states
)
==
0
:
if
len
(
self
.
_all_states
)
==
0
:
...
@@ -218,7 +235,10 @@ class OSS(Optimizer):
...
@@ -218,7 +235,10 @@ class OSS(Optimizer):
}
}
def
load_local_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
def
load_local_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
""" Loads this rank's state_dict. """
"""Loads this rank's state_dict.
.. warning: This is not meant to load the global state dict.
"""
self
.
optim
.
load_state_dict
(
state_dict
)
self
.
optim
.
load_state_dict
(
state_dict
)
...
@@ -242,7 +262,12 @@ class OSS(Optimizer):
...
@@ -242,7 +262,12 @@ class OSS(Optimizer):
global_group
[
k
]
=
v
global_group
[
k
]
=
v
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
])
->
None
:
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
])
->
None
:
""" Restore the global parameter groups as well as the shard """
"""Restore the global parameter groups as well as the shard.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`
"""
# Check whether we got a local or global dict
# Check whether we got a local or global dict
if
state_dict
[
"local_state_dict"
]:
if
state_dict
[
"local_state_dict"
]:
...
@@ -256,6 +281,18 @@ class OSS(Optimizer):
...
@@ -256,6 +281,18 @@ class OSS(Optimizer):
self
.
load_local_state_dict
({
"state"
:
state_dict
[
"state"
][
self
.
rank
],
"param_groups"
:
param_groups
})
self
.
load_local_state_dict
({
"state"
:
state_dict
[
"state"
][
self
.
rank
],
"param_groups"
:
param_groups
})
def
add_param_group
(
self
,
param_group
:
dict
)
->
None
:
def
add_param_group
(
self
,
param_group
:
dict
)
->
None
:
"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen layers can be made
trainable and added to the :class:`Optimizer` as training progresses.
Arguments:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options
.. warning: This handles updating the shards on all partitions, but needs to be called on all ranks.
"""
super
().
add_param_group
(
param_group
)
super
().
add_param_group
(
param_group
)
if
not
self
.
in_super_constructor
:
if
not
self
.
in_super_constructor
:
self
.
_partition_parameters
.
clear
()
# Force a re-partitioning
self
.
_partition_parameters
.
clear
()
# Force a re-partitioning
...
@@ -273,9 +310,7 @@ class OSS(Optimizer):
...
@@ -273,9 +310,7 @@ class OSS(Optimizer):
local_group
[
k
]
=
global_group
[
k
]
local_group
[
k
]
=
global_group
[
k
]
def
_collect_sharded_states
(
self
)
->
List
[
Dict
[
str
,
Any
]]:
def
_collect_sharded_states
(
self
)
->
List
[
Dict
[
str
,
Any
]]:
"""
"""Collect all the state shards, in CPU memory."""
Collect all the state shards, in CPU memory.
"""
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
all_states
:
List
[
Dict
[
str
,
Any
]]
=
[]
all_states
:
List
[
Dict
[
str
,
Any
]]
=
[]
...
@@ -304,9 +339,7 @@ class OSS(Optimizer):
...
@@ -304,9 +339,7 @@ class OSS(Optimizer):
return
all_states
return
all_states
def
_broadcast_state_dict
(
self
)
->
None
:
def
_broadcast_state_dict
(
self
)
->
None
:
"""
"""Broadcast this rank's state shard, discard others"""
Broadcast this rank's state shard, discard others
"""
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
for
rank
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
)):
for
rank
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment