Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a6549be7
Unverified
Commit
a6549be7
authored
Apr 08, 2021
by
Sam Shleifer
Committed by
GitHub
Apr 08, 2021
Browse files
[fix] [FSDP] optim state dict should be completely on CPU (#590)
parent
ce1f2cea
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
6 deletions
+34
-6
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+16
-3
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+3
-2
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
+15
-1
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
a6549be7
...
...
@@ -8,6 +8,7 @@ import copy
from
enum
import
Enum
,
auto
import
functools
from
math
import
inf
import
time
import
traceback
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
,
Union
...
...
@@ -208,6 +209,7 @@ class FullyShardedDataParallel(nn.Module):
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
self
.
numel_padded_per_param
:
List
[
int
]
=
[]
self
.
_tstart
=
time
.
time
()
if
self
.
fp32_reduce_scatter
and
not
self
.
mixed_precision
:
raise
ValueError
(
"fp32_reduce_scatter requires mixed_precision=True"
)
...
...
@@ -1414,7 +1416,6 @@ class FullyShardedDataParallel(nn.Module):
if
should_collect_state
:
assert
isinstance
(
sd
,
dict
),
f
"
{
self
.
rank
}
received
{
type
(
sd
)
}
from
{
rank
}
, expected dict"
all_states
.
append
(
recursive_copy_to_device
(
sd
,
non_blocking
=
False
,
device
=
torch
.
device
(
"cpu"
)))
return
all_states
def
gather_full_optim_state_dict
(
...
...
@@ -1459,8 +1460,12 @@ class FullyShardedDataParallel(nn.Module):
uncollected_ids
=
[
i
for
i
,
m
in
enumerate
(
self
.
_fsdp_instances
)
if
m
.
no_broadcast_optim_state
]
new_dct
=
{
"state"
:
{
k
:
v
for
k
,
v
in
osd
[
"state"
].
items
()
if
k
not
in
uncollected_ids
}}
if
self
.
rank
==
0
:
# Save placeholders for uncollected opt state to keep the same unflat OSD format.
self
.
uncollected_opt_state
=
{
k
:
v
for
k
,
v
in
osd
[
"state"
].
items
()
if
k
in
uncollected_ids
}
# Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU.
self
.
uncollected_opt_state
=
{
k
:
recursive_copy_to_device
(
v
,
non_blocking
=
False
,
device
=
torch
.
device
(
"cpu"
))
for
k
,
v
in
osd
[
"state"
].
items
()
if
k
in
uncollected_ids
}
pg
=
copy
.
deepcopy
(
osd
[
"param_groups"
])
new_dct
[
"param_groups"
]
=
pg
...
...
@@ -1500,6 +1505,14 @@ class FullyShardedDataParallel(nn.Module):
return
full_optim_state_dict
def
_print_r0
(
self
,
msg
:
str
)
->
None
:
"""Debugging utility to print memory usage stats nicely on rank 0"""
if
self
.
rank
==
0
:
gb_denom
=
1024
**
3
print
(
f
"
{
msg
}
cur=
{
torch
.
cuda
.
memory_allocated
()
/
gb_denom
:
.
4
f
}
GB, max=
{
torch
.
cuda
.
max_memory_allocated
()
/
gb_denom
:
.
4
f
}
GB, t=
{
time
.
time
()
-
self
.
_tstart
:
.
1
f
}
"
)
def
_get_default_cuda_device
(
module
:
nn
.
Module
)
->
torch
.
device
:
"""Try to infer CUDA device from module parameters."""
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
a6549be7
...
...
@@ -627,14 +627,15 @@ class MixtureOfExperts(NestedWrappedModule):
# "expert" params are different on each rank
torch
.
manual_seed
(
42
+
group
.
rank
())
expert
=
nn
.
Linear
(
16
,
4
)
d_expert
=
16
expert
=
nn
.
Linear
(
d_expert
,
4
)
self
.
num_expert_params
=
sum
([
p
.
numel
()
for
p
in
expert
.
parameters
()])
for
p
in
expert
.
parameters
():
p
.
expert
=
True
# everything else is shared
torch
.
manual_seed
(
0
)
shared
=
nn
.
Linear
(
4
,
16
)
shared
=
nn
.
Linear
(
4
,
d_expert
)
if
checkpoint_act
:
expert
=
checkpoint_wrapper
(
expert
)
...
...
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
View file @
a6549be7
...
...
@@ -86,16 +86,30 @@ class TestOptimizerUtils(DistributedTest):
no_broadcast_children
=
[
x
for
x
in
fsdp
.
_fsdp_instances
if
x
.
no_broadcast_optim_state
]
assert
len
(
no_broadcast_children
)
==
1
assert
fsdp
.
_fsdp_instances
[
-
1
].
no_broadcast_optim_state
torch
.
cuda
.
empty_cache
()
cuda_gb_before
=
torch
.
cuda
.
memory_stats
(
fsdp
.
rank
)[
"allocated_bytes.all.current"
]
/
1024
**
3
tstart
=
time
()
sd
=
fsdp
.
gather_full_optim_state_dict
(
fsdp_optim
,
recipient_rank
=
0
)
duration
=
time
()
-
tstart
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
assert
duration
<
fsdp
.
world_size
,
f
"gather optim state took
{
duration
}
seconds, suspect change in _consolidate"
cuda_gb_after
=
torch
.
cuda
.
memory_stats
(
fsdp
.
rank
)[
"allocated_bytes.all.current"
]
/
1024
**
3
mem_usg_gb
=
cuda_gb_after
-
cuda_gb_before
assert
mem_usg_gb
==
0
,
f
"gather_full_optim_state_dict used
{
mem_usg_gb
:.
2
f
}
CUDA GB, max allowed is 0"
assert
cuda_gb_after
>
0
,
"got 0 memory usage, logging is broken"
if
fsdp
.
rank
>
0
:
assert
sd
is
None
return
# assert whole state dict on CPU
for
k
,
v
in
sd
[
"state"
].
items
():
for
buffer_name
,
t
in
v
.
items
():
if
torch
.
is_tensor
(
t
):
msg
=
f
"got device
{
t
.
device
}
for
{
k
}
:
{
buffer_name
}
. expected CPU"
assert
t
.
device
==
torch
.
device
(
"cpu"
),
msg
unflat_state
=
sd
[
"state"
]
assert
"uncollected_local_ids"
in
sd
shard_sd
=
fsdp
.
get_shard_from_optim_state_dict
(
sd
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment