Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
8e363567
Commit
8e363567
authored
Jul 31, 2020
by
Benjamin Lefaudeux
Committed by
Mandeep Singh Baines
Jul 31, 2020
Browse files
[feat] Implement OSS save and load of the sharded state from a single replica (#16)
parent
bfba68d8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
261 additions
and
18 deletions
+261
-18
fairscale/optim/oss.py
fairscale/optim/oss.py
+105
-9
fairscale/optim/utils.py
fairscale/optim/utils.py
+70
-0
tests/optim/test_oss.py
tests/optim/test_oss.py
+86
-9
No files found.
fairscale/optim/oss.py
View file @
8e363567
...
@@ -4,11 +4,15 @@
...
@@ -4,11 +4,15 @@
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
copy
import
copy
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
List
,
Optional
,
Type
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Optional
,
Type
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.optim
import
SGD
,
Optimizer
from
torch.optim
import
SGD
,
Optimizer
from
.utils
import
broadcast_object
,
recursive_copy_to_device
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
torch.optim.optimizer
import
_params_t
from
torch.optim.optimizer
import
_params_t
else
:
else
:
...
@@ -17,7 +21,7 @@ else:
...
@@ -17,7 +21,7 @@ else:
class
OSS
(
Optimizer
):
class
OSS
(
Optimizer
):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as describe by ZeRO_.
optimizer and shards its state as describe
d
by ZeRO_.
::
::
opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
...
@@ -54,6 +58,12 @@ class OSS(Optimizer):
...
@@ -54,6 +58,12 @@ class OSS(Optimizer):
param_groups
=
self
.
partition_parameters
()
param_groups
=
self
.
partition_parameters
()
self
.
optim
=
optim
(
param_groups
[
self
.
rank
],
**
defaults
)
self
.
optim
=
optim
(
param_groups
[
self
.
rank
],
**
defaults
)
# Optional consolidated optimizer state
self
.
_all_states
:
List
[
Dict
[
str
,
Any
]]
=
[]
# Current device is set by the parameters allocated to this rank
self
.
_device
=
self
.
partition_parameters
()[
self
.
rank
][
0
][
"params"
][
0
].
device
def
partition_parameters
(
self
)
->
List
[
List
[
dict
]]:
def
partition_parameters
(
self
)
->
List
[
List
[
dict
]]:
"""Partitions parameters across distributed ranks.
"""Partitions parameters across distributed ranks.
...
@@ -73,10 +83,10 @@ class OSS(Optimizer):
...
@@ -73,10 +83,10 @@ class OSS(Optimizer):
param_lists
[
rank
].
append
(
param
)
param_lists
[
rank
].
append
(
param
)
sizes
[
rank
]
+=
param
.
numel
()
sizes
[
rank
]
+=
param
.
numel
()
for
rank
,
params
in
enumerate
(
param_lists
):
for
rank
,
params
in
enumerate
(
param_lists
):
if
len
(
params
):
if
len
(
params
)
>
0
:
p
g
=
copy
.
copy
(
param_group
)
p
aram_group_rank
=
copy
.
copy
(
param_group
)
p
g
[
"params"
]
=
params
p
aram_group_rank
[
"params"
]
=
params
param_groups
[
rank
].
append
(
p
g
)
param_groups
[
rank
].
append
(
p
aram_group_rank
)
return
param_groups
return
param_groups
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
)
->
Optional
[
float
]:
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
)
->
Optional
[
float
]:
...
@@ -87,13 +97,50 @@ class OSS(Optimizer):
...
@@ -87,13 +97,50 @@ class OSS(Optimizer):
dist
.
broadcast
(
param
,
rank
,
group
=
self
.
group
)
dist
.
broadcast
(
param
,
rank
,
group
=
self
.
group
)
return
loss
return
loss
def
state_dict
(
self
)
->
dict
:
def
local_
state_dict
(
self
)
->
dict
:
""" Gets this rank's state_dict. """
""" Gets this rank's state_dict. """
return
self
.
optim
.
state_dict
()
return
self
.
optim
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
def
consolidate_state_dict
(
self
,
recipient_rank
:
int
=
0
)
->
None
:
""" Update the consolidated state_dict list, one per rank.
This needs to be called on all replicas """
if
self
.
rank
==
recipient_rank
:
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
logging
.
debug
(
"Pulling the sharded SGD state from all replicas"
)
self
.
_all_states
=
self
.
_collect_sharded_states
()
else
:
# Acknowledge broadcasts, and send this rank's shard when needed
self
.
_broadcast_state_dict
()
def
state_dict
(
self
)
->
Dict
[
str
,
Any
]:
"""
Return the last known global optimizer state, which consist of a list of the shards.
NOTE: This is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""
assert
(
len
(
self
.
_all_states
)
>
0
),
"The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand"
return
{
"states"
:
self
.
_all_states
}
def
load_local_state_dict
(
self
,
state_dict
:
dict
)
->
None
:
""" Loads this rank's state_dict. """
""" Loads this rank's state_dict. """
self
.
optim
.
load_state_dict
(
state_dict
)
# Make sure that the state is on the appropriate device
state_dict_ondevice
=
recursive_copy_to_device
(
state_dict
,
non_blocking
=
False
,
device
=
self
.
_device
)
self
.
optim
.
load_state_dict
(
state_dict_ondevice
)
def
load_state_dict
(
self
,
state_dict
:
Dict
[
str
,
Any
])
->
None
:
""" Loads this rank's optimizer state_dict, given the global optimizer state. """
# Dispatch this rank's state dictionary to the local load
self
.
load_local_state_dict
(
state_dict
[
"states"
][
self
.
rank
])
def
add_param_group
(
self
,
param_group
:
dict
)
->
None
:
def
add_param_group
(
self
,
param_group
:
dict
)
->
None
:
super
().
add_param_group
(
param_group
)
super
().
add_param_group
(
param_group
)
...
@@ -101,3 +148,52 @@ class OSS(Optimizer):
...
@@ -101,3 +148,52 @@ class OSS(Optimizer):
param_groups
=
self
.
partition_parameters
()[
self
.
rank
]
param_groups
=
self
.
partition_parameters
()[
self
.
rank
]
if
len
(
param_groups
)
==
len
(
self
.
optim
.
param_groups
)
+
1
:
if
len
(
param_groups
)
==
len
(
self
.
optim
.
param_groups
)
+
1
:
self
.
optim
.
add_param_group
(
param_groups
[
-
1
])
self
.
optim
.
add_param_group
(
param_groups
[
-
1
])
def
_collect_sharded_states
(
self
)
->
List
[
Dict
[
str
,
Any
]]:
"""
Collect all the state shards, in CPU memory.
"""
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
all_states
:
List
[
Dict
[
str
,
Any
]]
=
[]
for
rank
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
)):
if
rank
==
self
.
rank
:
logging
.
debug
(
"Saving self state"
)
all_states
.
append
(
recursive_copy_to_device
(
self
.
local_state_dict
(),
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
))
)
# Sync with other replicas
broadcast_object
(
empty_buffer
,
src_rank
=
rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
else
:
# Fetch the optim state from the other replicas
logging
.
debug
(
"Receiving state from rank %s "
,
rank
)
replica_state
=
broadcast_object
(
empty_buffer
,
src_rank
=
rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
all_states
.
append
(
recursive_copy_to_device
(
replica_state
,
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
))
)
logging
.
debug
(
"State from rank %s received"
,
rank
)
return
all_states
def
_broadcast_state_dict
(
self
)
->
None
:
"""
Broadcast this rank's state shard, discard others
"""
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
for
rank
in
range
(
dist
.
get_world_size
(
group
=
self
.
group
)):
if
rank
==
self
.
rank
:
# Send the state to the reference replica
logging
.
debug
(
"Sending the sharded SGD state to the reference replica from rank %s"
,
rank
,
)
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
else
:
# Discard this tensor/rank, broadcast necessary for syncing
logging
.
debug
(
"Discarding broadcast from rank %s"
,
rank
)
broadcast_object
(
empty_buffer
,
src_rank
=
rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
fairscale/optim/utils.py
0 → 100644
View file @
8e363567
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
io
from
typing
import
Any
,
Dict
import
torch
from
torch._six
import
container_abcs
import
torch.distributed
as
dist
# Credits: classy_vision/generic/distributed_util.py
def
recursive_copy_to_device
(
value
:
Any
,
non_blocking
:
bool
,
device
:
torch
.
device
)
->
Any
:
"""
Recursively searches lists, tuples, dicts and copies tensors to device if
possible. Non-tensor values are passed as-is in the result.
NOTE: These are all copies, so if there are two objects that reference
the same object, then after this call, there will be two different objects
referenced on the device.
"""
if
isinstance
(
value
,
torch
.
Tensor
):
return
value
.
to
(
device
,
non_blocking
=
non_blocking
)
if
isinstance
(
value
,
(
list
,
tuple
)):
values
=
[]
for
val
in
value
:
values
.
append
(
recursive_copy_to_device
(
val
,
non_blocking
=
non_blocking
,
device
=
device
))
return
values
if
isinstance
(
value
,
list
)
else
tuple
(
values
)
if
isinstance
(
value
,
container_abcs
.
Mapping
):
device_val
:
Dict
[
str
,
Any
]
=
{}
for
key
,
val
in
value
.
items
():
device_val
[
key
]
=
recursive_copy_to_device
(
val
,
non_blocking
=
non_blocking
,
device
=
device
)
return
device_val
return
value
def
broadcast_object
(
obj
:
Any
,
src_rank
:
int
,
group
:
object
=
dist
.
group
.
WORLD
,
dist_device
:
torch
.
device
=
torch
.
device
(
"cpu"
)
)
->
Any
:
"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
"""
if
dist
.
get_rank
()
==
src_rank
:
# Emit data
buffer
=
io
.
BytesIO
()
torch
.
save
(
obj
,
buffer
)
# type: ignore
data
=
bytearray
(
buffer
.
getbuffer
())
length_tensor
=
torch
.
LongTensor
([
len
(
data
)]).
to
(
dist_device
)
data_send_tensor
=
torch
.
ByteTensor
(
data
).
to
(
dist_device
)
dist
.
broadcast
(
length_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
dist
.
broadcast
(
data_send_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
else
:
# Fetch from the source
length_tensor
=
torch
.
LongTensor
([
0
]).
to
(
dist_device
)
dist
.
broadcast
(
length_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
data_recv_tensor
=
torch
.
empty
([
int
(
length_tensor
.
item
())],
dtype
=
torch
.
uint8
,
device
=
dist_device
)
dist
.
broadcast
(
data_recv_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
buffer
=
io
.
BytesIO
(
data_recv_tensor
.
cpu
().
numpy
())
obj
=
torch
.
load
(
buffer
,
map_location
=
dist_device
)
# type: ignore
return
obj
tests/optim/test_oss.py
View file @
8e363567
...
@@ -3,6 +3,10 @@
...
@@ -3,6 +3,10 @@
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
# pylint: disable=missing-module-docstring
# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring
import
os
import
os
import
pytest
import
pytest
...
@@ -14,17 +18,20 @@ import fairscale.optim as optim
...
@@ -14,17 +18,20 @@ import fairscale.optim as optim
skip_if_no_cuda
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
skip_if_no_cuda
=
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"cuda required"
)
BACKEND
=
dist
.
Backend
.
NCCL
if
torch
.
cuda
.
is_available
()
else
dist
.
Backend
.
GLOO
# type: ignore
DEVICE
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
torch
.
device
(
"cpu"
)
def
setup_module
(
module
):
def
setup_module
(
module
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29500"
os
.
environ
[
"MASTER_PORT"
]
=
"29500"
dist
.
init_process_group
(
backend
=
"nccl"
,
rank
=
0
,
world_size
=
1
)
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
0
,
world_size
=
1
)
def
dist_init
(
rank
,
world_size
):
def
dist_init
(
rank
,
world_size
):
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
dist
.
init_process_group
(
backend
=
"nccl"
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_process_group
(
backend
=
BACKEND
,
rank
=
rank
,
world_size
=
world_size
)
def
test_create
():
def
test_create
():
...
@@ -32,17 +39,29 @@ def test_create():
...
@@ -32,17 +39,29 @@ def test_create():
o
=
optim
.
OSS
(
params
,
lr
=
0.01
)
o
=
optim
.
OSS
(
params
,
lr
=
0.01
)
@
skip_if_no_cuda
def
test_state_dict
():
def
test_state_dict
():
x
=
torch
.
tensor
([
1.0
],
device
=
"cuda"
,
requires_grad
=
True
)
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
o
.
consolidate_state_dict
()
# Sync state dict in between replicas - even if there are none
state_dict
=
o
.
state_dict
()
state_dict
=
o
.
state_dict
()
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
o
.
load_state_dict
(
state_dict
)
o
.
load_state_dict
(
state_dict
)
# We should now be using a lr of 0.1.
# We should now be using a lr of 0.1.
x
.
backward
()
x
.
backward
()
o
.
step
()
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
"cuda"
)
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
test_local_state_dict
():
x
=
torch
.
tensor
([
1.0
],
device
=
DEVICE
,
requires_grad
=
True
)
o
=
optim
.
OSS
([
x
],
lr
=
0.1
)
local_state_dict
=
o
.
local_state_dict
()
o
=
optim
.
OSS
([
x
],
lr
=
0.01
)
o
.
load_local_state_dict
(
local_state_dict
)
# We should now be using a lr of 0.1.
x
.
backward
()
o
.
step
()
assert
x
==
torch
.
tensor
([
0.9
],
device
=
DEVICE
)
def
run_test_add_param_group
(
rank
,
world_size
):
def
run_test_add_param_group
(
rank
,
world_size
):
...
@@ -57,9 +76,9 @@ def run_test_add_param_group(rank, world_size):
...
@@ -57,9 +76,9 @@ def run_test_add_param_group(rank, world_size):
# Verify that added group is added to the correct partition making all have 8 elements.
# Verify that added group is added to the correct partition making all have 8 elements.
assert
sum
([
x
.
numel
()
for
g
in
o
.
optim
.
param_groups
for
x
in
g
[
"params"
]])
==
8
assert
sum
([
x
.
numel
()
for
g
in
o
.
optim
.
param_groups
for
x
in
g
[
"params"
]])
==
8
if
rank
==
1
:
if
rank
==
1
:
len
(
o
.
optim
.
param_groups
)
==
2
assert
len
(
o
.
optim
.
param_groups
)
==
2
else
:
else
:
len
(
o
.
optim
.
param_groups
)
==
1
assert
len
(
o
.
optim
.
param_groups
)
==
1
def
test_add_param_group
():
def
test_add_param_group
():
...
@@ -81,7 +100,6 @@ def run_test_zero_grad(rank, world_size):
...
@@ -81,7 +100,6 @@ def run_test_zero_grad(rank, world_size):
assert
not
m
.
bias
.
grad
assert
not
m
.
bias
.
grad
@
skip_if_no_cuda
def
test_zero_grad
():
def
test_zero_grad
():
world_size
=
2
world_size
=
2
mp
.
spawn
(
run_test_zero_grad
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
mp
.
spawn
(
run_test_zero_grad
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
...
@@ -111,8 +129,9 @@ def test_step():
...
@@ -111,8 +129,9 @@ def test_step():
mp
.
spawn
(
run_test_step
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
mp
.
spawn
(
run_test_step
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
def
run_test_step_with_closure
(
rank
,
world_size
):
def
run_test_step_with_closure
(
rank
,
world_size
,
optimizer
=
None
):
dist_init
(
rank
,
world_size
)
dist_init
(
rank
,
world_size
)
x_val
=
rank
+
1
x_val
=
rank
+
1
weight
=
1.0
weight
=
1.0
bias
=
2.0
bias
=
2.0
...
@@ -125,7 +144,9 @@ def run_test_step_with_closure(rank, world_size):
...
@@ -125,7 +144,9 @@ def run_test_step_with_closure(rank, world_size):
m
.
weight
.
data
=
torch
.
tensor
([[
weight
]])
m
.
weight
.
data
=
torch
.
tensor
([[
weight
]])
m
.
bias
.
data
=
torch
.
tensor
([
bias
])
m
.
bias
.
data
=
torch
.
tensor
([
bias
])
m
.
to
(
rank
)
m
.
to
(
rank
)
o
=
optim
.
OSS
(
m
.
parameters
(),
lr
=
0.1
)
o
=
optim
.
OSS
(
m
.
parameters
(),
lr
=
0.1
)
y
=
m
(
x
)
y
=
m
(
x
)
y
.
backward
(
x
)
y
.
backward
(
x
)
for
p
in
m
.
parameters
():
for
p
in
m
.
parameters
():
...
@@ -164,3 +185,59 @@ def run_test_sharding(rank, world_size):
...
@@ -164,3 +185,59 @@ def run_test_sharding(rank, world_size):
def
test_sharding
():
def
test_sharding
():
world_size
=
3
world_size
=
3
mp
.
spawn
(
run_test_sharding
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
mp
.
spawn
(
run_test_sharding
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
)
def
run_test_collect_shards
(
rank
,
world_size
,
reference_rank
):
dist_init
(
rank
,
world_size
)
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
# Run a dummy step so that the optimizer state dict exists
batch
,
input_width
,
hidden
,
target_width
=
3
,
20
,
10
,
5
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
input_width
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
target_width
))
model
.
to
(
device
)
loss_fn
=
torch
.
nn
.
L1Loss
()
loss_fn
.
to
(
device
)
# With SGD, Momentum is required to get a state to shard
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
def
closure
():
optimizer
.
zero_grad
()
output
=
model
(
inputs
)
loss
=
loss_fn
(
output
,
target
)
loss
.
backward
()
return
loss
_
=
optimizer
.
step
(
closure
=
closure
)
# Update the optimizer state on the reference rank
optimizer
.
consolidate_state_dict
(
recipient_rank
=
reference_rank
)
# Fetch the state on the reference rank
# - check that it has the correct size
# - load it again
if
rank
==
reference_rank
:
optimizer_state_dict
=
optimizer
.
state_dict
()
assert
len
(
optimizer_state_dict
[
"states"
])
==
world_size
else
:
optimizer_state_dict
=
{}
optimizer_state_dict
=
optim
.
utils
.
broadcast_object
(
optimizer_state_dict
,
src_rank
=
reference_rank
,
group
=
dist
.
group
.
WORLD
,
dist_device
=
device
)
# Load the optimizer state dict
optimizer
.
load_state_dict
(
optimizer_state_dict
)
def
test_collect_shards
():
world_size
=
3
reference_rank
=
0
mp
.
spawn
(
run_test_collect_shards
,
args
=
(
world_size
,
reference_rank
),
nprocs
=
world_size
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment