Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
b36e01d5
Unverified
Commit
b36e01d5
authored
Mar 04, 2021
by
Sam Shleifer
Committed by
GitHub
Mar 04, 2021
Browse files
[feat] add buffer_dtype kwarg for more control of batchnorm (#458)
parent
103d33c1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
69 additions
and
16 deletions
+69
-16
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
+7
-3
fairscale/utils/testing.py
fairscale/utils/testing.py
+6
-2
tests/nn/data_parallel/test_fsdp.py
tests/nn/data_parallel/test_fsdp.py
+56
-11
No files found.
fairscale/nn/data_parallel/fully_sharded_data_parallel.py
View file @
b36e01d5
...
...
@@ -127,6 +127,8 @@ class FullyShardedDataParallel(nn.Module):
dtype for full parameters for computation. This defaults to
``torch.float32`` unless *``mixed_precision``* is set, in which case
it defaults to ``torch.float16``.
buffer_dtype (torch.dtype, Optional):
dtype for buffers for computation. This defaults to ``compute_dtype``.
move_grads_to_cpu (bool, Optional):
move gradient shard to CPU after reduction. This is useful when
combined with CPU-based optimizers. It defaults to the value of
...
...
@@ -150,6 +152,7 @@ class FullyShardedDataParallel(nn.Module):
flatten_parameters
:
bool
=
True
,
cpu_offload
:
bool
=
False
,
compute_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
buffer_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
move_grads_to_cpu
:
Optional
[
bool
]
=
None
,
bucket_cap_mb
:
int
=
25
,
):
...
...
@@ -163,6 +166,7 @@ class FullyShardedDataParallel(nn.Module):
self
.
flatten_parameters
=
flatten_parameters
self
.
cpu_offload
=
cpu_offload
self
.
compute_dtype
=
compute_dtype
or
(
torch
.
float16
if
mixed_precision
else
torch
.
float32
)
self
.
buffer_dtype
=
buffer_dtype
or
self
.
compute_dtype
self
.
move_grads_to_cpu
=
cpu_offload
if
move_grads_to_cpu
is
None
else
move_grads_to_cpu
self
.
bucket_cap_mb
=
bucket_cap_mb
...
...
@@ -446,7 +450,7 @@ class FullyShardedDataParallel(nn.Module):
if
self
.
mixed_precision
:
# In case we are in mixed precision, restore buffers back to fp16.
self
.
_all_buffers_to
(
dtype
=
self
.
compute
_dtype
)
self
.
_all_buffers_to
(
dtype
=
self
.
buffer
_dtype
)
return
state_dict
# TODO (Min): figuring out how to do typing for this overloaded function.
...
...
@@ -619,9 +623,9 @@ class FullyShardedDataParallel(nn.Module):
self
.
_setup_streams
()
if
self
.
cpu_offload
:
# Buffers stay on GPU, and don't get sharded
self
.
_all_buffers_to
(
device
=
torch
.
device
(
"cuda"
),
dtype
=
self
.
compute
_dtype
)
self
.
_all_buffers_to
(
device
=
torch
.
device
(
"cuda"
),
dtype
=
self
.
buffer
_dtype
)
else
:
self
.
_all_buffers_to
(
dtype
=
self
.
compute
_dtype
)
self
.
_all_buffers_to
(
dtype
=
self
.
buffer
_dtype
)
if
self
.
_is_root
:
# Don't free the full params for the outer-most (root) instance,
...
...
fairscale/utils/testing.py
View file @
b36e01d5
...
...
@@ -471,6 +471,7 @@ class DeviceAndTypeCheckModule(Base):
expected_param_device
:
Optional
[
torch
.
device
]
=
None
,
expected_loss_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
expected_loss_device
:
Optional
[
torch
.
device
]
=
None
,
expected_buffer_dtype
:
Optional
[
torch
.
device
]
=
None
,
):
super
().
__init__
()
self
.
expected_input_dtype
=
expected_input_dtype
...
...
@@ -479,8 +480,10 @@ class DeviceAndTypeCheckModule(Base):
self
.
expected_param_device
=
expected_param_device
self
.
expected_loss_dtype
=
expected_loss_dtype
self
.
expected_loss_device
=
expected_loss_device
self
.
expected_buffer_dtype
=
expected_buffer_dtype
self
.
linear
=
nn
.
Linear
(
5
,
5
)
self
.
register_buffer
(
"buffer"
,
torch
.
rand
((
5
,)))
def
_check
(
self
,
...
...
@@ -498,8 +501,9 @@ class DeviceAndTypeCheckModule(Base):
param
=
self
.
linear
.
weight
self
.
_check
(
"param.dtype"
,
param
.
dtype
,
self
.
expected_param_dtype
)
self
.
_check
(
"param.device"
,
param
.
device
,
self
.
expected_param_device
)
loss
=
self
.
linear
(
x
).
sum
()
self
.
_check
(
"buffer.dtype"
,
self
.
buffer
.
dtype
,
self
.
expected_buffer_dtype
)
# type: ignore
x
=
x
+
self
.
buffer
loss
=
(
self
.
linear
(
x
)
+
self
.
buffer
).
sum
()
self
.
_check
(
"loss.dtype"
,
loss
.
dtype
,
self
.
expected_loss_dtype
)
self
.
_check
(
"loss.device"
,
loss
.
device
,
self
.
expected_loss_device
)
...
...
tests/nn/data_parallel/test_fsdp.py
View file @
b36e01d5
...
...
@@ -110,6 +110,18 @@ class TestMixedPrecision(DistributedTest):
torch
.
float16
,
# expected_reduce_dtype
)
def
test_mixed_precision_autocast_buffer_type_fp32
(
self
):
"""If autocast enabled, loss should be fp32."""
self
.
_spawn_test_case
(
{
"mixed_precision"
:
True
,
"buffer_dtype"
:
torch
.
float32
},
True
,
# autocast enabled
torch
.
float16
,
# expected_input_dtype
torch
.
float16
,
# expected_param_dtype
torch
.
float32
,
# expected_loss_dtype
torch
.
float16
,
# expected_reduce_dtype
expected_buffer_type
=
torch
.
float32
,
)
def
test_mixed_precision_autocast_fp32_compute
(
self
):
self
.
_spawn_test_case
(
{
"mixed_precision"
:
True
,
"compute_dtype"
:
torch
.
float32
},
...
...
@@ -118,6 +130,7 @@ class TestMixedPrecision(DistributedTest):
torch
.
float32
,
# expected_param_dtype
torch
.
float32
,
# expected_loss_dtype
torch
.
float32
,
# expected_reduce_dtype
expected_buffer_type
=
torch
.
float32
,
)
def
test_fp32_reduce_scatter
(
self
):
...
...
@@ -128,6 +141,7 @@ class TestMixedPrecision(DistributedTest):
torch
.
float16
,
# expected_param_dtype
torch
.
float16
,
# expected_loss_dtype
torch
.
float32
,
# expected_reduce_dtype
expected_buffer_type
=
torch
.
float16
,
)
def
test_fp32_reduce_scatter_autocast
(
self
):
...
...
@@ -140,18 +154,42 @@ class TestMixedPrecision(DistributedTest):
torch
.
float32
,
# expected_reduce_dtype
)
def
_spawn_test_case
(
self
,
cfg
,
autocast_enabled
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
world_size
=
2
):
def
_spawn_test_case
(
self
,
cfg
,
autocast_enabled
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
expected_buffer_type
=
None
,
world_size
=
2
,
):
"""Call test_dtypes inside of torch.multiprocessing.spawn"""
fn
=
functools
.
partial
(
self
.
_test_dtypes
,
cfg
,
autocast_enabled
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
)
fn
=
functools
.
partial
(
self
.
_test_dtypes
,
cfg
,
autocast_enabled
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
expected_buffer_type
=
expected_buffer_type
,
)
spawn_and_init
(
fn
,
world_sizes
=
[
world_size
])
@
staticmethod
def
_test_dtypes
(
cfg
:
Dict
,
autocast
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
rank
,
group
):
def
_test_dtypes
(
cfg
:
Dict
,
autocast
,
in_dtype
,
p_dtype
,
loss_dtype
,
reduce_dtype
,
rank
,
group
,
expected_buffer_type
=
None
):
# Patch torch.distributed.reduce_scatter to check the dtype of the reduction
orig_reduce_scatter
=
torch
.
distributed
.
reduce_scatter
model
:
nn
.
Module
=
DeviceAndTypeCheckModule
(
expected_input_dtype
=
in_dtype
,
expected_param_dtype
=
p_dtype
,
expected_loss_dtype
=
loss_dtype
,
expected_input_dtype
=
in_dtype
,
expected_param_dtype
=
p_dtype
,
expected_loss_dtype
=
loss_dtype
,
expected_buffer_dtype
=
expected_buffer_type
,
)
def
_reduce_scatter
(
output
,
input_list
,
**
kwargs
):
...
...
@@ -265,7 +303,7 @@ class TestComparisonToPyTorchDDP(DistributedTest):
def
_test_identical_outputs
(
cls
,
model_init_fn
,
config
,
rank
,
group
,
num_steps
=
2
,
use_cuda
=
True
,
lr
=
0.01
,
ref_ddp_fn
=
None
,
norm_type
=
2
,
):
if
config
[
"mixed_precision"
]
:
if
config
.
get
(
"mixed_precision"
,
False
)
:
autocast
=
True
# Force the compute dtype to be torch.float32 so that we get
# identical results as PyTorch DDP when using autocast. Note that
...
...
@@ -399,7 +437,9 @@ class TestLocalStateDict(DistributedTest):
@
classmethod
def
_load_local_and_train
(
self
,
config
,
rank
,
group
,
d_model
=
16
,
d_vocab
=
23
):
"""Check that local_state_dict can be saved and loaded for a given worker, and that training updates it"""
model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
d_vocab
=
d_vocab
,
d_model
=
d_model
)
model
=
self
.
get_wrapped_model
(
group
,
cuda_first
=
False
,
config
=
config
,
d_vocab
=
d_vocab
,
d_model
=
d_model
,
add_bn
=
False
)
# Set bn=True here to show that BN doesn't get updated
state_1
=
model
.
local_state_dict
()
state_before_training
=
{
k
:
v
.
cpu
().
clone
()
for
k
,
v
in
state_1
.
items
()}
assert
len
(
state_1
)
>
0
...
...
@@ -639,7 +679,7 @@ class TestNoSync(DistributedTest):
def
test_no_sync_before_first_forward
(
self
):
group
=
DummyProcessGroup
(
rank
=
0
,
size
=
1
)
model
=
self
.
get_wrapped_model
(
group
,
config
=
{})
model
=
self
.
get_wrapped_model
(
group
,
config
=
{}
,
add_bn
=
False
)
batch
=
model
.
module
.
get_input
(
torch
.
device
(
"cuda"
))
with
model
.
no_sync
():
output
=
model
(
*
batch
)
...
...
@@ -651,7 +691,7 @@ class TestNoSync(DistributedTest):
@
classmethod
def
_test_transformer
(
self
,
rank
,
group
,
config
):
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
)
model
=
self
.
get_wrapped_model
(
group
,
config
=
config
,
add_bn
=
False
)
model
.
eval
()
# turn off dropout for the test
self
.
_test_no_sync
(
model
,
batch_dim
=
1
)
...
...
@@ -703,7 +743,7 @@ class TestNoSync(DistributedTest):
class
TransformerWithSharedParams
(
nn
.
Module
):
def
__init__
(
self
,
group
,
*
unused_args
,
d_vocab
=
23
,
d_model
=
16
,
**
unused_kwargs
):
def
__init__
(
self
,
group
,
*
unused_args
,
d_vocab
=
23
,
d_model
=
16
,
add_bn
=
True
,
**
unused_kwargs
):
super
().
__init__
()
self
.
rank
=
group
.
rank
()
self
.
world_size
=
group
.
size
()
...
...
@@ -714,21 +754,26 @@ class TransformerWithSharedParams(nn.Module):
d_model
=
d_model
,
num_encoder_layers
=
2
,
num_decoder_layers
=
2
,
dim_feedforward
=
8
,
dropout
=
0.1
,
)
self
.
output_proj
=
nn
.
Linear
(
d_model
,
d_vocab
)
# share the embedding and output projection weights
self
.
output_proj
.
weight
=
self
.
embed_tokens
.
weight
self
.
register_buffer
(
"vocab_bias"
,
self
.
embed_tokens
.
weight
.
new_ones
((
d_model
,)))
self
.
register_buffer
(
"long_buffer"
,
torch
.
zeros_like
(
self
.
vocab_bias
,
dtype
=
torch
.
long
))
self
.
bs
=
2
self
.
bn
=
torch
.
nn
.
BatchNorm1d
(
self
.
bs
)
if
add_bn
else
torch
.
nn
.
Identity
()
def
get_input
(
self
,
device
):
torch
.
manual_seed
(
1
+
self
.
rank
)
# keep everything deterministic
src
=
torch
.
arange
(
12
,
device
=
device
).
view
(
6
,
2
)
# T x B
tgt
=
torch
.
arange
(
8
,
device
=
device
).
view
(
4
,
2
)
# T x B
src
=
torch
.
arange
(
12
,
device
=
device
).
view
(
6
,
self
.
bs
)
# T x B
tgt
=
torch
.
arange
(
self
.
bs
*
4
,
device
=
device
).
view
(
4
,
self
.
bs
)
# T x B
return
(
src
,
tgt
)
def
forward
(
self
,
src_ids
,
tgt_ids
):
src
=
self
.
embed_tokens
(
src_ids
)
src
=
src
+
self
.
vocab_bias
+
self
.
long_buffer
.
type_as
(
src
)
tgt
=
self
.
embed_tokens
(
tgt_ids
)
tgt
=
self
.
bn
(
tgt
)
x
=
self
.
transformer
(
src
,
tgt
)
return
self
.
output_proj
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment