Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
3399e97c
Unverified
Commit
3399e97c
authored
Jan 08, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 08, 2021
Browse files
[refactor][OSS] Removing ad-hoc object broadcast, use pytorch's (#297)
parent
9faad392
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
82 deletions
+54
-82
fairscale/optim/oss.py
fairscale/optim/oss.py
+50
-48
fairscale/optim/utils.py
fairscale/optim/utils.py
+0
-30
stubs/torch/distributed/__init__.pyi
stubs/torch/distributed/__init__.pyi
+1
-0
tests/optim/test_oss.py
tests/optim/test_oss.py
+3
-4
No files found.
fairscale/optim/oss.py
View file @
3399e97c
...
@@ -16,7 +16,7 @@ import torch.distributed as dist
...
@@ -16,7 +16,7 @@ import torch.distributed as dist
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
torch.optim
import
SGD
,
Optimizer
from
torch.optim
import
SGD
,
Optimizer
from
.utils
import
Bucket
,
Workhandle
,
broadcast_object
,
recursive_copy_to_device
from
.utils
import
Bucket
,
Workhandle
,
recursive_copy_to_device
__all__
=
[
"OSS"
]
__all__
=
[
"OSS"
]
...
@@ -320,6 +320,55 @@ class OSS(Optimizer):
...
@@ -320,6 +320,55 @@ class OSS(Optimizer):
# Acknowledge broadcasts, and send this rank's shard when needed
# Acknowledge broadcasts, and send this rank's shard when needed
self
.
_broadcast_state_dict
()
self
.
_broadcast_state_dict
()
def
_broadcast_state_dict
(
self
)
->
None
:
"""Broadcast this rank's state shard, discard others"""
# Default to CPU space to gain some memory headroom
local_cpu_state
=
recursive_copy_to_device
(
self
.
local_state_dict
(),
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
)
)
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
# Send the state to the reference replica
logging
.
debug
(
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
)
dist
.
broadcast_object_list
([
local_cpu_state
],
src
=
self
.
global_rank
,
group
=
self
.
group
)
else
:
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
dist
.
broadcast_object_list
([
0
],
src
=
global_rank
,
group
=
self
.
group
)
def
_collect_sharded_states
(
self
)
->
List
[
Dict
[
str
,
Any
]]:
"""Collect all the state shards, in CPU memory."""
all_states
=
[]
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
logging
.
debug
(
"Saving self state"
)
all_states
.
append
(
recursive_copy_to_device
(
self
.
local_state_dict
(),
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
))
)
# Sync with other replicas
dist
.
broadcast_object_list
([
0
],
src
=
self
.
global_rank
,
group
=
self
.
group
)
else
:
# Fetch the optim state from the other replicas
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
replica_state
=
[
0
]
dist
.
broadcast_object_list
(
replica_state
,
src
=
global_rank
,
group
=
self
.
group
)
all_states
.
append
(
recursive_copy_to_device
(
replica_state
[
0
],
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
))
)
logging
.
debug
(
"State from rank %s received"
,
rank
)
return
all_states
def
state_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
state_dict
(
self
)
->
Dict
[
str
,
Any
]:
"""Return the last known global optimizer state, which consist of a list of the shards.
"""Return the last known global optimizer state, which consist of a list of the shards.
...
@@ -466,53 +515,6 @@ class OSS(Optimizer):
...
@@ -466,53 +515,6 @@ class OSS(Optimizer):
elif
k
in
global_group
.
keys
():
elif
k
in
global_group
.
keys
():
local_group
[
k
]
=
global_group
[
k
]
local_group
[
k
]
=
global_group
[
k
]
def
_collect_sharded_states
(
self
)
->
List
[
Dict
[
str
,
Any
]]:
"""Collect all the state shards, in CPU memory."""
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
all_states
:
List
[
Dict
[
str
,
Any
]]
=
[]
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
logging
.
debug
(
"Saving self state"
)
all_states
.
append
(
recursive_copy_to_device
(
self
.
local_state_dict
(),
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
))
)
# Sync with other replicas
broadcast_object
(
empty_buffer
,
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
else
:
# Fetch the optim state from the other replicas
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
replica_state
=
broadcast_object
(
empty_buffer
,
src_rank
=
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
all_states
.
append
(
recursive_copy_to_device
(
replica_state
,
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
))
)
logging
.
debug
(
"State from rank %s received"
,
rank
)
return
all_states
def
_broadcast_state_dict
(
self
)
->
None
:
"""Broadcast this rank's state shard, discard others"""
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
# Send the state to the reference replica
logging
.
debug
(
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
)
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
else
:
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
# Discard this tensor/rank, broadcast necessary for syncing
broadcast_object
(
empty_buffer
,
src_rank
=
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
def
_broadcast_params
(
self
)
->
None
:
def
_broadcast_params
(
self
)
->
None
:
"""Helper function to broadcast all the parameters from a given device"""
"""Helper function to broadcast all the parameters from a given device"""
...
...
fairscale/optim/utils.py
View file @
3399e97c
...
@@ -3,12 +3,10 @@
...
@@ -3,12 +3,10 @@
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
io
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch
from
torch._six
import
container_abcs
from
torch._six
import
container_abcs
import
torch.distributed
as
dist
class
Workhandle
:
class
Workhandle
:
...
@@ -128,31 +126,3 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
...
@@ -128,31 +126,3 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return
device_val
return
device_val
return
value
return
value
def
broadcast_object
(
obj
:
Any
,
src_rank
:
int
,
group
:
object
=
dist
.
group
.
WORLD
,
dist_device
:
torch
.
device
=
torch
.
device
(
"cpu"
)
)
->
Any
:
"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
"""
if
dist
.
get_rank
()
==
src_rank
:
# Emit data
buffer
=
io
.
BytesIO
()
torch
.
save
(
obj
,
buffer
)
data
=
bytearray
(
buffer
.
getbuffer
())
length_tensor
=
torch
.
LongTensor
([
len
(
data
)]).
to
(
dist_device
)
data_send_tensor
=
torch
.
ByteTensor
(
data
).
to
(
dist_device
)
dist
.
broadcast
(
length_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
dist
.
broadcast
(
data_send_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
else
:
# Fetch from the source
length_tensor
=
torch
.
LongTensor
([
0
]).
to
(
dist_device
)
dist
.
broadcast
(
length_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
data_recv_tensor
=
torch
.
empty
([
int
(
length_tensor
.
item
())],
dtype
=
torch
.
uint8
,
device
=
dist_device
)
dist
.
broadcast
(
data_recv_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
buffer
=
io
.
BytesIO
(
data_recv_tensor
.
cpu
().
numpy
())
obj
=
torch
.
load
(
buffer
,
map_location
=
dist_device
)
return
obj
stubs/torch/distributed/__init__.pyi
View file @
3399e97c
...
@@ -32,6 +32,7 @@ def get_backend(group: Optional[Any] = None) -> Any: ...
...
@@ -32,6 +32,7 @@ def get_backend(group: Optional[Any] = None) -> Any: ...
def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def gather(tensor: Tensor, gather_list: Optional[List[Tensor]], dst: Any, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def gather(tensor: Tensor, gather_list: Optional[List[Tensor]], dst: Any, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def broadcast_object_list(object_list: List[Any], src: int, group:Optional[ProcessGroup] = None): ...
def is_initialized() -> bool: ...
def is_initialized() -> bool: ...
...
...
tests/optim/test_oss.py
View file @
3399e97c
...
@@ -392,12 +392,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
...
@@ -392,12 +392,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
else
:
else
:
optimizer_state_dict
=
{}
optimizer_state_dict
=
{}
optimizer_state_dict
=
optim
.
utils
.
broadcast_object
(
optim_state
=
[
optimizer_state_dict
]
optimizer_state_dict
,
src_rank
=
reference_rank
,
group
=
dist
.
group
.
WORLD
,
dist_device
=
device
dist
.
broadcast_object_list
(
optim_state
,
src
=
reference_rank
,
group
=
dist
.
group
.
WORLD
)
)
# Load the optimizer state dict
# Load the optimizer state dict
optimizer
.
load_state_dict
(
optim
izer
_state
_dict
)
optimizer
.
load_state_dict
(
optim_state
[
0
]
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment