Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b36e01d5
Unverified
Commit
b36e01d5
authored
Mar 04, 2021
by
Sam Shleifer
Committed by
GitHub
Mar 04, 2021
Browse files
[feat] add buffer_dtype kwarg for more control of batchnorm (#458)
parent
103d33c1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
69 additions
and
16 deletions
+69
-16
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+7
-3
fairscale/utils/testing.py
fairscale/utils/testing.py
+6
-2
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+56
-11
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
b36e01d5
...
@@ -127,6 +127,8 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -127,6 +127,8 @@ class FullyShardedDataParallel(nn.Module):
dtype for full parameters for computation. This defaults to
dtype for full parameters for computation. This defaults to
``torch.float32`` unless *``mixed_precision``* is set, in which case
``torch.float32`` unless *``mixed_precision``* is set, in which case
it defaults to ``torch.float16``.
it defaults to ``torch.float16``.
buffer_dtype (torch.dtype, Optional):
dtype for buffers for computation. This defaults to ``compute_dtype``.
move_grads_to_cpu (bool, Optional):
move_grads_to_cpu (bool, Optional):
move gradient shard to CPU after reduction. This is useful when
move gradient shard to CPU after reduction. This is useful when
combined with CPU-based optimizers. It defaults to the value of
combined with CPU-based optimizers. It defaults to the value of
...
@@ -150,6 +152,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -150,6 +152,7 @@ class FullyShardedDataParallel(nn.Module):
flatten_parameters
:
bool
=
True
,
flatten_parameters
:
bool
=
True
,
cpu_offload
:
bool
=
False
,
cpu_offload
:
bool
=
False
,
compute_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
compute_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
buffer_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
move_grads_to_cpu
:
Optional
[
bool
]
=
None
,
move_grads_to_cpu
:
Optional
[
bool
]
=
None
,
bucket_cap_mb
:
int
=
25
,
bucket_cap_mb
:
int
=
25
,
):
):
...
@@ -163,6 +166,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -163,6 +166,7 @@ class FullyShardedDataParallel(nn.Module):
self
.
flatten_parameters
=
flatten_parameters
self
.
flatten_parameters
=
flatten_parameters
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
self
.
compute_dtype
=
compute_dtype
or
(
torch
.
float16
if
mixed_precision
else
torch
.
float32
)
self
.
compute_dtype
=
compute_dtype
or
(
torch
.
float16
if
mixed_precision
else
torch
.
float32
)
self
.
buffer_dtype
=
buffer_dtype
or
self
.
compute_dtype
self
.
move_grads_to_cpu
=
cpu_offload
if
move_grads_to_cpu
is
None
else
move_grads_to_cpu
self
.
move_grads_to_cpu
=
cpu_offload
if
move_grads_to_cpu
is
None
else
move_grads_to_cpu
self
.
bucket_cap_mb
=
bucket_cap_mb
self
.
bucket_cap_mb
=
bucket_cap_mb
...
@@ -446,7 +450,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -446,7 +450,7 @@ class FullyShardedDataParallel(nn.Module):
if
self
.
mixed_precision
:
if
self
.
mixed_precision
:
# In case we are in mixed precision, restore buffers back to fp16.
# In case we are in mixed precision, restore buffers back to fp16.
self
.
_all_buffers_to
(
dtype
=
self
.
compute
_dtype
)
self
.
_all_buffers_to
(
dtype
=
self
.
buffer
_dtype
)
return
state_dict
return
state_dict
# TODO (Min): figuring out how to do typing for this overloaded function.
# TODO (Min): figuring out how to do typing for this overloaded function.
...
@@ -619,9 +623,9 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -619,9 +623,9 @@ class FullyShardedDataParallel(nn.Module):
self
.
_setup_streams
()
self
.
_setup_streams
()
if
self
.
cpu_offload
:
# Buffers stay on GPU, and don't get sharded
if
self
.
cpu_offload
:
# Buffers stay on GPU, and don't get sharded
self
.
_all_buffers_to
(
device
=
torch
.
device
(
"cuda"
),
dtype
=
self
.
compute
_dtype
)
self
.
_all_buffers_to
(
device
=
torch
.
device
(
"cuda"
),
dtype
=
self
.
buffer
_dtype
)
else
:
else
:
self
.
_all_buffers_to
(
dtype
=
self
.
compute
_dtype
)
self
.
_all_buffers_to
(
dtype
=
self
.
buffer
_dtype
)
if
self
.
_is_root
:
if
self
.
_is_root
:
# Don't free the full params for the outer-most (root) instance,
# Don't free the full params for the outer-most (root) instance,
...
...
fairscale/utils/testing.py
View file @
b36e01d5
...
@@ -471,6 +471,7 @@ class DeviceAndTypeCheckModule(Base):
...
@@ -471,6 +471,7 @@ class DeviceAndTypeCheckModule(Base):
expected_param_device
:
Optional
[
torch
.
device
]
=
None
,
expected_param_device
:
Optional
[
torch
.
device
]
=
None
,
expected_loss_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
expected_loss_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
expected_loss_device
:
Optional
[
torch
.
device
]
=
None
,
expected_loss_device
:
Optional
[
torch
.
device
]
=
None
,
expected_buffer_dtype
:
Optional
[
torch
.
device
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
expected_input_dtype
=
expected_input_dtype
self
.
expected_input_dtype
=
expected_input_dtype
...
@@ -479,8 +480,10 @@ class DeviceAndTypeCheckModule(Base):
...
@@ -479,8 +480,10 @@ class DeviceAndTypeCheckModule(Base):
self
.
expected_param_device
=
expected_param_device
self
.
expected_param_device
=
expected_param_device
self
.
expected_loss_dtype
=
expected_loss_dtype
self
.
expected_loss_dtype
=
expected_loss_dtype
self
.
expected_loss_device
=
expected_loss_device
self
.
expected_loss_device
=
expected_loss_device
self
.
expected_buffer_dtype
=
expected_buffer_dtype
self
.
linear
=
nn
.
Linear
(
5
,
5
)
self
.
linear
=
nn
.
Linear
(
5
,
5
)
self
.
register_buffer
(
"buffer"
,
torch
.
rand
((
5
,)))
def
_check
(
def
_check
(
self
,
self
,
...
@@ -498,8 +501,9 @@ class DeviceAndTypeCheckModule(Base):
...
@@ -498,8 +501,9 @@ class DeviceAndTypeCheckModule(Base):
param
=
self
.
linear
.
weight
param
=
self
.
linear
.
weight
self
.
_check
(
"param.dtype"
,
param
.
dtype
,
self
.
expected_param_dtype
)
self
.
_check
(
"param.dtype"
,
param
.
dtype
,
self
.
expected_param_dtype
)
self
.
_check
(
"param.device"
,
param
.
device
,
self
.
expected_param_device
)
self
.
_check
(
"param.device"
,
param
.
device
,
self
.
expected_param_device
)
self
.
_check
(
"buffer.dtype"
,
self
.
buffer
.
dtype
,
self
.
expected_buffer_dtype
)
# type: ignore
loss
=
self
.
linear
(
x
).
sum
()
x
=
x
+
self
.
buffer
loss
=
(
self
.
linear
(
x
)
+
self
.
buffer
).
sum
()
self
.
_check
(
"loss.dtype"
,
loss
.
dtype
,
self
.
expected_loss_dtype
)
self
.
_check
(
"loss.dtype"
,
loss
.
dtype
,
self
.
expected_loss_dtype
)
self
.
_check
(
"loss.device"
,
loss
.
device
,
self
.
expected_loss_device
)
self
.
_check
(
"loss.device"
,
loss
.
device
,
self
.
expected_loss_device
)
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
b36e01d5
...
@@ -110,6 +110,18 @@ class TestMixedPrecision(DistributedTest):
...
@@ -110,6 +110,18 @@ class TestMixedPrecision(DistributedTest):
torch
.
float16
,
# expected_reduce_dtype
torch
.
float16
,
# expected_reduce_dtype
)
)
def
test_mixed_precision_autocast_buffer_type_fp32
(
self
):
"""If autocast enabled, loss should be fp32."""
self
.
_spawn_test_case
(
{
"mixed_precision"
:
True
,
"buffer_dtype"
:
torch
.
float32
},
True
,
# autocast enabled
torch
.
float16
,
# expected_input_dtype
torch
.
float16
,
# expected_param_dtype
torch
.
float32
,
# expected_loss_dtype
torch
.
float16
,
# expected_reduce_dtype
expected_buffer_type
=
torch
.
float32
,
)
def
test_mixed_precision_autocast_fp32_compute
(
self
):
def
test_mixed_precision_autocast_fp32_compute
(
self
):
self
.
_spawn_test_case
(
self
.
_spawn_test_case
(
{
"mixed_precision"
:
True
,
"compute_dtype"
:
torch
.
float32
},
{
"mixed_precision"
:
True
,
"compute_dtype"
:
torch
.
float32
},
...
@@ -118,6 +130,7 @@ class TestMixedPrecision(DistributedTest):
...
@@ -118,6 +130,7 @@ class TestMixedPrecision(DistributedTest):
torch
.
float32
,
# expected_param_dtype
torch
.
float32
,
# expected_param_dtype
torch
.
float32
,
# expected_loss_dtype
torch
.
float32
,
# expected_loss_dtype
torch
.
float32
,
# expected_reduce_dtype
torch
.
float32
,
# expected_reduce_dtype
expected_buffer_type
=
torch
.
float32
,
)
)
def
test_fp32_reduce_scatter
(
self
):
def
test_fp32_reduce_scatter
(
self
):
...
@@ -128,6 +141,7 @@ class TestMixedPrecision(DistributedTest):
...
@@ -128,6 +141,7 @@ class TestMixedPrecision(DistributedTest):
torch
.
float16
,
# expected_param_dtype
torch
.
float16
,
# expected_param_dtype
torch
.
float16
,
# expected_loss_dtype
torch
.
float16
,
# expected_loss_dtype
torch
.
float32
,
# expected_reduce_dtype
torch
.
float32
,
# expected_reduce_dtype
expected_buffer_type
=
torch
.
float16
,
)
)
def
test_fp32_reduce_scatter_autocast
(
self
):
def
test_fp32_reduce_scatter_autocast
(
self
):
...
@@ -140,18 +154,42 @@ class TestMixedPrecision(DistributedTest):
...
@@ -140,18 +154,42 @@ class TestMixedPrecision(DistributedTest):
torch
.
float32
,
# expected_reduce_dtype
torch
.
float32
,
# expected_reduce_dtype
)
)
def
_spawn_test_case
(
self
,
cfg
,
autocast_enabled
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
world_size
=
2
):
def
_spawn_test_case
(
self
,
cfg
,
autocast_enabled
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
expected_buffer_type
=
None
,
world_size
=
2
,
):
"""Call test_dtypes inside of torch.multiprocessing.spawn"""
"""Call test_dtypes inside of torch.multiprocessing.spawn"""
fn
=
functools
.
partial
(
self
.
_test_dtypes
,
cfg
,
autocast_enabled
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
)
fn
=
functools
.
partial
(
self
.
_test_dtypes
,
cfg
,
autocast_enabled
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
expected_buffer_type
=
expected_buffer_type
,
)
spawn_and_init
(
fn
,
world_sizes
=
[
world_size
])
spawn_and_init
(
fn
,
world_sizes
=
[
world_size
])
@
staticmethod
@
staticmethod
def
_test_dtypes
(
cfg
:
Dict
,
autocast
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
rank
,
group
):
def
_test_dtypes
(
cfg
:
Dict
,
autocast
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
rank
,
group
,
expected_buffer_type
=
None
):
# Patch torch.distributed.reduce_scatter to check the dtype of the reduction
# Patch torch.distributed.reduce_scatter to check the dtype of the reduction
orig_reduce_scatter
=
torch
.
distributed
.
reduce_scatter
orig_reduce_scatter
=
torch
.
distributed
.
reduce_scatter
model
:
nn
.
Module
=
DeviceAndTypeCheckModule
(
model
:
nn
.
Module
=
DeviceAndTypeCheckModule
(
expected_input_dtype
=
in_dtype
,
expected_param_dtype
=
p_dtype
,
expected_loss_dtype
=
loss_dtype
,
expected_input_dtype
=
in_dtype
,
expected_param_dtype
=
p_dtype
,
expected_loss_dtype
=
loss_dtype
,
expected_buffer_dtype
=
expected_buffer_type
,
)
)
def
_reduce_scatter
(
output
,
input_list
,
**
kwargs
):
def
_reduce_scatter
(
output
,
input_list
,
**
kwargs
):
...
@@ -265,7 +303,7 @@ class TestComparisonToPyTorchDDP(DistributedTest):
...
@@ -265,7 +303,7 @@ class TestComparisonToPyTorchDDP(DistributedTest):
def
_test_identical_outputs
(
def
_test_identical_outputs
(
cls
,
model_init_fn
,
config
,
rank
,
group
,
num_steps
=
2
,
use_cuda
=
True
,
lr
=
0.01
,
ref_ddp_fn
=
None
,
norm_type
=
2
,
cls
,
model_init_fn
,
config
,
rank
,
group
,
num_steps
=
2
,
use_cuda
=
True
,
lr
=
0.01
,
ref_ddp_fn
=
None
,
norm_type
=
2
,
):
):
if
config
[
"mixed_precision"
]
:
if
config
.
get
(
"mixed_precision"
,
False
)
:
autocast
=
True
autocast
=
True
# Force the compute dtype to be torch.float32 so that we get
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# identical results as PyTorch DDP when using autocast. Note that
...
@@ -399,7 +437,9 @@ class TestLocalStateDict(DistributedTest):
...
@@ -399,7 +437,9 @@ class TestLocalStateDict(DistributedTest):
@
classmethod
@
classmethod
def
_load_local_and_train
(
self
,
config
,
rank
,
group
,
d_model
=
16
,
d_vocab
=
23
):
def
_load_local_and_train
(
self
,
config
,
rank
,
group
,
d_model
=
16
,
d_vocab
=
23
):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
d_vocab
=
d_vocab
,
d_model
=
d_model
)
model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
d_vocab
=
d_vocab
,
d_model
=
d_model
,
add_bn
=
False
)
# Set bn=True here to show that BN doesn't get updated
state_1
=
model
.
local_state_dict
()
state_1
=
model
.
local_state_dict
()
state_before_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_1
.
items
()}
state_before_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_1
.
items
()}
assert
len
(
state_1
)
>
0
assert
len
(
state_1
)
>
0
...
@@ -639,7 +679,7 @@ class TestNoSync(DistributedTest):
...
@@ -639,7 +679,7 @@ class TestNoSync(DistributedTest):
def
test_no_sync_before_first_forward
(
self
):
def
test_no_sync_before_first_forward
(
self
):
group
=
DummyProcessGroup
(
rank
=
0
,
size
=
1
)
group
=
DummyProcessGroup
(
rank
=
0
,
size
=
1
)
model
=
self
.
get_wrapped_model
(
group
,
config
=
{})
model
=
self
.
get_wrapped_model
(
group
,
config
=
{}
,
add_bn
=
False
)
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
with
model
.
no_sync
():
with
model
.
no_sync
():
output
=
model
(
*
batch
)
output
=
model
(
*
batch
)
...
@@ -651,7 +691,7 @@ class TestNoSync(DistributedTest):
...
@@ -651,7 +691,7 @@ class TestNoSync(DistributedTest):
@
classmethod
@
classmethod
def
_test_transformer
(
self
,
rank
,
group
,
config
):
def
_test_transformer
(
self
,
rank
,
group
,
config
):
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
)
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
,
add_bn
=
False
)
model
.
eval
()
# turn off dropout for the test
model
.
eval
()
# turn off dropout for the test
self
.
_test_no_sync
(
model
,
batch_dim
=
1
)
self
.
_test_no_sync
(
model
,
batch_dim
=
1
)
...
@@ -703,7 +743,7 @@ class TestNoSync(DistributedTest):
...
@@ -703,7 +743,7 @@ class TestNoSync(DistributedTest):
class
TransformerWithSharedParams
(
nn
.
Module
):
class
TransformerWithSharedParams
(
nn
.
Module
):
def
__init__
(
self
,
group
,
*
unused_args
,
d_vocab
=
23
,
d_model
=
16
,
**
unused_kwargs
):
def
__init__
(
self
,
group
,
*
unused_args
,
d_vocab
=
23
,
d_model
=
16
,
add_bn
=
True
,
**
unused_kwargs
):
super
().
__init__
()
super
().
__init__
()
self
.
rank
=
group
.
rank
()
self
.
rank
=
group
.
rank
()
self
.
world_size
=
group
.
size
()
self
.
world_size
=
group
.
size
()
...
@@ -714,21 +754,26 @@ class TransformerWithSharedParams(nn.Module):
...
@@ -714,21 +754,26 @@ class TransformerWithSharedParams(nn.Module):
d_model
=
d_model
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
8
,
dropout
=
0.1
,
d_model
=
d_model
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
8
,
dropout
=
0.1
,
)
)
self
.
output_proj
=
nn
.
Linear
(
d_model
,
d_vocab
)
self
.
output_proj
=
nn
.
Linear
(
d_model
,
d_vocab
)
# share the embedding and output projection weights
# share the embedding and output projection weights
self
.
output_proj
.
weight
=
self
.
embed_tokens
.
weight
self
.
output_proj
.
weight
=
self
.
embed_tokens
.
weight
self
.
register_buffer
(
"vocab_bias"
,
self
.
embed_tokens
.
weight
.
new_ones
((
d_model
,)))
self
.
register_buffer
(
"vocab_bias"
,
self
.
embed_tokens
.
weight
.
new_ones
((
d_model
,)))
self
.
register_buffer
(
"long_buffer"
,
torch
.
zeros_like
(
self
.
vocab_bias
,
dtype
=
torch
.
long
))
self
.
register_buffer
(
"long_buffer"
,
torch
.
zeros_like
(
self
.
vocab_bias
,
dtype
=
torch
.
long
))
self
.
bs
=
2
self
.
bn
=
torch
.
nn
.
BatchNorm1d
(
self
.
bs
)
if
add_bn
else
torch
.
nn
.
Identity
()
def
get_input
(
self
,
device
):
def
get_input
(
self
,
device
):
torch
.
manual_seed
(
1
+
self
.
rank
)
# keep everything deterministic
torch
.
manual_seed
(
1
+
self
.
rank
)
# keep everything deterministic
src
=
torch
.
arange
(
12
,
device
=
device
).
view
(
6
,
2
)
# T x B
src
=
torch
.
arange
(
12
,
device
=
device
).
view
(
6
,
self
.
bs
)
# T x B
tgt
=
torch
.
arange
(
8
,
device
=
device
).
view
(
4
,
2
)
# T x B
tgt
=
torch
.
arange
(
self
.
bs
*
4
,
device
=
device
).
view
(
4
,
self
.
bs
)
# T x B
return
(
src
,
tgt
)
return
(
src
,
tgt
)
def
forward
(
self
,
src_ids
,
tgt_ids
):
def
forward
(
self
,
src_ids
,
tgt_ids
):
src
=
self
.
embed_tokens
(
src_ids
)
src
=
self
.
embed_tokens
(
src_ids
)
src
=
src
+
self
.
vocab_bias
+
self
.
long_buffer
.
type_as
(
src
)
src
=
src
+
self
.
vocab_bias
+
self
.
long_buffer
.
type_as
(
src
)
tgt
=
self
.
embed_tokens
(
tgt_ids
)
tgt
=
self
.
embed_tokens
(
tgt_ids
)
tgt
=
self
.
bn
(
tgt
)
x
=
self
.
transformer
(
src
,
tgt
)
x
=
self
.
transformer
(
src
,
tgt
)
return
self
.
output_proj
(
x
)
return
self
.
output_proj
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment