Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
1fcbd624
Unverified
Commit
1fcbd624
authored
Apr 04, 2021
by
Sam Shleifer
Committed by
GitHub
Apr 04, 2021
Browse files
[FSDP] add no_broadcast_optim_state option (#560)
parent
54a97ee5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
42 deletions
+99
-42
fairscale/nn/data_parallel/fsdp_optim_utils.py
fairscale/nn/data_parallel/fsdp_optim_utils.py
+21
-11
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+38
-8
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+2
-1
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
+38
-22
No files found.
fairscale/nn/data_parallel/fsdp_optim_utils.py
View file @
1fcbd624
...
...
@@ -21,22 +21,22 @@ def flatten_optim_state_dict(sd: Dict) -> Dict:
non_tensor_state
=
{}
# Populate `new_state["state"]`. (Assuming sd is sorted)
for
expanded_p
id
,
buffers
in
sd
[
"state"
].
items
():
consolidated_p
id
=
param_id_map
[
expanded_p
id
]
for
global_
id
,
buffers
in
sd
[
"state"
].
items
():
local_
id
=
param_id_map
[
global_
id
]
for
buffer_name
,
p
in
buffers
.
items
():
if
torch
.
is_tensor
(
p
):
if
buffer_name
not
in
new_state
[
consolidated_p
id
]:
new_state
[
consolidated_p
id
][
buffer_name
]
=
[]
new_state
[
consolidated_p
id
][
buffer_name
].
append
(
p
.
reshape
(
-
1
))
if
buffer_name
not
in
new_state
[
local_
id
]:
new_state
[
local_
id
][
buffer_name
]
=
[]
new_state
[
local_
id
][
buffer_name
].
append
(
p
.
reshape
(
-
1
))
else
:
non_tensor_state
[
buffer_name
]
=
p
# Now combine all tensors in each buffer using torch.cat().
for
consolidated_p
id
,
state
in
new_state
.
items
():
for
local_
id
,
state
in
new_state
.
items
():
for
buffer_name
,
tensors
in
state
.
items
():
new_state
[
consolidated_p
id
][
buffer_name
]
=
torch
.
cat
(
tensors
)
new_state
[
consolidated_p
id
].
update
(
non_tensor_state
)
new_sd
=
{
"state"
:
new_state
,
"param_groups"
:
sd
[
"param_groups"
]}
new_state
[
local_
id
][
buffer_name
]
=
torch
.
cat
(
tensors
)
new_state
[
local_
id
].
update
(
non_tensor_state
)
new_sd
=
{
"state"
:
new_state
,
"param_groups"
:
copy
.
deepcopy
(
sd
[
"param_groups"
]
)
}
# add pointers from the `params` dict.
for
pg_id
,
_
in
enumerate
(
sd
[
"param_groups"
]):
...
...
@@ -109,6 +109,7 @@ def _unflatten_optim_state(
# If the constant state is the same as the combined state, copy it N times, no unflattening needed.
unflat_state
=
{
i
:
copy
.
deepcopy
(
non_tensor_state
[
0
])
for
i
in
range
(
sum
(
num_unflat_params
))}
if
non_tensor_state
[
0
].
keys
()
==
combined_state
[
0
].
keys
():
return
unflat_state
,
global_to_local_id
...
...
@@ -134,24 +135,33 @@ def _unflatten_optim_state(
return
unflat_state
,
global_to_local_id
def
build_unflat_state_dict
(
instance_list
:
List
[
torch
.
nn
.
Module
],
world_optim_states
:
List
[
Dict
])
->
Dict
:
def
build_unflat_state_dict
(
instance_list
:
List
[
torch
.
nn
.
Module
],
world_optim_states
:
List
[
Dict
],
uncollected_opt_state
:
Dict
[
int
,
Dict
]
)
->
Dict
:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
world_pad_info
:
List
[
List
[
List
[
int
]]]
=
[
s
.
pop
(
"num_padded"
)
for
s
in
world_optim_states
]
assert
all
(
len
(
s
)
==
len
(
instance_list
)
for
s
in
world_pad_info
)
assert
all
(
len
(
s
[
0
])
==
1
for
s
in
world_pad_info
)
# Since there are no tensors in param_groups, deepcopy is fine
param_groups
=
copy
.
deepcopy
(
world_optim_states
[
0
][
"param_groups"
])
assert
len
(
param_groups
)
==
1
# Aggregate from a list of dictionaries to a dictionary of lists
combined_state
=
_combine_state
([
x
[
"state"
]
for
x
in
world_optim_states
])
for
local_id
,
v
in
uncollected_opt_state
.
items
():
assert
local_id
not
in
combined_state
combined_state
[
local_id
]
=
{}
for
buffer_name
,
tensor
in
v
.
items
():
combined_state
[
local_id
][
buffer_name
]
=
[
tensor
]
del
world_optim_states
# local ids are in the current state, global_ids will be in returned state.
unflat_state
,
global_to_local_id
=
_unflatten_optim_state
(
combined_state
,
instance_list
,
world_pad_info
)
num_params
=
sum
([
len
(
m
.
_param_numels
)
for
m
in
instance_list
])
# type: ignore
param_groups
[
0
][
"params"
]
=
list
(
range
(
num_params
))
# This could be a large list. #TODO: is it essential
param_groups
[
0
][
"params"
]
=
list
(
range
(
num_params
))
return
{
"state"
:
dict
(
sorted
(
unflat_state
.
items
())),
# NOTE: this is probably already sorted
"param_id_map"
:
global_to_local_id
,
"param_groups"
:
param_groups
,
"uncollected_local_ids"
:
list
(
uncollected_opt_state
.
keys
()),
}
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
1fcbd624
...
...
@@ -157,6 +157,12 @@ class FullyShardedDataParallel(nn.Module):
device, the param's device will be used. If not given and module
params are on CPU, then the current CUDA device (as indicated by
``torch.cuda.current_device()`` will be used.
no_broadcast_optim_state: (bool, Optional)
do not broadcast this modules optimizer state when ``gather_full_optim_state_dict`` is called.
If you set this true, you are expected to overwrite the relevant state entries of the returned optimizer state dict
with the proper state at each rank. This is useful for situations, like Mixture Of Experts,
where all but a few parameters can fit on one node.
Default: False
"""
def
__init__
(
...
...
@@ -173,6 +179,7 @@ class FullyShardedDataParallel(nn.Module):
move_grads_to_cpu
:
Optional
[
bool
]
=
None
,
bucket_cap_mb
:
int
=
25
,
compute_device
:
Optional
[
torch
.
device
]
=
None
,
no_broadcast_optim_state
:
Optional
[
bool
]
=
False
,
):
super
().
__init__
()
self
.
process_group
=
process_group
or
dist
.
new_group
()
...
...
@@ -187,6 +194,8 @@ class FullyShardedDataParallel(nn.Module):
self
.
buffer_dtype
=
buffer_dtype
or
self
.
compute_dtype
self
.
move_grads_to_cpu
=
cpu_offload
if
move_grads_to_cpu
is
None
else
move_grads_to_cpu
self
.
bucket_cap_mb
=
bucket_cap_mb
self
.
uncollected_opt_state
:
Dict
[
int
,
Dict
]
=
{}
self
.
no_broadcast_optim_state
=
no_broadcast_optim_state
self
.
gradient_predivide_factor
:
int
=
self
.
get_gradient_predivide_factor
(
self
.
world_size
)
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
...
...
@@ -849,6 +858,12 @@ class FullyShardedDataParallel(nn.Module):
if
m
.
process_group
!=
self
.
process_group
:
self
.
children_share_process_group
=
False
# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
# Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
m
.
no_broadcast_optim_state
=
m
.
no_broadcast_optim_state
or
(
(
m
.
world_size
==
1
)
and
(
m
.
world_size
<
self
.
world_size
)
and
(
m
.
process_group
!=
self
.
process_group
)
)
def
_setup_streams
(
self
)
->
None
:
"""Create streams to overlap data transfer and computation."""
if
len
(
self
.
_streams
)
>
0
or
not
self
.
_is_root
:
...
...
@@ -1391,7 +1406,7 @@ class FullyShardedDataParallel(nn.Module):
dummy_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
uint8
,
device
=
self
.
compute_device
)
for
rank
in
range
(
self
.
world_size
):
if
rank
==
self
.
rank
:
sd
=
optim
.
state_dict
()
sd
=
self
.
_remove_uncollectable_params_from_optim_state_dict
(
optim
.
state_dict
()
)
sd
[
"num_padded"
]
=
[
m
.
numel_padded_per_param
for
m
in
self
.
_fsdp_instances
]
else
:
sd
=
dummy_tensor
# type: ignore
...
...
@@ -1428,8 +1443,11 @@ class FullyShardedDataParallel(nn.Module):
if
self
.
rank
!=
recipient_rank
and
recipient_rank
is
not
None
:
return
None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict
=
ou
.
build_unflat_state_dict
(
self
.
_fsdp_instances
,
world_optim_states
)
# TODO: check if this code supports nested instances with different world size
new_state_dict
=
ou
.
build_unflat_state_dict
(
self
.
_fsdp_instances
,
world_optim_states
,
self
.
uncollected_opt_state
)
self
.
uncollected_opt_state
=
{}
assert
"uncollected_local_ids"
in
new_state_dict
return
new_state_dict
@
property
...
...
@@ -1437,6 +1455,17 @@ class FullyShardedDataParallel(nn.Module):
"""Returns all fsdp modules in self.modules() including self."""
return
[
m
for
m
in
self
.
modules
()
if
isinstance
(
m
,
FullyShardedDataParallel
)]
def
_remove_uncollectable_params_from_optim_state_dict
(
self
,
osd
:
Dict
)
->
Dict
:
uncollected_ids
=
[
i
for
i
,
m
in
enumerate
(
self
.
_fsdp_instances
)
if
m
.
no_broadcast_optim_state
]
new_dct
=
{
"state"
:
{
k
:
v
for
k
,
v
in
osd
[
"state"
].
items
()
if
k
not
in
uncollected_ids
}}
if
self
.
rank
==
0
:
# Save placeholders for uncollected opt state to keep the same unflat OSD format.
self
.
uncollected_opt_state
=
{
k
:
v
for
k
,
v
in
osd
[
"state"
].
items
()
if
k
in
uncollected_ids
}
pg
=
copy
.
deepcopy
(
osd
[
"param_groups"
])
new_dct
[
"param_groups"
]
=
pg
return
new_dct
def
get_shard_from_optim_state_dict
(
self
,
full_optim_state_dict
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
"""Get the portion of the optimizer state dict associated with the shard
...
...
@@ -1451,18 +1480,19 @@ class FullyShardedDataParallel(nn.Module):
"""
# Assert nesting is the same as it was at save time
instance_list
=
self
.
_fsdp_instances
assert
all
(
x
.
world_size
==
self
.
world_size
for
x
in
instance_list
),
"all nested instances must have same world size"
ou
.
check_param_counts_before_sharding
(
full_optim_state_dict
,
len
(
instance_list
))
ids_not_to_shard
=
copy
.
deepcopy
(
full_optim_state_dict
[
"uncollected_local_ids"
])
if
self
.
flatten_parameters
:
full_optim_state_dict
=
ou
.
flatten_optim_state_dict
(
full_optim_state_dict
)
assert
len
(
full_optim_state_dict
[
"state"
])
in
(
0
,
len
(
instance_list
))
assert
len
(
full_optim_state_dict
[
"state"
])
in
(
0
,
len
(
instance_list
),
),
f
'
{
len
(
full_optim_state_dict
[
"state"
])
}
,
{
len
(
instance_list
)
}
'
# get the portion of dict associated with the shard, in place
for
id
,
s
in
full_optim_state_dict
[
"state"
].
items
():
for
k
,
v
in
s
.
items
():
if
torch
.
is_tensor
(
v
):
if
torch
.
is_tensor
(
v
)
and
id
not
in
ids_not_to_shard
:
v_shard
,
_
=
self
.
_get_shard
(
v
)
else
:
v_shard
=
v
# dont shard entries that are not tensors
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
1fcbd624
...
...
@@ -782,6 +782,7 @@ class MixtureOfExperts(NestedWrappedModule):
# "expert" params are different on each rank
torch
.
manual_seed
(
42
+
group
.
rank
())
expert
=
nn
.
Linear
(
16
,
4
)
self
.
num_expert_params
=
sum
([
p
.
numel
()
for
p
in
expert
.
parameters
()])
for
p
in
expert
.
parameters
():
p
.
expert
=
True
...
...
@@ -795,7 +796,7 @@ class MixtureOfExperts(NestedWrappedModule):
if
wrapper_config
is
not
None
:
# we create a process group of size 1 for the expert params
expert_group
=
torch
.
distributed
.
new_group
([
group
.
rank
()])
expert_group
=
torch
.
distributed
.
new_group
([
group
.
rank
()])
# world size 1 means no shard
expert
=
FullyShardedDataParallel
(
expert
,
expert_group
,
**
wrapper_config
)
shared
=
FullyShardedDataParallel
(
shared
,
group
,
**
wrapper_config
)
...
...
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
View file @
1fcbd624
...
...
@@ -16,7 +16,7 @@ from fairscale.utils.testing import objects_are_equal
from
.test_fsdp
import
(
DistributedTest
,
DummyProcessGroup
,
NestedWrappedModule
,
MixtureOfExperts
,
TransformerWithSharedParams
,
rename_test
,
spawn_and_init
,
...
...
@@ -36,11 +36,12 @@ def assert_equal(a, b):
class
TestOptimizerUtils
(
DistributedTest
):
@
parameterized
.
expand
(
[[
functools
.
partial
(
SGD
,
momentum
=
0.9
),
True
],
[
SGD
,
False
],
[
Adam
,
False
],
[
Adadelta
,
True
]],
[[
functools
.
partial
(
SGD
,
momentum
=
0.9
),
True
],
[
SGD
,
False
],
[
Adam
,
False
],
[
Adadelta
,
True
],
[
Adam
,
True
]],
name_func
=
rename_test
,
)
def
test_consolidate_optimizer
(
self
,
optim_fn
,
transformer
):
config
=
{
"mixed_precision"
:
True
,
"flatten_parameters"
:
True
}
config
[
"compute_dtype"
]
=
torch
.
float32
test_fn
=
functools
.
partial
(
self
.
_test_consolidated_optimizer
,
config
,
optim_fn
=
optim_fn
,
transformer
=
transformer
)
...
...
@@ -53,11 +54,11 @@ class TestOptimizerUtils(DistributedTest):
# Establish reference behavior.
if
transformer
:
unwrapped_model
=
TransformerWithSharedParams
(
group
,
wrapper_config
=
config
).
cuda
()
fsdp
=
self
.
get_wrapped_model
(
group
,
config
=
config
).
cuda
()
unwrapped_model
=
TransformerWithSharedParams
(
group
).
cuda
()
else
:
fsdp
=
FullyShardedDataParallel
(
NestedWrappedModule
(
group
,
wrapper_config
=
c
on
fig
),
group
,
**
config
).
cuda
()
unwrapped_model
=
NestedWrappedModule
(
group
,
wrapper_config
=
N
on
e
).
cuda
()
unwrapped_model
=
MixtureOfExperts
(
group
,
wrapper_config
=
N
on
e
).
cuda
()
fsdp
=
FullyShardedDataParallel
(
MixtureOfExperts
(
group
,
wrapper_config
=
c
on
fig
)
).
cuda
()
try
:
fsdp_optim
=
optim_fn
(
fsdp
.
parameters
(),
lr
=
0.01
,)
...
...
@@ -68,19 +69,24 @@ class TestOptimizerUtils(DistributedTest):
fsdp_optim
.
zero_grad
()
optim_unwrapped
.
zero_grad
()
x
=
fsdp
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
output
=
fsdp
(
*
x
)
loss
=
fsdp
.
module
.
get_loss
(
x
,
output
).
to
(
"cuda"
)
fsdp
.
module
.
run_backward
(
loss
)
fsdp_optim
.
step
()
output
=
unwrapped_model
(
*
x
)
loss
=
unwrapped_model
.
get_loss
(
x
,
output
)
unwrapped_model
.
run_backward
(
loss
)
optim_unwrapped
.
step
()
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
True
):
x
=
fsdp
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
output
=
fsdp
(
*
x
)
loss
=
fsdp
.
module
.
get_loss
(
x
,
output
).
to
(
"cuda"
)
fsdp
.
module
.
run_backward
(
loss
)
fsdp_optim
.
step
()
output
=
unwrapped_model
(
*
x
)
loss
=
unwrapped_model
.
get_loss
(
x
,
output
)
unwrapped_model
.
run_backward
(
loss
)
optim_unwrapped
.
step
()
unwrapped_sd
=
optim_unwrapped
.
state_dict
()
if
not
transformer
:
no_broadcast_children
=
[
x
for
x
in
fsdp
.
_fsdp_instances
if
x
.
no_broadcast_optim_state
]
assert
len
(
no_broadcast_children
)
==
1
assert
fsdp
.
_fsdp_instances
[
-
1
].
no_broadcast_optim_state
tstart
=
time
()
sd
=
fsdp
.
gather_full_optim_state_dict
(
fsdp_optim
,
recipient_rank
=
0
)
duration
=
time
()
-
tstart
...
...
@@ -88,7 +94,14 @@ class TestOptimizerUtils(DistributedTest):
assert
duration
<
fsdp
.
world_size
,
f
"gather optim state took
{
duration
}
seconds, suspect change in _consolidate"
if
fsdp
.
rank
>
0
:
assert
sd
is
None
return
unflat_state
=
sd
[
"state"
]
assert
"uncollected_local_ids"
in
sd
shard_sd
=
fsdp
.
get_shard_from_optim_state_dict
(
sd
)
shard_sd
=
recursive_copy_to_device
(
shard_sd
,
non_blocking
=
False
,
device
=
"cpu"
)
state_after_get_shard
=
sd
[
"state"
]
assert
objects_are_equal
(
unflat_state
,
state_after_get_shard
)
# no side effects.
assert_equal
(
len
(
sd
[
"state"
]),
len
(
unwrapped_sd
[
"state"
]))
assert_equal
(
len
(
sd
[
"param_groups"
][
0
][
"params"
]),
len
(
unwrapped_sd
[
"param_groups"
][
0
][
"params"
]))
...
...
@@ -97,18 +110,21 @@ class TestOptimizerUtils(DistributedTest):
sum
([
first_tensor_numel
(
v
)
for
k
,
v
in
unwrapped_sd
[
"state"
].
items
()]),
)
shard_sd
=
fsdp
.
get_shard_from_optim_state_dict
(
sd
)
original_shard_sd
=
fsdp_optim
.
state_dict
()
assert_equal
(
len
(
shard_sd
[
"state"
]),
len
(
original_shard_sd
[
"state"
]))
assert_equal
(
shard_sd
.
keys
(),
original_shard_sd
.
keys
())
original_shard_sd
=
recursive_copy_to_device
(
original_shard_sd
,
non_blocking
=
False
,
device
=
"cpu"
)
# Before asserting that the dicts are equal, we check keys individually to allow nice tracebacks.
assert_equal
(
[
first_tensor_numel
(
v
)
for
k
,
v
in
shard_sd
[
"state"
].
items
()],
[
first_tensor_numel
(
v
)
for
k
,
v
in
original_shard_sd
[
"state"
].
items
()],
)
assert_equal
(
sum
([
first_tensor_numel
(
v
)
for
k
,
v
in
shard_sd
[
"
state"
].
items
()]
)
,
sum
([
first_tensor_numel
(
v
)
for
k
,
v
in
original_shard_sd
[
"
state"
].
items
()]
)
,
[
v
for
k
,
v
in
shard_sd
[
"
param_groups"
][
0
].
items
()],
[
v
for
k
,
v
in
original_shard_sd
[
"
param_groups"
][
0
].
items
()],
)
assert
objects_are_equal
(
shard_sd
,
original_shard_sd
)
assert
objects_are_equal
(
shard_sd
[
"state"
],
original_shard_sd
[
"state"
])
assert
objects_are_equal
({
k
:
shard_sd
[
k
]
for
k
in
original_shard_sd
},
original_shard_sd
)
def
test_named_params_ordering
(
self
):
"""Test assumption of consolidate_optimizer_state_dict"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment