Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a6549be7
Unverified
Commit
a6549be7
authored
Apr 08, 2021
by
Sam Shleifer
Committed by
GitHub
Apr 08, 2021
Browse files
[fix] [FSDP] optim state dict should be completely on CPU (#590)
parent
ce1f2cea
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
6 deletions
+34
-6
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+16
-3
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+3
-2
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
+15
-1
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
a6549be7
...
...
@@ -8,6 +8,7 @@ import copy
from
enum
import
Enum
,
auto
import
functools
from
math
import
inf
import
time
import
traceback
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
,
Union
...
...
@@ -208,6 +209,7 @@ class FullyShardedDataParallel(nn.Module):
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
self
.
numel_padded_per_param
:
List
[
int
]
=
[]
self
.
_tstart
=
time
.
time
()
if
self
.
fp32_reduce_scatter
and
not
self
.
mixed_precision
:
raise
ValueError
(
"fp32_reduce_scatter requires mixed_precision=True"
)
...
...
@@ -1414,7 +1416,6 @@ class FullyShardedDataParallel(nn.Module):
if
should_collect_state
:
assert
isinstance
(
sd
,
dict
),
f
"
{
self
.
rank
}
received
{
type
(
sd
)
}
from
{
rank
}
, expected dict"
all_states
.
append
(
recursive_copy_to_device
(
sd
,
non_blocking
=
False
,
device
=
torch
.
device
(
"cpu"
)))
return
all_states
def
gather_full_optim_state_dict
(
...
...
@@ -1459,8 +1460,12 @@ class FullyShardedDataParallel(nn.Module):
uncollected_ids
=
[
i
for
i
,
m
in
enumerate
(
self
.
_fsdp_instances
)
if
m
.
no_broadcast_optim_state
]
new_dct
=
{
"state"
:
{
k
:
v
for
k
,
v
in
osd
[
"state"
].
items
()
if
k
not
in
uncollected_ids
}}
if
self
.
rank
==
0
:
# Save placeholders for uncollected opt state to keep the same unflat OSD format.
self
.
uncollected_opt_state
=
{
k
:
v
for
k
,
v
in
osd
[
"state"
].
items
()
if
k
in
uncollected_ids
}
# Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU.
self
.
uncollected_opt_state
=
{
k
:
recursive_copy_to_device
(
v
,
non_blocking
=
False
,
device
=
torch
.
device
(
"cpu"
))
for
k
,
v
in
osd
[
"state"
].
items
()
if
k
in
uncollected_ids
}
pg
=
copy
.
deepcopy
(
osd
[
"param_groups"
])
new_dct
[
"param_groups"
]
=
pg
...
...
@@ -1500,6 +1505,14 @@ class FullyShardedDataParallel(nn.Module):
return
full_optim_state_dict
def
_print_r0
(
self
,
msg
:
str
)
->
None
:
"""Debugging utility to print memory usage stats nicely on rank 0"""
if
self
.
rank
==
0
:
gb_denom
=
1024
**
3
print
(
f
"
{
msg
}
cur=
{
torch
.
cuda
.
memory_allocated
()
/
gb_denom
:
.
4
f
}
GB, max=
{
torch
.
cuda
.
max_memory_allocated
()
/
gb_denom
:
.
4
f
}
GB, t=
{
time
.
time
()
-
self
.
_tstart
:
.
1
f
}
"
)
def
_get_default_cuda_device
(
module
:
nn
.
Module
)
->
torch
.
device
:
"""Try to infer CUDA device from module parameters."""
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
a6549be7
...
...
@@ -627,14 +627,15 @@ class MixtureOfExperts(NestedWrappedModule):
# "expert" params are different on each rank
torch
.
manual_seed
(
42
+
group
.
rank
())
expert
=
nn
.
Linear
(
16
,
4
)
d_expert
=
16
expert
=
nn
.
Linear
(
d_expert
,
4
)
self
.
num_expert_params
=
sum
([
p
.
numel
()
for
p
in
expert
.
parameters
()])
for
p
in
expert
.
parameters
():
p
.
expert
=
True
# everything else is shared
torch
.
manual_seed
(
0
)
shared
=
nn
.
Linear
(
4
,
16
)
shared
=
nn
.
Linear
(
4
,
d_expert
)
if
checkpoint_act
:
expert
=
checkpoint_wrapper
(
expert
)
...
...
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
View file @
a6549be7
...
...
@@ -86,16 +86,30 @@ class TestOptimizerUtils(DistributedTest):
no_broadcast_children
=
[
x
for
x
in
fsdp
.
_fsdp_instances
if
x
.
no_broadcast_optim_state
]
assert
len
(
no_broadcast_children
)
==
1
assert
fsdp
.
_fsdp_instances
[
-
1
].
no_broadcast_optim_state
torch
.
cuda
.
empty_cache
()
cuda_gb_before
=
torch
.
cuda
.
memory_stats
(
fsdp
.
rank
)[
"allocated_bytes.all.current"
]
/
1024
**
3
tstart
=
time
()
sd
=
fsdp
.
gather_full_optim_state_dict
(
fsdp_optim
,
recipient_rank
=
0
)
duration
=
time
()
-
tstart
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
assert
duration
<
fsdp
.
world_size
,
f
"gather optim state took
{
duration
}
seconds, suspect change in _consolidate"
cuda_gb_after
=
torch
.
cuda
.
memory_stats
(
fsdp
.
rank
)[
"allocated_bytes.all.current"
]
/
1024
**
3
mem_usg_gb
=
cuda_gb_after
-
cuda_gb_before
assert
mem_usg_gb
==
0
,
f
"gather_full_optim_state_dict used
{
mem_usg_gb
:.
2
f
}
CUDA GB, max allowed is 0"
assert
cuda_gb_after
>
0
,
"got 0 memory usage, logging is broken"
if
fsdp
.
rank
>
0
:
assert
sd
is
None
return
# assert whole state dict on CPU
for
k
,
v
in
sd
[
"state"
].
items
():
for
buffer_name
,
t
in
v
.
items
():
if
torch
.
is_tensor
(
t
):
msg
=
f
"got device
{
t
.
device
}
for
{
k
}
:
{
buffer_name
}
. expected CPU"
assert
t
.
device
==
torch
.
device
(
"cpu"
),
msg
unflat_state
=
sd
[
"state"
]
assert
"uncollected_local_ids"
in
sd
shard_sd
=
fsdp
.
get_shard_from_optim_state_dict
(
sd
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment