Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a82825db
"vscode:/vscode.git/clone" did not exist on "e489abc684c864f1da010d56e3ca66dbd2df82fb"
Unverified
Commit
a82825db
authored
Apr 13, 2021
by
Sam Shleifer
Committed by
GitHub
Apr 13, 2021
Browse files
[FSDP] use all_gather for 10X OSD consolidation speedup (#595)
parent
4726d5be
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
127 additions
and
85 deletions
+127
-85
fairscale/nn/data_parallel/fsdp_optim_utils.py
fairscale/nn/data_parallel/fsdp_optim_utils.py
+43
-37
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+68
-44
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+8
-4
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
+8
-0
No files found.
fairscale/nn/data_parallel/fsdp_optim_utils.py
View file @
a82825db
...
...
@@ -4,10 +4,12 @@
# LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import
copy
from
typing
import
Dict
,
Generator
,
List
,
Tuple
from
typing
import
Any
,
Dict
,
Generator
,
List
,
Tuple
import
torch
# These return keys are used by fairseq. To change, add @sshleifer as a reviewer.
UNFLAT_RETURN_KEYS
=
{
"state"
,
"param_groups"
,
"uncollected_local_ids"
,
"param_id_map"
}
# This function helps shard a full optimizer state dict
def
flatten_optim_state_dict
(
sd
:
Dict
)
->
Dict
:
...
...
@@ -16,6 +18,7 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
num_local_params
=
len
(
set
(
param_id_map
.
values
()))
if
sd
[
"state"
]:
new_state
:
Dict
=
{
local_id
:
{}
for
local_id
in
range
(
num_local_params
)}
singleton_state
:
Dict
=
copy
.
deepcopy
(
new_state
)
else
:
new_state
=
{}
non_tensor_state
=
{}
...
...
@@ -24,19 +27,26 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
for
global_id
,
buffers
in
sd
[
"state"
].
items
():
local_id
=
param_id_map
[
global_id
]
for
buffer_name
,
p
in
buffers
.
items
():
if
torch
.
is_tensor
(
p
):
if
is_singleton_tensor
(
p
):
singleton_state
[
local_id
][
buffer_name
]
=
p
elif
torch
.
is_tensor
(
p
):
if
buffer_name
not
in
new_state
[
local_id
]:
new_state
[
local_id
][
buffer_name
]
=
[]
new_state
[
local_id
][
buffer_name
].
append
(
p
.
reshape
(
-
1
))
elif
isinstance
(
p
,
list
):
singleton_state
[
local_id
][
buffer_name
]
=
p
else
:
non_tensor_state
[
buffer_name
]
=
p
# Now combine all tensors in each buffer using torch.cat().
for
local_id
,
state
in
new_state
.
items
():
for
buffer_name
,
tensors
in
state
.
items
():
new_state
[
local_id
][
buffer_name
]
=
torch
.
cat
(
tensors
)
new_state
[
local_id
].
update
(
non_tensor_state
)
new_state
[
local_id
].
update
(
singleton_state
[
local_id
])
new_sd
=
{
"state"
:
new_state
,
"param_groups"
:
copy
.
deepcopy
(
sd
[
"param_groups"
])}
for
k
in
sd
.
keys
():
# if there are extra keys, like loss_scale, don't delete them
if
k
not
in
UNFLAT_RETURN_KEYS
:
new_sd
[
k
]
=
copy
.
deepcopy
(
sd
[
k
])
# add pointers from the `params` dict.
for
pg_id
,
_
in
enumerate
(
sd
[
"param_groups"
]):
...
...
@@ -70,22 +80,11 @@ def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_
return
non_tensor_state
def
_combine_state
(
states
:
List
[
Dict
])
->
Dict
[
int
,
Dict
]:
combined_state
=
states
[
0
]
for
param_id
in
combined_state
:
combined_state
[
param_id
]
=
{
k
:
[
v
]
for
k
,
v
in
combined_state
[
param_id
].
items
()}
if
len
(
states
)
==
1
:
return
combined_state
for
rank
,
s
in
enumerate
(
states
[
1
:]):
for
param_id
,
param_state
in
s
.
items
():
for
k
,
tensor
in
param_state
.
items
():
combined_state
[
param_id
][
k
].
append
(
tensor
)
return
combined_state
def
_unflatten_optim_state
(
combined_state
:
Dict
[
int
,
Dict
],
instance_list
:
List
[
torch
.
nn
.
Module
],
world_pad_info
:
List
[
List
[
List
[
int
]]],
combined_state
:
Dict
[
int
,
Dict
],
instance_list
:
List
[
torch
.
nn
.
Module
],
world_pad_info
:
List
[
List
[
List
[
int
]]],
singleton_state
:
Dict
[
int
,
Dict
],
)
->
Tuple
[
Dict
[
int
,
Dict
],
Dict
[
int
,
int
]]:
# local ids are the keys in the current state (combined_state), (usually fewer)
# global ids will be the keys in the unflattened state
...
...
@@ -98,17 +97,17 @@ def _unflatten_optim_state(
non_tensor_state
=
[
_extract_non_tensor_state
(
combined_state
,
id
)
for
id
in
combined_state
]
# local corresponds to flattened, global corresponds to unflattened
num_
unflat
_params
=
[
len
(
m
.
_param_numels
)
for
m
in
instance_list
]
# type: ignore
num_
global
_params
=
[
len
(
m
.
_param_numels
)
for
m
in
instance_list
]
# type: ignore
global_to_local_id
=
{}
for
local_id
,
num_unflat
in
enumerate
(
num_
unflat
_params
):
for
local_id
,
num_unflat
in
enumerate
(
num_
global
_params
):
for
_
in
range
(
num_unflat
):
global_to_local_id
[
next_global_id
]
=
local_id
next_global_id
+=
1
if
not
combined_state
:
return
{},
global_to_local_id
#
If the constant state is the same as the combined state, copy it N times, no unflattening needed.
unflat_state
=
{
i
:
copy
.
deepcopy
(
non_tensor_state
[
0
])
for
i
in
range
(
sum
(
num_
unflat
_params
))}
#
copy non tensor state to all global entries
unflat_state
=
{
i
:
copy
.
deepcopy
(
non_tensor_state
[
0
])
for
i
in
range
(
sum
(
num_
global
_params
))}
if
non_tensor_state
[
0
].
keys
()
==
combined_state
[
0
].
keys
():
return
unflat_state
,
global_to_local_id
...
...
@@ -131,37 +130,44 @@ def _unflatten_optim_state(
for
global_id
,
param_view
in
zip
(
sorted
(
local_to_global
[
local_id
]),
param_views
):
assert
k
not
in
unflat_state
[
global_id
],
f
"already added
{
k
}
to
{
global_id
}
{
local_id
}
"
unflat_state
[
global_id
][
k
]
=
param_view
unflat_state
[
global_id
].
update
(
singleton_state
[
local_id
])
return
unflat_state
,
global_to_local_id
def
build_unflat_state_dict
(
instance_list
:
List
[
torch
.
nn
.
Module
],
world_optim_states
:
List
[
Dict
],
uncollected_opt_state
:
Dict
[
int
,
Dict
]
instance_list
:
List
[
torch
.
nn
.
Module
],
world_pad_info
:
List
[
List
[
List
[
int
]]],
state
:
Dict
[
int
,
Dict
[
str
,
List
[
torch
.
Tensor
]]],
singleton_state
:
Dict
[
int
,
Dict
[
str
,
List
[
torch
.
Tensor
]]],
uncollected_opt_state
:
Dict
[
int
,
Dict
],
param_groups
:
List
[
Dict
],
)
->
Dict
:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
world_pad_info
:
List
[
List
[
List
[
int
]]]
=
[
s
.
pop
(
"num_padded"
)
for
s
in
world_optim_states
]
assert
all
(
len
(
s
)
==
len
(
instance_list
)
for
s
in
world_pad_info
)
assert
all
(
len
(
s
[
0
])
==
1
for
s
in
world_pad_info
)
# Since there are no tensors in param_groups, deepcopy is fine
param_groups
=
copy
.
deepcopy
(
world_optim_states
[
0
][
"param_groups"
])
assert
len
(
param_groups
)
==
1
# Aggregate from a list of dictionaries to a dictionary of lists
combined_state
=
_combine_state
([
x
[
"state"
]
for
x
in
world_optim_states
])
# Use uncollected_opt_state to update tensor_state, singleton_state
for
local_id
,
v
in
uncollected_opt_state
.
items
():
assert
local_id
not
in
combined_state
combined_state
[
local_id
]
=
{}
for
buffer_name
,
tensor
in
v
.
items
():
combined_state
[
local_id
][
buffer_name
]
=
[
tensor
]
del
world_optim_states
assert
local_id
not
in
state
state
[
local_id
]
=
{
buffer_name
:
[
x
]
for
buffer_name
,
x
in
v
.
items
()
if
not
is_singleton_tensor
(
x
)}
singleton_state
[
local_id
]
=
{
buffer_name
:
[
x
]
for
buffer_name
,
x
in
v
.
items
()
if
is_singleton_tensor
(
x
)}
# local ids are in the current state, global_ids will be in returned state.
unflat_state
,
global_to_local_id
=
_unflatten_optim_state
(
combined_state
,
instance_list
,
world_pad_info
)
unflat_state
,
global_to_local_id
=
_unflatten_optim_state
(
state
,
instance_list
,
world_pad_info
,
singleton_state
)
# Since there are no tensors in param_groups, deepcopy is fine
param_groups
=
copy
.
deepcopy
(
param_groups
)
num_params
=
sum
([
len
(
m
.
_param_numels
)
for
m
in
instance_list
])
# type: ignore
param_groups
[
0
][
"params"
]
=
list
(
range
(
num_params
))
return
{
unflat_optim_state_dict
=
{
"state"
:
dict
(
sorted
(
unflat_state
.
items
())),
# NOTE: this is probably already sorted
"param_id_map"
:
global_to_local_id
,
"param_groups"
:
param_groups
,
"uncollected_local_ids"
:
list
(
uncollected_opt_state
.
keys
()),
}
assert
set
(
unflat_optim_state_dict
.
keys
())
==
UNFLAT_RETURN_KEYS
return
unflat_optim_state_dict
def
is_singleton_tensor
(
x
:
Any
)
->
bool
:
"""Is x a dimensionless tensor?"""
return
torch
.
is_tensor
(
x
)
and
x
.
dim
()
==
0
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
a82825db
...
...
@@ -1382,70 +1382,88 @@ class FullyShardedDataParallel(nn.Module):
traceback
.
print_stack
()
raise
ValueError
(
msg
)
def
_consolidate_optim_state_dict
(
self
,
optim
:
torch
.
optim
.
Optimizer
,
recipient_rank
:
Optional
[
int
]
=
None
)
->
List
[
Dict
]:
"""Update the consolidated state_dict list, one per rank.
Args:
optim (Optimizer): an optimizer instance for this FSDP rank. Its state is
used in the consolidation. However, its state is not modified.
recipient_rank (int): on which rank to materialize the full state dict.
None is a special value, which means that all ranks should have the state
Returns:
all_states (list[dict]) the optimizer state from each rank
.. warning: This needs to be called on all replicas"""
self
.
_lazy_init
()
# NOTE(SS): we do not support param groups yet, as they seem to break FSDP
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
should_collect_state
=
recipient_rank
is
None
or
(
self
.
rank
==
recipient_rank
)
all_states
:
List
[
Dict
[
str
,
Any
]]
=
[]
def
_broadcast_pad_info_to_r0
(
self
)
->
List
[
List
[
List
[
int
]]]:
"""Collect [x.numel_padded_per_param for x in self._fsdp_instances] from teach rank."""
dummy_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
compute_device
)
world_pad_info
:
List
[
List
[
List
[
int
]]]
=
[]
# this will contain values from the whole world.
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
sd
=
self
.
_remove_uncollectable_params_from_optim_state_dict
(
optim
.
state_dict
())
sd
[
"num_padded"
]
=
[
m
.
numel_padded_per_param
for
m
in
self
.
_fsdp_instances
]
pad_info
=
[
m
.
numel_padded_per_param
for
m
in
self
.
_fsdp_instances
]
else
:
sd
=
dummy_tensor
# type: ignore
sd
=
broadcast_object
(
sd
,
src_rank
=
rank
,
group
=
self
.
process_group
,
dist_device
=
self
.
compute_device
)
if
should_collect_state
:
assert
isinstance
(
sd
,
dict
),
f
"
{
self
.
rank
}
received
{
type
(
sd
)
}
from
{
rank
}
, expected dict"
all_states
.
append
(
recursive_copy_to_device
(
sd
,
non_blocking
=
False
,
device
=
torch
.
device
(
"cpu"
)))
return
all_states
def
gather_full_optim_state_dict
(
self
,
optim
:
torch
.
optim
.
Optimizer
,
recipient_rank
:
Optional
[
int
]
=
0
)
->
Optional
[
Dict
[
str
,
Any
]]:
pad_info
=
dummy_tensor
# type: ignore
pad_info
=
broadcast_object
(
pad_info
,
src_rank
=
rank
,
group
=
self
.
process_group
,
dist_device
=
self
.
compute_device
)
if
self
.
rank
==
0
:
world_pad_info
.
append
(
pad_info
)
# type: ignore
return
world_pad_info
def
_gather_optim_state
(
self
,
sd_state
:
Dict
[
int
,
Dict
[
str
,
Any
]]
)
->
Tuple
[
Dict
[
int
,
Dict
[
str
,
List
]],
Dict
[
int
,
Dict
[
str
,
List
]]]:
"""For each value in state[i], if the value is a tensor, collect it from the world. Else use rank 0's entry."""
gathered_state
:
Dict
[
int
,
Dict
[
str
,
List
[
Any
]]]
=
{}
singleton_state
:
Dict
[
int
,
Dict
[
str
,
List
[
Any
]]]
=
{}
# Dimensionless tensor
for
k
,
v
in
sd_state
.
items
():
gathered_state
[
k
]
=
{}
singleton_state
[
k
]
=
{}
desired_buffer_size
=
self
.
_fsdp_instances
[
k
].
flat_param
.
_full_param_padded
.
size
()
# type: ignore
buffer
=
None
# for sharded tensors
singleton_buffer
=
None
# for singleton tensors
for
buffer_name
,
t
in
v
.
items
():
if
ou
.
is_singleton_tensor
(
t
):
if
singleton_buffer
is
None
:
singleton_buffer
=
list
(
t
.
new_zeros
(
self
.
world_size
).
chunk
(
self
.
world_size
))
dist
.
all_gather
(
singleton_buffer
,
t
,
group
=
self
.
process_group
)
if
self
.
rank
==
0
:
singleton_state
[
k
][
buffer_name
]
=
[
x
.
cpu
().
squeeze
()
for
x
in
singleton_buffer
]
assert
ou
.
is_singleton_tensor
(
singleton_state
[
k
][
buffer_name
][
0
])
elif
torch
.
is_tensor
(
t
):
if
buffer
is
None
:
buffer
=
list
(
t
.
new_zeros
(
*
desired_buffer_size
).
chunk
(
self
.
world_size
))
dist
.
all_gather
(
buffer
,
t
,
group
=
self
.
process_group
)
if
self
.
rank
==
0
:
gathered_state
[
k
][
buffer_name
]
=
[
x
.
cpu
()
for
x
in
buffer
]
elif
self
.
rank
==
0
:
# Add non tensor state
gathered_state
[
k
][
buffer_name
]
=
[
t
]
return
gathered_state
,
singleton_state
def
gather_full_optim_state_dict
(
self
,
optim
:
torch
.
optim
.
Optimizer
,
**
ignored
:
Dict
)
->
Optional
[
Dict
[
str
,
Any
]]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. Multiple parameter groups are not yet supported.
This should be called only on the root FSDP instance.
Nested FSDP instances are supported as long as they have the same world_size as the parent or world_size=1.
Different world_size groups in nested FSDP instances is not supported.
Args:
optim (Optimizer): an optimizer instance for this FSDP rank. Its state is
used in the consolidation. However, its state is not modified.
recipient_rank (int): on which rank to materialize the full state dict.
optim (Optimizer): an optimizer instance for this FSDP rank. Its state_dict is
used in the consolidation. However, its state is not modified.
Returns:
a dict with two entries
* A dict with four entries (On rank zero, other workers return ``None``)
* state - a dict holding gathered optimization state, 1 entry per unflat parameter
* param_groups - a dict containing the 1 parameter group
* param_id_map - global (unflat) to local (flat) id mapping
* uncollected_local_ids - keys in the state dict that were not broadcast
"""
if
not
self
.
flatten_parameters
:
raise
NotImplementedError
(
"optim state dict requires flatten_parameters=True"
)
world_optim_states
=
self
.
_consolidate_optim_state_dict
(
optim
,
recipient_rank
)
if
self
.
rank
!=
recipient_rank
and
recipient_rank
is
not
None
:
self
.
_lazy_init
()
sd
=
self
.
_remove_uncollectable_params_from_optim_state_dict
(
optim
.
state_dict
())
assert
set
(
sd
.
keys
())
==
{
"param_groups"
,
"state"
},
f
'
{
set
(
sd
.
keys
())
}
!=
{
"param_groups"
,
"state"
}
'
assert
len
(
sd
[
"param_groups"
])
==
1
,
"Param groups are not supported"
# We use all_gather to consolidate OSD['state'] and broadcast to consolidate the other keys (like param_groups)
state
,
singleton_state
=
self
.
_gather_optim_state
(
sd
.
pop
(
"state"
))
pad_info
=
self
.
_broadcast_pad_info_to_r0
()
if
self
.
rank
!=
0
:
return
None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict
=
ou
.
build_unflat_state_dict
(
self
.
_fsdp_instances
,
world_optim
_state
s
,
self
.
uncollected_opt_state
self
.
_fsdp_instances
,
pad_info
,
state
,
singleton
_state
,
self
.
uncollected_opt_state
,
sd
[
"param_groups"
]
)
self
.
uncollected_opt_state
=
{}
assert
"uncollected_local_ids"
in
new_state_dict
...
...
@@ -1499,14 +1517,20 @@ class FullyShardedDataParallel(nn.Module):
for
k
,
v
in
s
.
items
():
if
torch
.
is_tensor
(
v
)
and
id
not
in
ids_not_to_shard
:
v_shard
,
_
=
self
.
_get_shard
(
v
)
elif
isinstance
(
v
,
list
)
and
ou
.
is_singleton_tensor
(
v
[
0
]):
# if we are resuming on larger world size, take first entry
v_shard
=
v
[
0
]
if
self
.
rank
>=
len
(
v
)
else
v
[
self
.
rank
]
assert
ou
.
is_singleton_tensor
(
v_shard
)
else
:
v_shard
=
v
# dont shard entries that are not tensors
full_optim_state_dict
[
"state"
][
id
][
k
]
=
v_shard
return
full_optim_state_dict
def
_print_r0
(
self
,
msg
:
str
)
->
None
:
def
_print_r0
(
self
,
msg
:
str
,
restart
:
bool
=
False
)
->
None
:
"""Debugging utility to print memory usage stats nicely on rank 0"""
if
restart
:
self
.
_tstart
=
time
.
time
()
if
self
.
rank
==
0
:
gb_denom
=
1024
**
3
print
(
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
a82825db
...
...
@@ -627,15 +627,19 @@ class MixtureOfExperts(NestedWrappedModule):
# "expert" params are different on each rank
torch
.
manual_seed
(
42
+
group
.
rank
())
d_expert
=
16
expert
=
nn
.
Linear
(
d_expert
,
4
)
d_expert
=
23
d_shared
=
12
d_input
=
8
expert
=
nn
.
Linear
(
d_expert
,
d_shared
)
self
.
num_expert_params
=
sum
([
p
.
numel
()
for
p
in
expert
.
parameters
()])
for
p
in
expert
.
parameters
():
p
.
expert
=
True
# everything else is shared
torch
.
manual_seed
(
0
)
shared
=
nn
.
Linear
(
4
,
d_expert
)
shared
=
nn
.
Linear
(
d_shared
,
d_expert
)
if
checkpoint_act
:
expert
=
checkpoint_wrapper
(
expert
)
...
...
@@ -648,7 +652,7 @@ class MixtureOfExperts(NestedWrappedModule):
shared
=
FullyShardedDataParallel
(
shared
,
group
,
**
wrapper_config
)
self
.
module
=
nn
.
Sequential
(
nn
.
Linear
(
8
,
4
),
shared
,
expert
,
nn
.
Linear
(
4
,
8
))
self
.
module
=
nn
.
Sequential
(
nn
.
Linear
(
d_input
,
d_shared
),
shared
,
expert
,
nn
.
Linear
(
d_shared
,
d_input
))
def
forward
(
self
,
x
):
if
self
.
delay_before_free_ms
>
0
:
...
...
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
View file @
a82825db
...
...
@@ -10,6 +10,7 @@ import torch
from
torch.optim
import
SGD
,
Adadelta
,
Adam
# type: ignore
from
fairscale.nn
import
FullyShardedDataParallel
from
fairscale.nn.data_parallel.fsdp_optim_utils
import
is_singleton_tensor
from
fairscale.optim.utils
import
recursive_copy_to_device
from
fairscale.utils.testing
import
objects_are_equal
...
...
@@ -147,3 +148,10 @@ class TestOptimizerUtils(DistributedTest):
named_pars
=
[
p
for
n
,
p
in
model
.
named_parameters
()]
for
i
,
p
in
enumerate
(
model
.
parameters
()):
assert
objects_are_equal
(
p
,
named_pars
[
i
])
def
test_is_singleton_tensor
(
self
):
assert
is_singleton_tensor
(
torch
.
tensor
(
4.0
))
assert
not
is_singleton_tensor
(
torch
.
tensor
([
4.0
]))
assert
not
is_singleton_tensor
(
torch
.
tensor
([
4.0
,
5.0
]))
assert
not
is_singleton_tensor
([
4.0
])
assert
not
is_singleton_tensor
(
4.0
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment