Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
e3865549
Unverified
Commit
e3865549
authored
Mar 19, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 19, 2021
Browse files
[feat][refactor][OSS] Param buckets + fp16 broadcasts (#540)
* param buckets * unifying the buckets
parent
195d62f1
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
359 additions
and
39 deletions
+359
-39
fairscale/nn/misc/__init__.py
fairscale/nn/misc/__init__.py
+1
-1
fairscale/nn/misc/grad_bucket.py
fairscale/nn/misc/grad_bucket.py
+7
-2
fairscale/nn/misc/param_bucket.py
fairscale/nn/misc/param_bucket.py
+244
-0
fairscale/optim/oss.py
fairscale/optim/oss.py
+36
-28
tests/ci_test_list_2.txt
tests/ci_test_list_2.txt
+1
-0
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_grad_bucket.py
+1
-1
tests/nn/misc/test_param_bucket.py
tests/nn/misc/test_param_bucket.py
+57
-0
tests/optim/test_oss.py
tests/optim/test_oss.py
+12
-7
No files found.
fairscale/nn/misc/__init__.py
View file @
e3865549
...
@@ -5,4 +5,4 @@
...
@@ -5,4 +5,4 @@
from
.checkpoint_activations
import
checkpoint_wrapper
from
.checkpoint_activations
import
checkpoint_wrapper
from
.flatten_params_wrapper
import
FlattenParamsWrapper
from
.flatten_params_wrapper
import
FlattenParamsWrapper
from
.
g
ra
d
_bucket
import
GradBucket
from
.
pa
ra
m
_bucket
import
GradBucket
,
ParamBucket
fairscale/nn/misc/grad_bucket.py
View file @
e3865549
...
@@ -16,6 +16,7 @@ class GradBucket:
...
@@ -16,6 +16,7 @@ class GradBucket:
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
destination
:
int
)
->
None
:
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
destination
:
int
)
->
None
:
self
.
_max_size
=
size
self
.
_max_size
=
size
self
.
_params
:
List
[
torch
.
Tensor
]
=
[]
self
.
_params
:
List
[
torch
.
Tensor
]
=
[]
self
.
_param_ids
:
List
[
int
]
=
[]
self
.
_fill
=
0
self
.
_fill
=
0
self
.
_is_collapsed
=
False
self
.
_is_collapsed
=
False
...
@@ -39,9 +40,9 @@ class GradBucket:
...
@@ -39,9 +40,9 @@ class GradBucket:
return
len
(
self
.
_params
)
==
self
.
params_checked_in
return
len
(
self
.
_params
)
==
self
.
params_checked_in
def
can_add_grad_view
(
self
,
param
:
torch
.
Tensor
)
->
bool
:
def
can_add_grad_view
(
self
,
param
:
torch
.
Tensor
)
->
bool
:
""" Is there enough room in the bucket to add this parameter gradient ?
""" Is there enough room in the bucket to add this parameter gradient
, and is this param not already checked in
?
"""
"""
return
self
.
_fill
+
param
.
numel
()
<
self
.
_max_size
return
self
.
_fill
+
param
.
numel
()
<
self
.
_max_size
and
id
(
param
)
not
in
self
.
_param_ids
def
to
(
# type: ignore
def
to
(
# type: ignore
self
,
self
,
...
@@ -70,11 +71,15 @@ class GradBucket:
...
@@ -70,11 +71,15 @@ class GradBucket:
"""
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
"""
assert
id
(
param
)
not
in
self
.
_param_ids
,
"The same gradients cannot be checked in twice"
if
param
.
grad
is
None
:
if
param
.
grad
is
None
:
param
.
grad
=
torch
.
zeros_like
(
param
)
param
.
grad
=
torch
.
zeros_like
(
param
)
self
.
_add_grad_as_view
(
param
)
self
.
_add_grad_as_view
(
param
)
self
.
_params
.
append
(
param
)
self
.
_params
.
append
(
param
)
self
.
_param_ids
.
append
(
id
(
param
))
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
collapse
(
self
)
->
None
:
def
collapse
(
self
)
->
None
:
...
...
fairscale/nn/misc/param_bucket.py
0 → 100644
View file @
e3865549
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
import
torch
class
Bucket
:
"""
Helper class to simplify the handling of buckets, which unify the underlying storage of multiple tensors
"""
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
None
:
self
.
_params
:
List
[
torch
.
Tensor
]
=
[]
self
.
_param_ids
:
List
[
int
]
=
[]
self
.
_fill
=
0
# The actual flat tensor
self
.
buffer
:
torch
.
Tensor
=
torch
.
zeros
(
size
,
dtype
=
dtype
,
device
=
device
)
def
to
(
# type: ignore
self
,
device
:
Optional
[
Union
[
int
,
torch
.
device
]],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
non_blocking
:
bool
=
False
,
keep_param_alignment
:
bool
=
True
,
)
->
"ParamBucket"
:
"""
Move the underlying buffer
"""
assert
self
.
buffer
is
not
None
,
"Cannot move a collapsed bucket, please rebuild it"
self
.
buffer
.
to
(
device
,
dtype
,
non_blocking
)
class
ParamBucket
(
Bucket
):
"""
Helper class to simplify the handling of parameter buckets
"""
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
None
:
super
().
__init__
(
size
,
dtype
,
device
)
def
to
(
# type: ignore
self
,
device
:
Optional
[
Union
[
int
,
torch
.
device
]],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
non_blocking
:
bool
=
False
,
keep_param_alignment
:
bool
=
True
,
)
->
"ParamBucket"
:
"""
Move the underlying buffer
"""
super
().
to
(
device
,
dtype
,
non_blocking
)
if
keep_param_alignment
:
self
.
_reattach_params
()
@
torch
.
no_grad
()
def
add_param
(
self
,
param
:
torch
.
Tensor
)
->
None
:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert
id
(
param
)
not
in
self
.
_param_ids
,
"The same param cannot be checked in twice"
self
.
_add_param_as_view
(
param
)
self
.
_params
.
append
(
param
)
self
.
_param_ids
.
append
(
id
(
param
))
@
torch
.
no_grad
()
def
_add_param_as_view
(
self
,
param
:
torch
.
Tensor
,
keep_existing_value
:
bool
=
True
)
->
None
:
assert
self
.
buffer
is
not
None
assert
param
.
dtype
==
self
.
buffer
.
dtype
assert
param
.
device
==
self
.
buffer
.
device
fill_next
=
self
.
_fill
+
param
.
numel
()
assert
fill_next
<=
self
.
buffer
.
numel
()
# Copy the current param value
if
keep_existing_value
:
self
.
buffer
[
self
.
_fill
:
fill_next
].
copy_
(
param
.
data
.
flatten
())
param
.
data
=
self
.
buffer
[
self
.
_fill
:
fill_next
].
view_as
(
param
.
data
)
self
.
_fill
=
fill_next
@
torch
.
no_grad
()
def
_reattach_params
(
self
)
->
None
:
"""
Given the parameters which have been registered previously, rebuild the whole bucket
"""
assert
len
(
self
.
_params
)
>
0
self
.
_fill
=
0
for
p
in
self
.
_params
:
self
.
_add_param_as_view
(
p
,
keep_existing_value
=
False
)
class
GradBucket
(
Bucket
):
"""
Helper class to simplify the handling of gradient buckets
"""
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
destination
:
int
)
->
None
:
super
().
__init__
(
size
,
dtype
,
device
)
self
.
_max_size
=
size
self
.
_is_collapsed
=
False
self
.
params_checked_in
=
0
self
.
destination
=
destination
self
.
sent
=
True
self
.
callback
:
Optional
[
Callable
[[
Any
],
None
]]
=
None
def
reset_checked_in
(
self
)
->
None
:
""" Reset the counter of the parameter grads which have been checked in
"""
self
.
params_checked_in
=
0
self
.
sent
=
False
@
property
def
all_checked_in
(
self
)
->
bool
:
""" Have all the expected gradient check-in happened ?"""
return
len
(
self
.
_params
)
==
self
.
params_checked_in
def
can_add_grad_view
(
self
,
param
:
torch
.
Tensor
)
->
bool
:
""" Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?
"""
return
self
.
_fill
+
param
.
numel
()
<
self
.
_max_size
and
id
(
param
)
not
in
self
.
_param_ids
def
to
(
# type: ignore
self
,
device
:
Optional
[
Union
[
int
,
torch
.
device
]],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
non_blocking
:
bool
=
False
,
keep_param_alignment
:
bool
=
True
,
)
->
"GradBucket"
:
"""
Move the underlying buffer
"""
if
self
.
_is_collapsed
:
self
.
rebuild
()
super
().
to
(
device
,
dtype
,
non_blocking
)
if
keep_param_alignment
:
self
.
_reattach_grads
()
def
zero
(
self
)
->
None
:
"""
Set all the grads to zero
"""
self
.
buffer
.
fill_
(
0.0
)
@
torch
.
no_grad
()
def
add_grad
(
self
,
param
:
torch
.
Tensor
)
->
None
:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert
id
(
param
)
not
in
self
.
_param_ids
,
"The same gradients cannot be checked in twice"
if
param
.
grad
is
None
:
param
.
grad
=
torch
.
zeros_like
(
param
)
self
.
_add_grad_as_view
(
param
)
self
.
_params
.
append
(
param
)
self
.
_param_ids
.
append
(
id
(
param
))
@
torch
.
no_grad
()
def
collapse
(
self
)
->
None
:
"""
Release the buffer from memory. The bucket will need to be rebuilt before use
"""
if
not
self
.
_is_collapsed
:
for
p
in
self
.
_params
:
assert
p
.
grad
is
not
None
p
.
grad
.
detach_
()
p
.
grad
=
None
self
.
buffer
=
torch
.
zeros
(
0
,
dtype
=
self
.
buffer
.
dtype
,
device
=
self
.
buffer
.
device
)
self
.
_fill
=
0
self
.
params_checked_in
=
0
self
.
_is_collapsed
=
True
@
torch
.
no_grad
()
def
rebuild
(
self
)
->
None
:
"""
Given the parameter gradients which have been registered previously, rebuild the whole bucket
"""
assert
len
(
self
.
_params
)
>
0
if
self
.
_is_collapsed
:
self
.
buffer
=
torch
.
zeros
(
self
.
_max_size
,
dtype
=
self
.
_params
[
0
].
dtype
,
device
=
self
.
_params
[
0
].
device
)
for
p
in
self
.
_params
:
self
.
_add_grad_as_view
(
p
)
self
.
_is_collapsed
=
False
@
torch
.
no_grad
()
def
shrink
(
self
)
->
None
:
"""
Shrink the buffer to the size of the parameter gradients currently checked in, release the extra memory
"""
assert
self
.
buffer
.
numel
()
>
0
,
"Cannot shrink a collapsed bucket, please rebuild"
self
.
buffer
=
self
.
buffer
.
resize_
(
self
.
_fill
).
clone
()
self
.
_fill
=
0
for
p
in
self
.
_params
:
self
.
_add_grad_as_view
(
p
)
self
.
_max_size
=
self
.
_fill
@
torch
.
no_grad
()
def
_reattach_grads
(
self
)
->
None
:
"""
Given the parameters gradients which have been registered previously, rebuild the whole bucket
"""
assert
len
(
self
.
_params
)
>
0
self
.
_fill
=
0
for
p
in
self
.
_params
:
self
.
_add_grad_as_view
(
p
,
keep_existing_value
=
False
)
@
torch
.
no_grad
()
def
_add_grad_as_view
(
self
,
param
:
torch
.
Tensor
,
keep_existing_value
:
bool
=
True
)
->
None
:
assert
self
.
buffer
.
numel
()
>
0
,
"Cannot add a gradient to a collapsed bucket, please rebuild"
assert
param
.
dtype
==
self
.
buffer
.
dtype
assert
param
.
device
==
self
.
buffer
.
device
fill_next
=
self
.
_fill
+
param
.
numel
()
assert
fill_next
<=
self
.
buffer
.
numel
()
# Copy the current grad value, if any
if
param
.
grad
is
not
None
:
# keep param.grad in place
if
keep_existing_value
:
self
.
buffer
[
self
.
_fill
:
fill_next
].
copy_
(
param
.
grad
.
data
.
flatten
())
param
.
grad
.
data
=
self
.
buffer
[
self
.
_fill
:
fill_next
].
view_as
(
param
.
data
)
else
:
param
.
grad
=
self
.
buffer
[
self
.
_fill
:
fill_next
].
view_as
(
param
.
data
)
self
.
_fill
=
fill_next
fairscale/optim/oss.py
View file @
e3865549
...
@@ -15,6 +15,8 @@ import torch.distributed as dist
...
@@ -15,6 +15,8 @@ import torch.distributed as dist
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
torch.optim
import
SGD
,
Optimizer
from
torch.optim
import
SGD
,
Optimizer
from
fairscale.nn.misc
import
ParamBucket
from
.utils
import
broadcast_object
,
calc_grad_norm
,
recursive_copy_to_device
from
.utils
import
broadcast_object
,
calc_grad_norm
,
recursive_copy_to_device
__all__
=
[
"OSS"
]
__all__
=
[
"OSS"
]
...
@@ -52,6 +54,10 @@ class OSS(Optimizer):
...
@@ -52,6 +54,10 @@ class OSS(Optimizer):
torch.distributed group (default: group.WORLD)
torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int):
broadcast_buffer_size (int):
(deprecated) used to cap the size of the broadcast buffers, not being used anymore.
(deprecated) used to cap the size of the broadcast buffers, not being used anymore.
broadcast_fp16 (bool):
Compress the model shards in fp16 before sharing them in between ranks.
This is safe to use when PyTorch AMP is activated. Without torch AMP this will lead to a slight
degradation in terms of accuracy.
.. warning: the communication patterns that OSS use depend on the "trainability" graph,
.. warning: the communication patterns that OSS use depend on the "trainability" graph,
...
@@ -73,6 +79,7 @@ class OSS(Optimizer):
...
@@ -73,6 +79,7 @@ class OSS(Optimizer):
optim
:
Type
[
Optimizer
]
=
SGD
,
optim
:
Type
[
Optimizer
]
=
SGD
,
group
:
Optional
[
Any
]
=
None
,
group
:
Optional
[
Any
]
=
None
,
broadcast_buffer_size
:
int
=
-
1
,
broadcast_buffer_size
:
int
=
-
1
,
broadcast_fp16
:
bool
=
False
,
**
default
:
Any
,
**
default
:
Any
,
):
):
...
@@ -99,7 +106,8 @@ class OSS(Optimizer):
...
@@ -99,7 +106,8 @@ class OSS(Optimizer):
self
.
global_rank
=
self
.
get_global_rank
(
self
.
group
,
self
.
rank
)
self
.
global_rank
=
self
.
get_global_rank
(
self
.
group
,
self
.
rank
)
self
.
_local_to_global_rank
=
[
self
.
get_global_rank
(
self
.
group
,
i
)
for
i
in
range
(
self
.
world_size
)]
self
.
_local_to_global_rank
=
[
self
.
get_global_rank
(
self
.
group
,
i
)
for
i
in
range
(
self
.
world_size
)]
self
.
buckets
:
Dict
[
torch
.
device
,
List
[
torch
.
Tensor
]]
=
{}
self
.
broadcast_fp16
=
broadcast_fp16
self
.
buckets
:
Dict
[
torch
.
device
,
Dict
[
int
,
ParamBucket
]]
=
{}
self
.
_all_states
:
List
[
Dict
[
str
,
Any
]]
=
[]
# Optional consolidated optimizer state
self
.
_all_states
:
List
[
Dict
[
str
,
Any
]]
=
[]
# Optional consolidated optimizer state
self
.
_default_device
=
torch
.
device
(
"cpu"
)
self
.
_default_device
=
torch
.
device
(
"cpu"
)
...
@@ -542,21 +550,32 @@ class OSS(Optimizer):
...
@@ -542,21 +550,32 @@ class OSS(Optimizer):
work_handles
=
[]
# Work handles are consumed within this scope, no callback
work_handles
=
[]
# Work handles are consumed within this scope, no callback
# Populate the fp16 shards
if
self
.
broadcast_fp16
:
for
device
in
self
.
buckets
.
keys
():
for
device
in
self
.
buckets
.
keys
():
for
src_rank
,
bucket
in
enumerate
(
self
.
buckets
[
device
]):
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
if
bucket
.
numel
()
>
0
:
bucket
.
to
(
dtype
=
torch
.
float16
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
False
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
# Exchange all the shards with the other ranks
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
work_handles
.
append
(
work_handles
.
append
(
dist
.
broadcast
(
dist
.
broadcast
(
tensor
=
bucket
,
src
=
self
.
_local_to_global_rank
[
src
_rank
],
group
=
self
.
group
,
async_op
=
True
tensor
=
bucket
.
buffer
,
src
=
self
.
_local_to_global_rank
[
dst
_rank
],
group
=
self
.
group
,
async_op
=
True
,
)
)
)
)
# Only check on the last handle, they're all inlined on the same CUDA stream
if
work_handles
and
self
.
backend
==
dist
.
Backend
.
NCCL
:
work_handles
[
-
1
].
wait
()
else
:
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
# Populate back the fp32 shards
if
self
.
broadcast_fp16
:
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
in
self
.
buckets
[
device
].
keys
():
bucket
.
to
(
dtype
=
torch
.
float32
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
True
)
def
_setup_flat_buffers
(
self
)
->
None
:
def
_setup_flat_buffers
(
self
)
->
None
:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and
This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and
...
@@ -567,7 +586,7 @@ class OSS(Optimizer):
...
@@ -567,7 +586,7 @@ class OSS(Optimizer):
# Only wipe the existing buckets if there are none
# Only wipe the existing buckets if there are none
# (could be that this is called twice, when trainability changes)
# (could be that this is called twice, when trainability changes)
if
device
not
in
self
.
buckets
.
keys
():
if
device
not
in
self
.
buckets
.
keys
():
self
.
buckets
[
device
]
=
[]
self
.
buckets
[
device
]
=
{}
# Make parameters a view of the bucket
# Make parameters a view of the bucket
for
dst_rank
,
params
in
enumerate
(
per_rank_params
):
for
dst_rank
,
params
in
enumerate
(
per_rank_params
):
...
@@ -580,23 +599,12 @@ class OSS(Optimizer):
...
@@ -580,23 +599,12 @@ class OSS(Optimizer):
# Merge all the trainable params in a single bucket
# Merge all the trainable params in a single bucket
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
params
))
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
params
))
buffer_size
=
sum
(
map
(
lambda
x
:
x
.
numel
(),
trainable_params
))
buffer_size
=
sum
(
map
(
lambda
x
:
x
.
numel
(),
trainable_params
))
bucket
=
torch
.
empty
(
buffer_size
,
dtype
=
params
[
0
].
dtype
,
device
=
device
)
bucket
=
ParamBucket
(
size
=
buffer_size
,
dtype
=
params
[
0
].
dtype
,
device
=
device
)
offset
=
0
for
param
in
trainable_params
:
for
param
in
trainable_params
:
offset_next
=
offset
+
param
.
numel
()
bucket
.
add_param
(
param
)
bucket
[
offset
:
offset_next
].
copy_
(
param
.
data
.
flatten
())
param
.
data
=
bucket
[
offset
:
offset_next
].
view_as
(
param
.
data
)
offset
=
offset_next
# Either replace the existing bucket, or create it
if
len
(
self
.
buckets
[
device
])
==
dst_rank
:
self
.
buckets
[
device
].
append
(
bucket
)
else
:
self
.
buckets
[
device
][
dst_rank
]
=
bucket
self
.
buckets
[
device
][
dst_rank
]
=
bucket
else
:
# This rank has an empty shard, that's fine
self
.
buckets
[
device
].
append
(
torch
.
zeros
(
0
,
device
=
device
))
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use
=
list
(
self
.
per_device_params
.
keys
())
devices_in_use
=
list
(
self
.
per_device_params
.
keys
())
...
...
tests/ci_test_list_2.txt
View file @
e3865549
...
@@ -5,6 +5,7 @@ tests/utils/test_state_dict.py
...
@@ -5,6 +5,7 @@ tests/utils/test_state_dict.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
tests/nn/wrap/test_wrap.py
tests/nn/wrap/test_wrap.py
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_transparency.py
tests/nn/pipe_process/test_transparency.py
...
...
tests/nn/misc/test_grad_bucket.py
View file @
e3865549
...
@@ -55,7 +55,7 @@ def test_collapse():
...
@@ -55,7 +55,7 @@ def test_collapse():
bucket
.
shrink
()
bucket
.
shrink
()
bucket
.
collapse
()
bucket
.
collapse
()
assert
bucket
.
buffer
is
None
assert
bucket
.
buffer
.
numel
()
==
0
assert
param
.
grad
is
None
assert
param
.
grad
is
None
bucket
.
rebuild
()
bucket
.
rebuild
()
...
...
tests/nn/misc/test_param_bucket.py
0 → 100644
View file @
e3865549
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
pytest
import
torch
from
fairscale.nn.misc
import
ParamBucket
def
test_param_values_conserved
():
param
=
torch
.
rand
((
2
,
3
))
bucket
=
ParamBucket
(
10
,
param
.
dtype
,
param
.
device
)
param_
=
param
.
clone
()
bucket
.
add_param
(
param_
)
torch
.
allclose
(
param
,
param_
)
def
test_max_size
():
param
=
torch
.
rand
((
20
,
30
))
bucket
=
ParamBucket
(
5
,
param
.
dtype
,
param
.
device
)
with
pytest
.
raises
(
AssertionError
):
bucket
.
add_param
(
param
)
def
test_double_check_int
():
param
=
torch
.
rand
((
5
,
6
))
bucket
=
ParamBucket
(
300
,
param
.
dtype
,
param
.
device
)
bucket
.
add_param
(
param
)
with
pytest
.
raises
(
AssertionError
):
bucket
.
add_param
(
param
)
def
test_type_change
():
size
=
(
5
,
6
)
param
=
torch
.
rand
(
size
,
requires_grad
=
True
)
param_
=
param
.
clone
()
bucket
=
ParamBucket
(
30
,
param
.
dtype
,
param
.
device
)
bucket
.
add_param
(
param
)
# Move the bucket to fp16 and back
bucket
.
to
(
dtype
=
torch
.
float16
,
device
=
param
.
device
)
bucket
.
to
(
dtype
=
torch
.
float32
,
device
=
param
.
device
,
keep_param_alignment
=
True
)
# Same with the reference tensor
param_
.
to
(
dtype
=
torch
.
float16
)
param_
.
to
(
dtype
=
torch
.
float32
)
torch
.
allclose
(
param
,
param_
)
tests/optim/test_oss.py
View file @
e3865549
...
@@ -484,7 +484,7 @@ def test_collect_shards():
...
@@ -484,7 +484,7 @@ def test_collect_shards():
)
)
def
run_test_reproducibility
(
rank
,
world_size
,
tempfile_name
):
def
run_test_reproducibility
(
rank
,
world_size
,
tempfile_name
,
broadcast_fp16
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
dist_init
(
rank
,
world_size
,
tempfile_name
)
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
torch
.
cuda
.
set_device
(
rank
)
torch
.
cuda
.
set_device
(
rank
)
...
@@ -501,7 +501,7 @@ def run_test_reproducibility(rank, world_size, tempfile_name):
...
@@ -501,7 +501,7 @@ def run_test_reproducibility(rank, world_size, tempfile_name):
loss_fn
=
torch
.
nn
.
L1Loss
()
loss_fn
=
torch
.
nn
.
L1Loss
()
loss_fn
.
to
(
device
)
loss_fn
.
to
(
device
)
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
optim
=
torch
.
optim
.
RMSprop
,
lr
=
0.1
)
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
optim
=
torch
.
optim
.
RMSprop
,
lr
=
0.1
,
broadcast_fp16
=
broadcast_fp16
)
def
closure
():
def
closure
():
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
...
@@ -534,12 +534,13 @@ def run_test_reproducibility(rank, world_size, tempfile_name):
...
@@ -534,12 +534,13 @@ def run_test_reproducibility(rank, world_size, tempfile_name):
@
skip_if_single_gpu
@
skip_if_single_gpu
def
test_reproducibility
():
@
pytest
.
mark
.
parametrize
(
"broadcast_fp16"
,
[
False
,
True
])
def
test_reproducibility
(
broadcast_fp16
:
bool
):
world_size
=
2
world_size
=
2
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
mp
.
spawn
(
run_test_reproducibility
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
run_test_reproducibility
,
args
=
(
world_size
,
temp_file_name
,
broadcast_fp16
),
nprocs
=
world_size
,
join
=
True
,
)
)
...
@@ -810,7 +811,7 @@ def test_state_dict_distributed():
...
@@ -810,7 +811,7 @@ def test_state_dict_distributed():
)
)
def
run_ddp_parity
(
rank
,
world_size
,
backend
,
temp_file_name
,
change_train_graph
):
def
run_ddp_parity
(
rank
,
world_size
,
backend
,
temp_file_name
,
change_train_graph
,
broadcast_fp16
):
url
=
"file://"
+
temp_file_name
url
=
"file://"
+
temp_file_name
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
@@ -937,9 +938,13 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph
...
@@ -937,9 +938,13 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph
@
skip_if_single_gpu
@
skip_if_single_gpu
@
pytest
.
mark
.
parametrize
(
"change_train_graph"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"change_train_graph"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
dist
.
Backend
.
NCCL
,
dist
.
Backend
.
GLOO
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
dist
.
Backend
.
NCCL
,
dist
.
Backend
.
GLOO
])
def
test_ddp_parity
(
change_train_graph
:
bool
,
backend
:
dist
.
Backend
):
@
pytest
.
mark
.
parametrize
(
"broadcast_fp16"
,
[
False
,
True
])
def
test_ddp_parity
(
change_train_graph
:
bool
,
backend
:
dist
.
Backend
,
broadcast_fp16
:
bool
):
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
world_size
=
torch
.
cuda
.
device_count
()
world_size
=
torch
.
cuda
.
device_count
()
mp
.
spawn
(
mp
.
spawn
(
run_ddp_parity
,
args
=
(
world_size
,
backend
,
temp_file_name
,
change_train_graph
),
nprocs
=
world_size
,
join
=
True
run_ddp_parity
,
args
=
(
world_size
,
backend
,
temp_file_name
,
change_train_graph
,
broadcast_fp16
),
nprocs
=
world_size
,
join
=
True
,
)
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment