Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b36e01d5
Unverified
Commit
b36e01d5
authored
Mar 04, 2021
by
Sam Shleifer
Committed by
GitHub
Mar 04, 2021
Browse files
[feat] add buffer_dtype kwarg for more control of batchnorm (#458)
parent
103d33c1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
69 additions
and
16 deletions
+69
-16
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+7
-3
fairscale/utils/testing.py
fairscale/utils/testing.py
+6
-2
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+56
-11
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
b36e01d5
...
@@ -127,6 +127,8 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -127,6 +127,8 @@ class FullyShardedDataParallel(nn.Module):
dtype for full parameters for computation. This defaults to
dtype for full parameters for computation. This defaults to
``torch.float32`` unless *``mixed_precision``* is set, in which case
``torch.float32`` unless *``mixed_precision``* is set, in which case
it defaults to ``torch.float16``.
it defaults to ``torch.float16``.
buffer_dtype (torch.dtype, Optional):
dtype for buffers for computation. This defaults to ``compute_dtype``.
move_grads_to_cpu (bool, Optional):
move_grads_to_cpu (bool, Optional):
move gradient shard to CPU after reduction. This is useful when
move gradient shard to CPU after reduction. This is useful when
combined with CPU-based optimizers. It defaults to the value of
combined with CPU-based optimizers. It defaults to the value of
...
@@ -150,6 +152,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -150,6 +152,7 @@ class FullyShardedDataParallel(nn.Module):
flatten_parameters
:
bool
=
True
,
flatten_parameters
:
bool
=
True
,
cpu_offload
:
bool
=
False
,
cpu_offload
:
bool
=
False
,
compute_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
compute_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
buffer_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
move_grads_to_cpu
:
Optional
[
bool
]
=
None
,
move_grads_to_cpu
:
Optional
[
bool
]
=
None
,
bucket_cap_mb
:
int
=
25
,
bucket_cap_mb
:
int
=
25
,
):
):
...
@@ -163,6 +166,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -163,6 +166,7 @@ class FullyShardedDataParallel(nn.Module):
self
.
flatten_parameters
=
flatten_parameters
self
.
flatten_parameters
=
flatten_parameters
self
.
cpu_offload
=
cpu_offload
self
.
cpu_offload
=
cpu_offload
self
.
compute_dtype
=
compute_dtype
or
(
torch
.
float16
if
mixed_precision
else
torch
.
float32
)
self
.
compute_dtype
=
compute_dtype
or
(
torch
.
float16
if
mixed_precision
else
torch
.
float32
)
self
.
buffer_dtype
=
buffer_dtype
or
self
.
compute_dtype
self
.
move_grads_to_cpu
=
cpu_offload
if
move_grads_to_cpu
is
None
else
move_grads_to_cpu
self
.
move_grads_to_cpu
=
cpu_offload
if
move_grads_to_cpu
is
None
else
move_grads_to_cpu
self
.
bucket_cap_mb
=
bucket_cap_mb
self
.
bucket_cap_mb
=
bucket_cap_mb
...
@@ -446,7 +450,7 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -446,7 +450,7 @@ class FullyShardedDataParallel(nn.Module):
if
self
.
mixed_precision
:
if
self
.
mixed_precision
:
# In case we are in mixed precision, restore buffers back to fp16.
# In case we are in mixed precision, restore buffers back to fp16.
self
.
_all_buffers_to
(
dtype
=
self
.
compute
_dtype
)
self
.
_all_buffers_to
(
dtype
=
self
.
buffer
_dtype
)
return
state_dict
return
state_dict
# TODO (Min): figuring out how to do typing for this overloaded function.
# TODO (Min): figuring out how to do typing for this overloaded function.
...
@@ -619,9 +623,9 @@ class FullyShardedDataParallel(nn.Module):
...
@@ -619,9 +623,9 @@ class FullyShardedDataParallel(nn.Module):
self
.
_setup_streams
()
self
.
_setup_streams
()
if
self
.
cpu_offload
:
# Buffers stay on GPU, and don't get sharded
if
self
.
cpu_offload
:
# Buffers stay on GPU, and don't get sharded
self
.
_all_buffers_to
(
device
=
torch
.
device
(
"cuda"
),
dtype
=
self
.
compute
_dtype
)
self
.
_all_buffers_to
(
device
=
torch
.
device
(
"cuda"
),
dtype
=
self
.
buffer
_dtype
)
else
:
else
:
self
.
_all_buffers_to
(
dtype
=
self
.
compute
_dtype
)
self
.
_all_buffers_to
(
dtype
=
self
.
buffer
_dtype
)
if
self
.
_is_root
:
if
self
.
_is_root
:
# Don't free the full params for the outer-most (root) instance,
# Don't free the full params for the outer-most (root) instance,
...
...
fairscale/utils/testing.py
View file @
b36e01d5
...
@@ -471,6 +471,7 @@ class DeviceAndTypeCheckModule(Base):
...
@@ -471,6 +471,7 @@ class DeviceAndTypeCheckModule(Base):
expected_param_device
:
Optional
[
torch
.
device
]
=
None
,
expected_param_device
:
Optional
[
torch
.
device
]
=
None
,
expected_loss_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
expected_loss_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
expected_loss_device
:
Optional
[
torch
.
device
]
=
None
,
expected_loss_device
:
Optional
[
torch
.
device
]
=
None
,
expected_buffer_dtype
:
Optional
[
torch
.
device
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
expected_input_dtype
=
expected_input_dtype
self
.
expected_input_dtype
=
expected_input_dtype
...
@@ -479,8 +480,10 @@ class DeviceAndTypeCheckModule(Base):
...
@@ -479,8 +480,10 @@ class DeviceAndTypeCheckModule(Base):
self
.
expected_param_device
=
expected_param_device
self
.
expected_param_device
=
expected_param_device
self
.
expected_loss_dtype
=
expected_loss_dtype
self
.
expected_loss_dtype
=
expected_loss_dtype
self
.
expected_loss_device
=
expected_loss_device
self
.
expected_loss_device
=
expected_loss_device
self
.
expected_buffer_dtype
=
expected_buffer_dtype
self
.
linear
=
nn
.
Linear
(
5
,
5
)
self
.
linear
=
nn
.
Linear
(
5
,
5
)
self
.
register_buffer
(
"buffer"
,
torch
.
rand
((
5
,)))
def
_check
(
def
_check
(
self
,
self
,
...
@@ -498,8 +501,9 @@ class DeviceAndTypeCheckModule(Base):
...
@@ -498,8 +501,9 @@ class DeviceAndTypeCheckModule(Base):
param
=
self
.
linear
.
weight
param
=
self
.
linear
.
weight
self
.
_check
(
"param.dtype"
,
param
.
dtype
,
self
.
expected_param_dtype
)
self
.
_check
(
"param.dtype"
,
param
.
dtype
,
self
.
expected_param_dtype
)
self
.
_check
(
"param.device"
,
param
.
device
,
self
.
expected_param_device
)
self
.
_check
(
"param.device"
,
param
.
device
,
self
.
expected_param_device
)
self
.
_check
(
"buffer.dtype"
,
self
.
buffer
.
dtype
,
self
.
expected_buffer_dtype
)
# type: ignore
loss
=
self
.
linear
(
x
).
sum
()
x
=
x
+
self
.
buffer
loss
=
(
self
.
linear
(
x
)
+
self
.
buffer
).
sum
()
self
.
_check
(
"loss.dtype"
,
loss
.
dtype
,
self
.
expected_loss_dtype
)
self
.
_check
(
"loss.dtype"
,
loss
.
dtype
,
self
.
expected_loss_dtype
)
self
.
_check
(
"loss.device"
,
loss
.
device
,
self
.
expected_loss_device
)
self
.
_check
(
"loss.device"
,
loss
.
device
,
self
.
expected_loss_device
)
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
b36e01d5
...
@@ -110,6 +110,18 @@ class TestMixedPrecision(DistributedTest):
...
@@ -110,6 +110,18 @@ class TestMixedPrecision(DistributedTest):
torch
.
float16
,
# expected_reduce_dtype
torch
.
float16
,
# expected_reduce_dtype
)
)
def
test_mixed_precision_autocast_buffer_type_fp32
(
self
):
"""If autocast enabled, loss should be fp32."""
self
.
_spawn_test_case
(
{
"mixed_precision"
:
True
,
"buffer_dtype"
:
torch
.
float32
},
True
,
# autocast enabled
torch
.
float16
,
# expected_input_dtype
torch
.
float16
,
# expected_param_dtype
torch
.
float32
,
# expected_loss_dtype
torch
.
float16
,
# expected_reduce_dtype
expected_buffer_type
=
torch
.
float32
,
)
def
test_mixed_precision_autocast_fp32_compute
(
self
):
def
test_mixed_precision_autocast_fp32_compute
(
self
):
self
.
_spawn_test_case
(
self
.
_spawn_test_case
(
{
"mixed_precision"
:
True
,
"compute_dtype"
:
torch
.
float32
},
{
"mixed_precision"
:
True
,
"compute_dtype"
:
torch
.
float32
},
...
@@ -118,6 +130,7 @@ class TestMixedPrecision(DistributedTest):
...
@@ -118,6 +130,7 @@ class TestMixedPrecision(DistributedTest):
torch
.
float32
,
# expected_param_dtype
torch
.
float32
,
# expected_param_dtype
torch
.
float32
,
# expected_loss_dtype
torch
.
float32
,
# expected_loss_dtype
torch
.
float32
,
# expected_reduce_dtype
torch
.
float32
,
# expected_reduce_dtype
expected_buffer_type
=
torch
.
float32
,
)
)
def
test_fp32_reduce_scatter
(
self
):
def
test_fp32_reduce_scatter
(
self
):
...
@@ -128,6 +141,7 @@ class TestMixedPrecision(DistributedTest):
...
@@ -128,6 +141,7 @@ class TestMixedPrecision(DistributedTest):
torch
.
float16
,
# expected_param_dtype
torch
.
float16
,
# expected_param_dtype
torch
.
float16
,
# expected_loss_dtype
torch
.
float16
,
# expected_loss_dtype
torch
.
float32
,
# expected_reduce_dtype
torch
.
float32
,
# expected_reduce_dtype
expected_buffer_type
=
torch
.
float16
,
)
)
def
test_fp32_reduce_scatter_autocast
(
self
):
def
test_fp32_reduce_scatter_autocast
(
self
):
...
@@ -140,18 +154,42 @@ class TestMixedPrecision(DistributedTest):
...
@@ -140,18 +154,42 @@ class TestMixedPrecision(DistributedTest):
torch
.
float32
,
# expected_reduce_dtype
torch
.
float32
,
# expected_reduce_dtype
)
)
def
_spawn_test_case
(
self
,
cfg
,
autocast_enabled
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
world_size
=
2
):
def
_spawn_test_case
(
self
,
cfg
,
autocast_enabled
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
expected_buffer_type
=
None
,
world_size
=
2
,
):
"""Call test_dtypes inside of torch.multiprocessing.spawn"""
"""Call test_dtypes inside of torch.multiprocessing.spawn"""
fn
=
functools
.
partial
(
self
.
_test_dtypes
,
cfg
,
autocast_enabled
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
)
fn
=
functools
.
partial
(
self
.
_test_dtypes
,
cfg
,
autocast_enabled
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
expected_buffer_type
=
expected_buffer_type
,
)
spawn_and_init
(
fn
,
world_sizes
=
[
world_size
])
spawn_and_init
(
fn
,
world_sizes
=
[
world_size
])
@
staticmethod
@
staticmethod
def
_test_dtypes
(
cfg
:
Dict
,
autocast
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
rank
,
group
):
def
_test_dtypes
(
cfg
:
Dict
,
autocast
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
rank
,
group
,
expected_buffer_type
=
None
):
# Patch torch.distributed.reduce_scatter to check the dtype of the reduction
# Patch torch.distributed.reduce_scatter to check the dtype of the reduction
orig_reduce_scatter
=
torch
.
distributed
.
reduce_scatter
orig_reduce_scatter
=
torch
.
distributed
.
reduce_scatter
model
:
nn
.
Module
=
DeviceAndTypeCheckModule
(
model
:
nn
.
Module
=
DeviceAndTypeCheckModule
(
expected_input_dtype
=
in_dtype
,
expected_param_dtype
=
p_dtype
,
expected_loss_dtype
=
loss_dtype
,
expected_input_dtype
=
in_dtype
,
expected_param_dtype
=
p_dtype
,
expected_loss_dtype
=
loss_dtype
,
expected_buffer_dtype
=
expected_buffer_type
,
)
)
def
_reduce_scatter
(
output
,
input_list
,
**
kwargs
):
def
_reduce_scatter
(
output
,
input_list
,
**
kwargs
):
...
@@ -265,7 +303,7 @@ class TestComparisonToPyTorchDDP(DistributedTest):
...
@@ -265,7 +303,7 @@ class TestComparisonToPyTorchDDP(DistributedTest):
def
_test_identical_outputs
(
def
_test_identical_outputs
(
cls
,
model_init_fn
,
config
,
rank
,
group
,
num_steps
=
2
,
use_cuda
=
True
,
lr
=
0.01
,
ref_ddp_fn
=
None
,
norm_type
=
2
,
cls
,
model_init_fn
,
config
,
rank
,
group
,
num_steps
=
2
,
use_cuda
=
True
,
lr
=
0.01
,
ref_ddp_fn
=
None
,
norm_type
=
2
,
):
):
if
config
[
"mixed_precision"
]
:
if
config
.
get
(
"mixed_precision"
,
False
)
:
autocast
=
True
autocast
=
True
# Force the compute dtype to be torch.float32 so that we get
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
# identical results as PyTorch DDP when using autocast. Note that
...
@@ -399,7 +437,9 @@ class TestLocalStateDict(DistributedTest):
...
@@ -399,7 +437,9 @@ class TestLocalStateDict(DistributedTest):
@
classmethod
@
classmethod
def
_load_local_and_train
(
self
,
config
,
rank
,
group
,
d_model
=
16
,
d_vocab
=
23
):
def
_load_local_and_train
(
self
,
config
,
rank
,
group
,
d_model
=
16
,
d_vocab
=
23
):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
d_vocab
=
d_vocab
,
d_model
=
d_model
)
model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
d_vocab
=
d_vocab
,
d_model
=
d_model
,
add_bn
=
False
)
# Set bn=True here to show that BN doesn't get updated
state_1
=
model
.
local_state_dict
()
state_1
=
model
.
local_state_dict
()
state_before_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_1
.
items
()}
state_before_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_1
.
items
()}
assert
len
(
state_1
)
>
0
assert
len
(
state_1
)
>
0
...
@@ -639,7 +679,7 @@ class TestNoSync(DistributedTest):
...
@@ -639,7 +679,7 @@ class TestNoSync(DistributedTest):
def
test_no_sync_before_first_forward
(
self
):
def
test_no_sync_before_first_forward
(
self
):
group
=
DummyProcessGroup
(
rank
=
0
,
size
=
1
)
group
=
DummyProcessGroup
(
rank
=
0
,
size
=
1
)
model
=
self
.
get_wrapped_model
(
group
,
config
=
{})
model
=
self
.
get_wrapped_model
(
group
,
config
=
{}
,
add_bn
=
False
)
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
with
model
.
no_sync
():
with
model
.
no_sync
():
output
=
model
(
*
batch
)
output
=
model
(
*
batch
)
...
@@ -651,7 +691,7 @@ class TestNoSync(DistributedTest):
...
@@ -651,7 +691,7 @@ class TestNoSync(DistributedTest):
@
classmethod
@
classmethod
def
_test_transformer
(
self
,
rank
,
group
,
config
):
def
_test_transformer
(
self
,
rank
,
group
,
config
):
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
)
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
,
add_bn
=
False
)
model
.
eval
()
# turn off dropout for the test
model
.
eval
()
# turn off dropout for the test
self
.
_test_no_sync
(
model
,
batch_dim
=
1
)
self
.
_test_no_sync
(
model
,
batch_dim
=
1
)
...
@@ -703,7 +743,7 @@ class TestNoSync(DistributedTest):
...
@@ -703,7 +743,7 @@ class TestNoSync(DistributedTest):
class
TransformerWithSharedParams
(
nn
.
Module
):
class
TransformerWithSharedParams
(
nn
.
Module
):
def
__init__
(
self
,
group
,
*
unused_args
,
d_vocab
=
23
,
d_model
=
16
,
**
unused_kwargs
):
def
__init__
(
self
,
group
,
*
unused_args
,
d_vocab
=
23
,
d_model
=
16
,
add_bn
=
True
,
**
unused_kwargs
):
super
().
__init__
()
super
().
__init__
()
self
.
rank
=
group
.
rank
()
self
.
rank
=
group
.
rank
()
self
.
world_size
=
group
.
size
()
self
.
world_size
=
group
.
size
()
...
@@ -714,21 +754,26 @@ class TransformerWithSharedParams(nn.Module):
...
@@ -714,21 +754,26 @@ class TransformerWithSharedParams(nn.Module):
d_model
=
d_model
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
8
,
dropout
=
0.1
,
d_model
=
d_model
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
8
,
dropout
=
0.1
,
)
)
self
.
output_proj
=
nn
.
Linear
(
d_model
,
d_vocab
)
self
.
output_proj
=
nn
.
Linear
(
d_model
,
d_vocab
)
# share the embedding and output projection weights
# share the embedding and output projection weights
self
.
output_proj
.
weight
=
self
.
embed_tokens
.
weight
self
.
output_proj
.
weight
=
self
.
embed_tokens
.
weight
self
.
register_buffer
(
"vocab_bias"
,
self
.
embed_tokens
.
weight
.
new_ones
((
d_model
,)))
self
.
register_buffer
(
"vocab_bias"
,
self
.
embed_tokens
.
weight
.
new_ones
((
d_model
,)))
self
.
register_buffer
(
"long_buffer"
,
torch
.
zeros_like
(
self
.
vocab_bias
,
dtype
=
torch
.
long
))
self
.
register_buffer
(
"long_buffer"
,
torch
.
zeros_like
(
self
.
vocab_bias
,
dtype
=
torch
.
long
))
self
.
bs
=
2
self
.
bn
=
torch
.
nn
.
BatchNorm1d
(
self
.
bs
)
if
add_bn
else
torch
.
nn
.
Identity
()
def
get_input
(
self
,
device
):
def
get_input
(
self
,
device
):
torch
.
manual_seed
(
1
+
self
.
rank
)
# keep everything deterministic
torch
.
manual_seed
(
1
+
self
.
rank
)
# keep everything deterministic
src
=
torch
.
arange
(
12
,
device
=
device
).
view
(
6
,
2
)
# T x B
src
=
torch
.
arange
(
12
,
device
=
device
).
view
(
6
,
self
.
bs
)
# T x B
tgt
=
torch
.
arange
(
8
,
device
=
device
).
view
(
4
,
2
)
# T x B
tgt
=
torch
.
arange
(
self
.
bs
*
4
,
device
=
device
).
view
(
4
,
self
.
bs
)
# T x B
return
(
src
,
tgt
)
return
(
src
,
tgt
)
def
forward
(
self
,
src_ids
,
tgt_ids
):
def
forward
(
self
,
src_ids
,
tgt_ids
):
src
=
self
.
embed_tokens
(
src_ids
)
src
=
self
.
embed_tokens
(
src_ids
)
src
=
src
+
self
.
vocab_bias
+
self
.
long_buffer
.
type_as
(
src
)
src
=
src
+
self
.
vocab_bias
+
self
.
long_buffer
.
type_as
(
src
)
tgt
=
self
.
embed_tokens
(
tgt_ids
)
tgt
=
self
.
embed_tokens
(
tgt_ids
)
tgt
=
self
.
bn
(
tgt
)
x
=
self
.
transformer
(
src
,
tgt
)
x
=
self
.
transformer
(
src
,
tgt
)
return
self
.
output_proj
(
x
)
return
self
.
output_proj
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment