Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a6549be7
Unverified
Commit
a6549be7
authored
Apr 08, 2021
by
Sam Shleifer
Committed by
GitHub
Apr 08, 2021
Browse files
[fix] [FSDP] optim state dict should be completely on CPU (#590)
parent
ce1f2cea
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
6 deletions
+34
-6
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+16
-3
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+3
-2
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
+15
-1
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
a6549be7
...
@@ -8,6 +8,7 @@ import copy
...
@@ -8,6 +8,7 @@ import copy
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
import
functools
import
functools
from
math
import
inf
from
math
import
inf
import
time
import
traceback
import
traceback
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Set
,
Tuple
,
Union
...
@@ -208,6 +209,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -208,6 +209,7 @@ class FullyShardedDataParallel(nn.Module):
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
self
.
numel_padded_per_param
:
List
[
int
]
=
[]
self
.
numel_padded_per_param
:
List
[
int
]
=
[]
self
.
_tstart
=
time
.
time
()
if
self
.
fp32_reduce_scatter
and
not
self
.
mixed_precision
:
if
self
.
fp32_reduce_scatter
and
not
self
.
mixed_precision
:
raise
ValueError
(
"fp32_reduce_scatter requires mixed_precision=True"
)
raise
ValueError
(
"fp32_reduce_scatter requires mixed_precision=True"
)
...
@@ -1414,7 +1416,6 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -1414,7 +1416,6 @@ class FullyShardedDataParallel(nn.Module):
if
should_collect_state
:
if
should_collect_state
:
assert
isinstance
(
sd
,
dict
),
f
"
{
self
.
rank
}
received
{
type
(
sd
)
}
from
{
rank
}
, expected dict"
assert
isinstance
(
sd
,
dict
),
f
"
{
self
.
rank
}
received
{
type
(
sd
)
}
from
{
rank
}
, expected dict"
all_states
.
append
(
recursive_copy_to_device
(
sd
,
non_blocking
=
False
,
device
=
torch
.
device
(
"cpu"
)))
all_states
.
append
(
recursive_copy_to_device
(
sd
,
non_blocking
=
False
,
device
=
torch
.
device
(
"cpu"
)))
return
all_states
return
all_states
def
gather_full_optim_state_dict
(
def
gather_full_optim_state_dict
(
...
@@ -1459,8 +1460,12 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -1459,8 +1460,12 @@ class FullyShardedDataParallel(nn.Module):
uncollected_ids
=
[
i
for
i
,
m
in
enumerate
(
self
.
_fsdp_instances
)
if
m
.
no_broadcast_optim_state
]
uncollected_ids
=
[
i
for
i
,
m
in
enumerate
(
self
.
_fsdp_instances
)
if
m
.
no_broadcast_optim_state
]
new_dct
=
{
"state"
:
{
k
:
v
for
k
,
v
in
osd
[
"state"
].
items
()
if
k
not
in
uncollected_ids
}}
new_dct
=
{
"state"
:
{
k
:
v
for
k
,
v
in
osd
[
"state"
].
items
()
if
k
not
in
uncollected_ids
}}
if
self
.
rank
==
0
:
if
self
.
rank
==
0
:
# Save placeholders for uncollected opt state to keep the same unflat OSD format.
# Save placeholders for uncollected opt state to keep the same unflat OSD format, and move them to CPU.
self
.
uncollected_opt_state
=
{
k
:
v
for
k
,
v
in
osd
[
"state"
].
items
()
if
k
in
uncollected_ids
}
self
.
uncollected_opt_state
=
{
k
:
recursive_copy_to_device
(
v
,
non_blocking
=
False
,
device
=
torch
.
device
(
"cpu"
))
for
k
,
v
in
osd
[
"state"
].
items
()
if
k
in
uncollected_ids
}
pg
=
copy
.
deepcopy
(
osd
[
"param_groups"
])
pg
=
copy
.
deepcopy
(
osd
[
"param_groups"
])
new_dct
[
"param_groups"
]
=
pg
new_dct
[
"param_groups"
]
=
pg
...
@@ -1500,6 +1505,14 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -1500,6 +1505,14 @@ class FullyShardedDataParallel(nn.Module):
return
full_optim_state_dict
return
full_optim_state_dict
def
_print_r0
(
self
,
msg
:
str
)
->
None
:
"""Debugging utility to print memory usage stats nicely on rank 0"""
if
self
.
rank
==
0
:
gb_denom
=
1024
**
3
print
(
f
"
{
msg
}
cur=
{
torch
.
cuda
.
memory_allocated
()
/
gb_denom
:
.
4
f
}
GB, max=
{
torch
.
cuda
.
max_memory_allocated
()
/
gb_denom
:
.
4
f
}
GB, t=
{
time
.
time
()
-
self
.
_tstart
:
.
1
f
}
"
)
def
_get_default_cuda_device
(
module
:
nn
.
Module
)
->
torch
.
device
:
def
_get_default_cuda_device
(
module
:
nn
.
Module
)
->
torch
.
device
:
"""Try to infer CUDA device from module parameters."""
"""Try to infer CUDA device from module parameters."""
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
a6549be7
...
@@ -627,14 +627,15 @@ class MixtureOfExperts(NestedWrappedModule):
...
@@ -627,14 +627,15 @@ class MixtureOfExperts(NestedWrappedModule):
# "expert" params are different on each rank
# "expert" params are different on each rank
torch
.
manual_seed
(
42
+
group
.
rank
())
torch
.
manual_seed
(
42
+
group
.
rank
())
expert
=
nn
.
Linear
(
16
,
4
)
d_expert
=
16
expert
=
nn
.
Linear
(
d_expert
,
4
)
self
.
num_expert_params
=
sum
([
p
.
numel
()
for
p
in
expert
.
parameters
()])
self
.
num_expert_params
=
sum
([
p
.
numel
()
for
p
in
expert
.
parameters
()])
for
p
in
expert
.
parameters
():
for
p
in
expert
.
parameters
():
p
.
expert
=
True
p
.
expert
=
True
# everything else is shared
# everything else is shared
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
shared
=
nn
.
Linear
(
4
,
16
)
shared
=
nn
.
Linear
(
4
,
d_expert
)
if
checkpoint_act
:
if
checkpoint_act
:
expert
=
checkpoint_wrapper
(
expert
)
expert
=
checkpoint_wrapper
(
expert
)
...
...
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
View file @
a6549be7
...
@@ -86,16 +86,30 @@ class TestOptimizerUtils(DistributedTest):
...
@@ -86,16 +86,30 @@ class TestOptimizerUtils(DistributedTest):
no_broadcast_children
=
[
x
for
x
in
fsdp
.
_fsdp_instances
if
x
.
no_broadcast_optim_state
]
no_broadcast_children
=
[
x
for
x
in
fsdp
.
_fsdp_instances
if
x
.
no_broadcast_optim_state
]
assert
len
(
no_broadcast_children
)
==
1
assert
len
(
no_broadcast_children
)
==
1
assert
fsdp
.
_fsdp_instances
[
-
1
].
no_broadcast_optim_state
assert
fsdp
.
_fsdp_instances
[
-
1
].
no_broadcast_optim_state
torch
.
cuda
.
empty_cache
()
cuda_gb_before
=
torch
.
cuda
.
memory_stats
(
fsdp
.
rank
)[
"allocated_bytes.all.current"
]
/
1024
**
3
tstart
=
time
()
tstart
=
time
()
sd
=
fsdp
.
gather_full_optim_state_dict
(
fsdp_optim
,
recipient_rank
=
0
)
sd
=
fsdp
.
gather_full_optim_state_dict
(
fsdp_optim
,
recipient_rank
=
0
)
duration
=
time
()
-
tstart
duration
=
time
()
-
tstart
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise
assert
duration
<
fsdp
.
world_size
,
f
"gather optim state took
{
duration
}
seconds, suspect change in _consolidate"
assert
duration
<
fsdp
.
world_size
,
f
"gather optim state took
{
duration
}
seconds, suspect change in _consolidate"
cuda_gb_after
=
torch
.
cuda
.
memory_stats
(
fsdp
.
rank
)[
"allocated_bytes.all.current"
]
/
1024
**
3
mem_usg_gb
=
cuda_gb_after
-
cuda_gb_before
assert
mem_usg_gb
==
0
,
f
"gather_full_optim_state_dict used
{
mem_usg_gb
:.
2
f
}
CUDA GB, max allowed is 0"
assert
cuda_gb_after
>
0
,
"got 0 memory usage, logging is broken"
if
fsdp
.
rank
>
0
:
if
fsdp
.
rank
>
0
:
assert
sd
is
None
assert
sd
is
None
return
return
# assert whole state dict on CPU
for
k
,
v
in
sd
[
"state"
].
items
():
for
buffer_name
,
t
in
v
.
items
():
if
torch
.
is_tensor
(
t
):
msg
=
f
"got device
{
t
.
device
}
for
{
k
}
:
{
buffer_name
}
. expected CPU"
assert
t
.
device
==
torch
.
device
(
"cpu"
),
msg
unflat_state
=
sd
[
"state"
]
unflat_state
=
sd
[
"state"
]
assert
"uncollected_local_ids"
in
sd
assert
"uncollected_local_ids"
in
sd
shard_sd
=
fsdp
.
get_shard_from_optim_state_dict
(
sd
)
shard_sd
=
fsdp
.
get_shard_from_optim_state_dict
(
sd
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment