Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
379c6bf0
Unverified
Commit
379c6bf0
authored
Oct 01, 2020
by
Joshua Meier
Committed by
GitHub
Oct 01, 2020
Browse files
Support optimizer state sharding for megatron (#121)
support optimizer state sharding for megatron
parent
1c2a6f6b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
11 deletions
+25
-11
fairscale/optim/oss.py
fairscale/optim/oss.py
+25
-11
No files found.
fairscale/optim/oss.py
View file @
379c6bf0
...
...
@@ -72,6 +72,8 @@ class OSS(Optimizer):
self
.
world_size
=
dist
.
get_world_size
(
self
.
group
)
self
.
rank
=
dist
.
get_rank
(
self
.
group
)
self
.
global_rank
=
self
.
get_global_rank
(
self
.
group
,
self
.
rank
)
self
.
optim
=
optim
(
self
.
partition_parameters
()[
self
.
rank
],
**
default
)
# Optional consolidated optimizer state
...
...
@@ -88,7 +90,7 @@ class OSS(Optimizer):
# Partition helpers
def
partition_parameters
(
self
)
->
List
[
List
[
dict
]]:
"""Partitions parameters across distributed ranks.
"""Partitions parameters across distributed
data parallel
ranks.
Returns a list of param_groups (which is a list of dict) where each
element of the list contains the param_groups for a rank. Element 0
...
...
@@ -135,6 +137,7 @@ class OSS(Optimizer):
@
property
def
param_to_rank
(
self
)
->
Dict
[
torch
.
Tensor
,
int
]:
'''param to data parallel rank'''
if
len
(
self
.
_param_rank
)
==
0
:
for
rank
,
param_groups
in
enumerate
(
self
.
partition_parameters
()):
for
param_group
in
param_groups
:
...
...
@@ -142,6 +145,13 @@ class OSS(Optimizer):
self
.
_param_rank
[
param
]
=
rank
return
self
.
_param_rank
def
get_global_rank
(
self
,
group
,
rank
):
if
group
is
dist
.
group
.
WORLD
:
return
rank
else
:
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
return
global_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def
step
(
self
,
closure
:
Optional
[
Callable
[[],
float
]]
=
None
,
**
kwargs
:
Any
)
->
Optional
[
float
]:
...
...
@@ -174,14 +184,15 @@ class OSS(Optimizer):
# Gloo will rightly assert on this operation for any tensor that requires grad.
# We save and restore the grad requirement state to work around that, in our case
# the grad is only useful on the source rank.
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
requires_grad
.
append
((
param
,
param
.
requires_grad
))
param
.
requires_grad
=
False
requests
.
append
(
dist
.
broadcast
(
tensor
=
param
,
src
=
rank
,
group
=
self
.
group
,
async_op
=
True
))
requests
.
append
(
dist
.
broadcast
(
tensor
=
param
,
src
=
global_
rank
,
group
=
self
.
group
,
async_op
=
True
))
for
fut
,
req_grad
in
zip
(
requests
,
requires_grad
):
fut
.
wait
()
req_grad
[
0
].
requires_grad
=
req_grad
[
1
]
return
loss
def
local_state_dict
(
self
)
->
dict
:
...
...
@@ -330,7 +341,7 @@ class OSS(Optimizer):
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
all_states
:
List
[
Dict
[
str
,
Any
]]
=
[]
for
rank
in
range
(
dist
.
get_
world_size
(
group
=
self
.
group
)
):
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
logging
.
debug
(
"Saving self state"
)
all_states
.
append
(
...
...
@@ -338,12 +349,13 @@ class OSS(Optimizer):
)
# Sync with other replicas
broadcast_object
(
empty_buffer
,
src_rank
=
rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
broadcast_object
(
empty_buffer
,
src_rank
=
self
.
global_
rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
else
:
# Fetch the optim state from the other replicas
logging
.
debug
(
"Receiving state from rank %s "
,
rank
)
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
logging
.
debug
(
"Receiving state from rank %s "
,
global_rank
)
replica_state
=
broadcast_object
(
empty_buffer
,
src_rank
=
rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
empty_buffer
,
src_rank
=
global_
rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
all_states
.
append
(
...
...
@@ -358,17 +370,18 @@ class OSS(Optimizer):
"""Broadcast this rank's state shard, discard others"""
empty_buffer
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
_device
)
for
rank
in
range
(
dist
.
get_
world_size
(
group
=
self
.
group
)
):
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
# Send the state to the reference replica
logging
.
debug
(
"Sending the sharded optimizer state to the reference replica from rank %s"
,
rank
,
)
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
broadcast_object
(
self
.
local_state_dict
(),
src_rank
=
self
.
global_
rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
else
:
global_rank
=
self
.
get_global_rank
(
self
.
group
,
rank
)
# Discard this tensor/rank, broadcast necessary for syncing
logging
.
debug
(
"Discarding broadcast from rank %s"
,
rank
)
broadcast_object
(
empty_buffer
,
src_rank
=
rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
logging
.
debug
(
"Discarding broadcast from rank %s"
,
global_
rank
)
broadcast_object
(
empty_buffer
,
src_rank
=
global_
rank
,
group
=
self
.
group
,
dist_device
=
self
.
_device
)
def
_free_other_grads
(
self
)
->
None
:
"""Free all the gradients only useful for the other ranks
...
...
@@ -380,3 +393,4 @@ class OSS(Optimizer):
for
p
in
partition
:
for
t
in
p
[
"params"
]:
t
.
grad
=
None
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment