Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
3399e97c
Unverified
Commit
3399e97c
authored
Jan 08, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 08, 2021
Browse files
[refactor][OSS] Removing ad-hoc object broadcast, use pytorch's (#297)
parent
9faad392
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
82 deletions
+54
-82
fairscale/optim/oss.py
fairscale/optim/oss.py
+50
-48
fairscale/optim/utils.py
fairscale/optim/utils.py
+0
-30
stubs/torch/distributed/__init__.pyi
stubs/torch/distributed/__init__.pyi
+1
-0
tests/optim/test_oss.py
tests/optim/test_oss.py
+3
-4
No files found.
fairscale/optim/oss.py
View file @
3399e97c
...
@@ -16,7 +16,7 @@ import torch.distributed as dist
...
@@ -16,7 +16,7 @@ import torch.distributed as dist
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
torch.optim
import
SGD
,
Optimizer
from
torch.optim
import
SGD
,
Optimizer
from
.utils
import
Bucket
,
Workhandle
,
broadcast_object
,
recursive_copy_to_device
from
.utils
import
Bucket
,
Workhandle
,
recursive_copy_to_device
__all__
=
[
"OSS"
]
__all__
=
[
"OSS"
]
...
@@ -320,6 +320,55 @@ class OSS(Optimizer):
...
@@ -320,6 +320,55 @@ class OSS(Optimizer):
# Acknowledge broadcasts, and send this rank's shard when needed
# Acknowledge broadcasts, and send this rank's shard when needed
self
.
_broadcast_state_dict
()
self
.
_broadcast_state_dict
()
def
_broadcast_state_dict
(
self
)
->
None
:
"""Broadcast this rank's state shard, discard others"""
# Default to CPU space to gain some memory headroom
local_cpu_state
=
recursive_copy_to_device
(
self
.
local_state_dict
(),
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
)
)
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
# Send the state to the reference replica
logging
.
debug
(
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
)
dist
.
broadcast_object_list
([
local_cpu_state
],
src
=
self
.
global_rank
,
group
=
self
.
group
)
else
:
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
# Discard this tensor/rank, broadcast necessary for syncing and because NCCL does not support gather
dist
.
broadcast_object_list
([
0
],
src
=
global_rank
,
group
=
self
.
group
)
def
_collect_sharded_states
(
self
)
->
List
[
Dict
[
str
,
Any
]]:
"""Collect all the state shards, in CPU memory."""
all_states
=
[]
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
logging
.
debug
(
"Saving self state"
)
all_states
.
append
(
recursive_copy_to_device
(
self
.
local_state_dict
(),
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
))
)
# Sync with other replicas
dist
.
broadcast_object_list
([
0
],
src
=
self
.
global_rank
,
group
=
self
.
group
)
else
:
# Fetch the optim state from the other replicas
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
replica_state
=
[
0
]
dist
.
broadcast_object_list
(
replica_state
,
src
=
global_rank
,
group
=
self
.
group
)
all_states
.
append
(
recursive_copy_to_device
(
replica_state
[
0
],
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
))
)
logging
.
debug
(
"State from rank %s received"
,
rank
)
return
all_states
def
state_dict
(
self
)
->
Dict
[
str
,
Any
]:
def
state_dict
(
self
)
->
Dict
[
str
,
Any
]:
"""Return the last known global optimizer state, which consist of a list of the shards.
"""Return the last known global optimizer state, which consist of a list of the shards.
...
@@ -466,53 +515,6 @@ class OSS(Optimizer):
...
@@ -466,53 +515,6 @@ class OSS(Optimizer):
elif
k
in
global_group
.
keys
():
elif
k
in
global_group
.
keys
():
local_group
[
k
]
=
global_group
[
k
]
local_group
[
k
]
=
global_group
[
k
]
def
_collect_sharded_states
(
self
)
->
List
[
Dict
[
str
,
Any
]]:
"""Collect all the state shards, in CPU memory."""
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
all_states
:
List
[
Dict
[
str
,
Any
]]
=
[]
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
logging
.
debug
(
"Saving self state"
)
all_states
.
append
(
recursive_copy_to_device
(
self
.
local_state_dict
(),
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
))
)
# Sync with other replicas
broadcast_object
(
empty_buffer
,
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
else
:
# Fetch the optim state from the other replicas
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
replica_state
=
broadcast_object
(
empty_buffer
,
src_rank
=
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
all_states
.
append
(
recursive_copy_to_device
(
replica_state
,
non_blocking
=
True
,
device
=
torch
.
device
(
"cpu"
))
)
logging
.
debug
(
"State from rank %s received"
,
rank
)
return
all_states
def
_broadcast_state_dict
(
self
)
->
None
:
"""Broadcast this rank's state shard, discard others"""
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
# Send the state to the reference replica
logging
.
debug
(
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
)
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
self
.
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
else
:
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
# Discard this tensor/rank, broadcast necessary for syncing
broadcast_object
(
empty_buffer
,
src_rank
=
global_rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
def
_broadcast_params
(
self
)
->
None
:
def
_broadcast_params
(
self
)
->
None
:
"""Helper function to broadcast all the parameters from a given device"""
"""Helper function to broadcast all the parameters from a given device"""
...
...
fairscale/optim/utils.py
View file @
3399e97c
...
@@ -3,12 +3,10 @@
...
@@ -3,12 +3,10 @@
# This source code is licensed under the BSD license found in the
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
import
io
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
import
torch
import
torch
from
torch._six
import
container_abcs
from
torch._six
import
container_abcs
import
torch.distributed
as
dist
class
Workhandle
:
class
Workhandle
:
...
@@ -128,31 +126,3 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
...
@@ -128,31 +126,3 @@ def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.devic
return
device_val
return
device_val
return
value
return
value
def
broadcast_object
(
obj
:
Any
,
src_rank
:
int
,
group
:
object
=
dist
.
group
.
WORLD
,
dist_device
:
torch
.
device
=
torch
.
device
(
"cpu"
)
)
->
Any
:
"""
Either broadcast from master to the fleet (default),
or use the src setting as the original rank.
"""
if
dist
.
get_rank
()
==
src_rank
:
# Emit data
buffer
=
io
.
BytesIO
()
torch
.
save
(
obj
,
buffer
)
data
=
bytearray
(
buffer
.
getbuffer
())
length_tensor
=
torch
.
LongTensor
([
len
(
data
)]).
to
(
dist_device
)
data_send_tensor
=
torch
.
ByteTensor
(
data
).
to
(
dist_device
)
dist
.
broadcast
(
length_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
dist
.
broadcast
(
data_send_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
else
:
# Fetch from the source
length_tensor
=
torch
.
LongTensor
([
0
]).
to
(
dist_device
)
dist
.
broadcast
(
length_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
data_recv_tensor
=
torch
.
empty
([
int
(
length_tensor
.
item
())],
dtype
=
torch
.
uint8
,
device
=
dist_device
)
dist
.
broadcast
(
data_recv_tensor
,
src
=
src_rank
,
group
=
group
,
async_op
=
False
)
buffer
=
io
.
BytesIO
(
data_recv_tensor
.
cpu
().
numpy
())
obj
=
torch
.
load
(
buffer
,
map_location
=
dist_device
)
return
obj
stubs/torch/distributed/__init__.pyi
View file @
3399e97c
...
@@ -32,6 +32,7 @@ def get_backend(group: Optional[Any] = None) -> Any: ...
...
@@ -32,6 +32,7 @@ def get_backend(group: Optional[Any] = None) -> Any: ...
def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def gather(tensor: Tensor, gather_list: Optional[List[Tensor]], dst: Any, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def gather(tensor: Tensor, gather_list: Optional[List[Tensor]], dst: Any, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def broadcast_object_list(object_list: List[Any], src: int, group:Optional[ProcessGroup] = None): ...
def is_initialized() -> bool: ...
def is_initialized() -> bool: ...
...
...
tests/optim/test_oss.py
View file @
3399e97c
...
@@ -392,12 +392,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
...
@@ -392,12 +392,11 @@ def run_test_collect_shards(rank, world_size, reference_rank, tempfile_name):
else
:
else
:
optimizer_state_dict
=
{}
optimizer_state_dict
=
{}
optimizer_state_dict
=
optim
.
utils
.
broadcast_object
(
optim_state
=
[
optimizer_state_dict
]
optimizer_state_dict
,
src_rank
=
reference_rank
,
group
=
dist
.
group
.
WORLD
,
dist_device
=
device
dist
.
broadcast_object_list
(
optim_state
,
src
=
reference_rank
,
group
=
dist
.
group
.
WORLD
)
)
# Load the optimizer state dict
# Load the optimizer state dict
optimizer
.
load_state_dict
(
optim
izer
_state
_dict
)
optimizer
.
load_state_dict
(
optim_state
[
0
]
)
dist
.
destroy_process_group
()
dist
.
destroy_process_group
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment