Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
e3865549
Unverified
Commit
e3865549
authored
Mar 19, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Mar 19, 2021
Browse files
[feat][refactor][OSS] Param buckets + fp16 broadcasts (#540)
* param buckets * unifying the buckets
parent
195d62f1
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
359 additions
and
39 deletions
+359
-39
fairscale/nn/misc/__init__.py
fairscale/nn/misc/__init__.py
+1
-1
fairscale/nn/misc/grad_bucket.py
fairscale/nn/misc/grad_bucket.py
+7
-2
fairscale/nn/misc/param_bucket.py
fairscale/nn/misc/param_bucket.py
+244
-0
fairscale/optim/oss.py
fairscale/optim/oss.py
+36
-28
tests/ci_test_list_2.txt
tests/ci_test_list_2.txt
+1
-0
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_grad_bucket.py
+1
-1
tests/nn/misc/test_param_bucket.py
tests/nn/misc/test_param_bucket.py
+57
-0
tests/optim/test_oss.py
tests/optim/test_oss.py
+12
-7
No files found.
fairscale/nn/misc/__init__.py
View file @
e3865549
...
...
@@ -5,4 +5,4 @@
from
.checkpoint_activations
import
checkpoint_wrapper
from
.flatten_params_wrapper
import
FlattenParamsWrapper
from
.
g
ra
d
_bucket
import
GradBucket
from
.
pa
ra
m
_bucket
import
GradBucket
,
ParamBucket
fairscale/nn/misc/grad_bucket.py
View file @
e3865549
...
...
@@ -16,6 +16,7 @@ class GradBucket:
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
destination
:
int
)
->
None
:
self
.
_max_size
=
size
self
.
_params
:
List
[
torch
.
Tensor
]
=
[]
self
.
_param_ids
:
List
[
int
]
=
[]
self
.
_fill
=
0
self
.
_is_collapsed
=
False
...
...
@@ -39,9 +40,9 @@ class GradBucket:
return
len
(
self
.
_params
)
==
self
.
params_checked_in
def
can_add_grad_view
(
self
,
param
:
torch
.
Tensor
)
->
bool
:
""" Is there enough room in the bucket to add this parameter gradient ?
""" Is there enough room in the bucket to add this parameter gradient
, and is this param not already checked in
?
"""
return
self
.
_fill
+
param
.
numel
()
<
self
.
_max_size
return
self
.
_fill
+
param
.
numel
()
<
self
.
_max_size
and
id
(
param
)
not
in
self
.
_param_ids
def
to
(
# type: ignore
self
,
...
...
@@ -70,11 +71,15 @@ class GradBucket:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert
id
(
param
)
not
in
self
.
_param_ids
,
"The same gradients cannot be checked in twice"
if
param
.
grad
is
None
:
param
.
grad
=
torch
.
zeros_like
(
param
)
self
.
_add_grad_as_view
(
param
)
self
.
_params
.
append
(
param
)
self
.
_param_ids
.
append
(
id
(
param
))
@
torch
.
no_grad
()
def
collapse
(
self
)
->
None
:
...
...
fairscale/nn/misc/param_bucket.py
0 → 100644
View file @
e3865549
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
import
torch
class
Bucket
:
"""
Helper class to simplify the handling of buckets, which unify the underlying storage of multiple tensors
"""
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
None
:
self
.
_params
:
List
[
torch
.
Tensor
]
=
[]
self
.
_param_ids
:
List
[
int
]
=
[]
self
.
_fill
=
0
# The actual flat tensor
self
.
buffer
:
torch
.
Tensor
=
torch
.
zeros
(
size
,
dtype
=
dtype
,
device
=
device
)
def
to
(
# type: ignore
self
,
device
:
Optional
[
Union
[
int
,
torch
.
device
]],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
non_blocking
:
bool
=
False
,
keep_param_alignment
:
bool
=
True
,
)
->
"ParamBucket"
:
"""
Move the underlying buffer
"""
assert
self
.
buffer
is
not
None
,
"Cannot move a collapsed bucket, please rebuild it"
self
.
buffer
.
to
(
device
,
dtype
,
non_blocking
)
class
ParamBucket
(
Bucket
):
"""
Helper class to simplify the handling of parameter buckets
"""
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
None
:
super
().
__init__
(
size
,
dtype
,
device
)
def
to
(
# type: ignore
self
,
device
:
Optional
[
Union
[
int
,
torch
.
device
]],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
non_blocking
:
bool
=
False
,
keep_param_alignment
:
bool
=
True
,
)
->
"ParamBucket"
:
"""
Move the underlying buffer
"""
super
().
to
(
device
,
dtype
,
non_blocking
)
if
keep_param_alignment
:
self
.
_reattach_params
()
@
torch
.
no_grad
()
def
add_param
(
self
,
param
:
torch
.
Tensor
)
->
None
:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert
id
(
param
)
not
in
self
.
_param_ids
,
"The same param cannot be checked in twice"
self
.
_add_param_as_view
(
param
)
self
.
_params
.
append
(
param
)
self
.
_param_ids
.
append
(
id
(
param
))
@
torch
.
no_grad
()
def
_add_param_as_view
(
self
,
param
:
torch
.
Tensor
,
keep_existing_value
:
bool
=
True
)
->
None
:
assert
self
.
buffer
is
not
None
assert
param
.
dtype
==
self
.
buffer
.
dtype
assert
param
.
device
==
self
.
buffer
.
device
fill_next
=
self
.
_fill
+
param
.
numel
()
assert
fill_next
<=
self
.
buffer
.
numel
()
# Copy the current param value
if
keep_existing_value
:
self
.
buffer
[
self
.
_fill
:
fill_next
].
copy_
(
param
.
data
.
flatten
())
param
.
data
=
self
.
buffer
[
self
.
_fill
:
fill_next
].
view_as
(
param
.
data
)
self
.
_fill
=
fill_next
@
torch
.
no_grad
()
def
_reattach_params
(
self
)
->
None
:
"""
Given the parameters which have been registered previously, rebuild the whole bucket
"""
assert
len
(
self
.
_params
)
>
0
self
.
_fill
=
0
for
p
in
self
.
_params
:
self
.
_add_param_as_view
(
p
,
keep_existing_value
=
False
)
class
GradBucket
(
Bucket
):
"""
Helper class to simplify the handling of gradient buckets
"""
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
destination
:
int
)
->
None
:
super
().
__init__
(
size
,
dtype
,
device
)
self
.
_max_size
=
size
self
.
_is_collapsed
=
False
self
.
params_checked_in
=
0
self
.
destination
=
destination
self
.
sent
=
True
self
.
callback
:
Optional
[
Callable
[[
Any
],
None
]]
=
None
def
reset_checked_in
(
self
)
->
None
:
""" Reset the counter of the parameter grads which have been checked in
"""
self
.
params_checked_in
=
0
self
.
sent
=
False
@
property
def
all_checked_in
(
self
)
->
bool
:
""" Have all the expected gradient check-in happened ?"""
return
len
(
self
.
_params
)
==
self
.
params_checked_in
def
can_add_grad_view
(
self
,
param
:
torch
.
Tensor
)
->
bool
:
""" Is there enough room in the bucket to add this parameter gradient, and is this param not already checked in ?
"""
return
self
.
_fill
+
param
.
numel
()
<
self
.
_max_size
and
id
(
param
)
not
in
self
.
_param_ids
def
to
(
# type: ignore
self
,
device
:
Optional
[
Union
[
int
,
torch
.
device
]],
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
non_blocking
:
bool
=
False
,
keep_param_alignment
:
bool
=
True
,
)
->
"GradBucket"
:
"""
Move the underlying buffer
"""
if
self
.
_is_collapsed
:
self
.
rebuild
()
super
().
to
(
device
,
dtype
,
non_blocking
)
if
keep_param_alignment
:
self
.
_reattach_grads
()
def
zero
(
self
)
->
None
:
"""
Set all the grads to zero
"""
self
.
buffer
.
fill_
(
0.0
)
@
torch
.
no_grad
()
def
add_grad
(
self
,
param
:
torch
.
Tensor
)
->
None
:
"""
Add a new parameter gradient to the bucket. Param.grad becomes a view of this bucket buffer
"""
assert
id
(
param
)
not
in
self
.
_param_ids
,
"The same gradients cannot be checked in twice"
if
param
.
grad
is
None
:
param
.
grad
=
torch
.
zeros_like
(
param
)
self
.
_add_grad_as_view
(
param
)
self
.
_params
.
append
(
param
)
self
.
_param_ids
.
append
(
id
(
param
))
@
torch
.
no_grad
()
def
collapse
(
self
)
->
None
:
"""
Release the buffer from memory. The bucket will need to be rebuilt before use
"""
if
not
self
.
_is_collapsed
:
for
p
in
self
.
_params
:
assert
p
.
grad
is
not
None
p
.
grad
.
detach_
()
p
.
grad
=
None
self
.
buffer
=
torch
.
zeros
(
0
,
dtype
=
self
.
buffer
.
dtype
,
device
=
self
.
buffer
.
device
)
self
.
_fill
=
0
self
.
params_checked_in
=
0
self
.
_is_collapsed
=
True
@
torch
.
no_grad
()
def
rebuild
(
self
)
->
None
:
"""
Given the parameter gradients which have been registered previously, rebuild the whole bucket
"""
assert
len
(
self
.
_params
)
>
0
if
self
.
_is_collapsed
:
self
.
buffer
=
torch
.
zeros
(
self
.
_max_size
,
dtype
=
self
.
_params
[
0
].
dtype
,
device
=
self
.
_params
[
0
].
device
)
for
p
in
self
.
_params
:
self
.
_add_grad_as_view
(
p
)
self
.
_is_collapsed
=
False
@
torch
.
no_grad
()
def
shrink
(
self
)
->
None
:
"""
Shrink the buffer to the size of the parameter gradients currently checked in, release the extra memory
"""
assert
self
.
buffer
.
numel
()
>
0
,
"Cannot shrink a collapsed bucket, please rebuild"
self
.
buffer
=
self
.
buffer
.
resize_
(
self
.
_fill
).
clone
()
self
.
_fill
=
0
for
p
in
self
.
_params
:
self
.
_add_grad_as_view
(
p
)
self
.
_max_size
=
self
.
_fill
@
torch
.
no_grad
()
def
_reattach_grads
(
self
)
->
None
:
"""
Given the parameters gradients which have been registered previously, rebuild the whole bucket
"""
assert
len
(
self
.
_params
)
>
0
self
.
_fill
=
0
for
p
in
self
.
_params
:
self
.
_add_grad_as_view
(
p
,
keep_existing_value
=
False
)
@
torch
.
no_grad
()
def
_add_grad_as_view
(
self
,
param
:
torch
.
Tensor
,
keep_existing_value
:
bool
=
True
)
->
None
:
assert
self
.
buffer
.
numel
()
>
0
,
"Cannot add a gradient to a collapsed bucket, please rebuild"
assert
param
.
dtype
==
self
.
buffer
.
dtype
assert
param
.
device
==
self
.
buffer
.
device
fill_next
=
self
.
_fill
+
param
.
numel
()
assert
fill_next
<=
self
.
buffer
.
numel
()
# Copy the current grad value, if any
if
param
.
grad
is
not
None
:
# keep param.grad in place
if
keep_existing_value
:
self
.
buffer
[
self
.
_fill
:
fill_next
].
copy_
(
param
.
grad
.
data
.
flatten
())
param
.
grad
.
data
=
self
.
buffer
[
self
.
_fill
:
fill_next
].
view_as
(
param
.
data
)
else
:
param
.
grad
=
self
.
buffer
[
self
.
_fill
:
fill_next
].
view_as
(
param
.
data
)
self
.
_fill
=
fill_next
fairscale/optim/oss.py
View file @
e3865549
...
...
@@ -15,6 +15,8 @@ import torch.distributed as dist
from
torch.nn
import
Parameter
from
torch.optim
import
SGD
,
Optimizer
from
fairscale.nn.misc
import
ParamBucket
from
.utils
import
broadcast_object
,
calc_grad_norm
,
recursive_copy_to_device
__all__
=
[
"OSS"
]
...
...
@@ -52,6 +54,10 @@ class OSS(Optimizer):
torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int):
(deprecated) used to cap the size of the broadcast buffers, not being used anymore.
broadcast_fp16 (bool):
Compress the model shards in fp16 before sharing them in between ranks.
This is safe to use when PyTorch AMP is activated. Without torch AMP this will lead to a slight
degradation in terms of accuracy.
.. warning: the communication patterns that OSS use depend on the "trainability" graph,
...
...
@@ -73,6 +79,7 @@ class OSS(Optimizer):
optim
:
Type
[
Optimizer
]
=
SGD
,
group
:
Optional
[
Any
]
=
None
,
broadcast_buffer_size
:
int
=
-
1
,
broadcast_fp16
:
bool
=
False
,
**
default
:
Any
,
):
...
...
@@ -99,7 +106,8 @@ class OSS(Optimizer):
self
.
global_rank
=
self
.
get_global_rank
(
self
.
group
,
self
.
rank
)
self
.
_local_to_global_rank
=
[
self
.
get_global_rank
(
self
.
group
,
i
)
for
i
in
range
(
self
.
world_size
)]
self
.
buckets
:
Dict
[
torch
.
device
,
List
[
torch
.
Tensor
]]
=
{}
self
.
broadcast_fp16
=
broadcast_fp16
self
.
buckets
:
Dict
[
torch
.
device
,
Dict
[
int
,
ParamBucket
]]
=
{}
self
.
_all_states
:
List
[
Dict
[
str
,
Any
]]
=
[]
# Optional consolidated optimizer state
self
.
_default_device
=
torch
.
device
(
"cpu"
)
...
...
@@ -542,20 +550,31 @@ class OSS(Optimizer):
work_handles
=
[]
# Work handles are consumed within this scope, no callback
# Populate the fp16 shards
if
self
.
broadcast_fp16
:
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
bucket
.
to
(
dtype
=
torch
.
float16
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
False
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
# Exchange all the shards with the other ranks
for
device
in
self
.
buckets
.
keys
():
for
src_rank
,
bucket
in
enumerate
(
self
.
buckets
[
device
]):
if
bucket
.
numel
()
>
0
:
work_handles
.
append
(
dist
.
broadcast
(
tensor
=
bucket
,
src
=
self
.
_local_to_global_rank
[
src_rank
],
group
=
self
.
group
,
async_op
=
True
)
for
dst_rank
,
bucket
in
self
.
buckets
[
device
].
items
():
work_handles
.
append
(
dist
.
broadcast
(
tensor
=
bucket
.
buffer
,
src
=
self
.
_local_to_global_rank
[
dst_rank
],
group
=
self
.
group
,
async_op
=
True
,
)
)
# Only check on the last handle, they're all inlined on the same CUDA stream
if
work_handles
and
self
.
backend
==
dist
.
Backend
.
NCCL
:
work_handles
[
-
1
].
wait
()
else
:
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
_
=
list
(
filter
(
lambda
x
:
x
.
wait
(),
work_handles
))
# Populate back the fp32 shards
if
self
.
broadcast_fp16
:
for
device
in
self
.
buckets
.
keys
():
for
dst_rank
in
self
.
buckets
[
device
].
keys
():
bucket
.
to
(
dtype
=
torch
.
float32
,
device
=
device
,
non_blocking
=
True
,
keep_param_alignment
=
True
)
def
_setup_flat_buffers
(
self
)
->
None
:
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
...
...
@@ -567,7 +586,7 @@ class OSS(Optimizer):
# Only wipe the existing buckets if there are none
# (could be that this is called twice, when trainability changes)
if
device
not
in
self
.
buckets
.
keys
():
self
.
buckets
[
device
]
=
[]
self
.
buckets
[
device
]
=
{}
# Make parameters a view of the bucket
for
dst_rank
,
params
in
enumerate
(
per_rank_params
):
...
...
@@ -580,23 +599,12 @@ class OSS(Optimizer):
# Merge all the trainable params in a single bucket
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
params
))
buffer_size
=
sum
(
map
(
lambda
x
:
x
.
numel
(),
trainable_params
))
bucket
=
torch
.
empty
(
buffer_size
,
dtype
=
params
[
0
].
dtype
,
device
=
device
)
offset
=
0
bucket
=
ParamBucket
(
size
=
buffer_size
,
dtype
=
params
[
0
].
dtype
,
device
=
device
)
for
param
in
trainable_params
:
offset_next
=
offset
+
param
.
numel
()
bucket
[
offset
:
offset_next
].
copy_
(
param
.
data
.
flatten
())
param
.
data
=
bucket
[
offset
:
offset_next
].
view_as
(
param
.
data
)
offset
=
offset_next
# Either replace the existing bucket, or create it
if
len
(
self
.
buckets
[
device
])
==
dst_rank
:
self
.
buckets
[
device
].
append
(
bucket
)
else
:
self
.
buckets
[
device
][
dst_rank
]
=
bucket
else
:
# This rank has an empty shard, that's fine
self
.
buckets
[
device
].
append
(
torch
.
zeros
(
0
,
device
=
device
))
bucket
.
add_param
(
param
)
self
.
buckets
[
device
][
dst_rank
]
=
bucket
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use
=
list
(
self
.
per_device_params
.
keys
())
...
...
tests/ci_test_list_2.txt
View file @
e3865549
...
...
@@ -5,6 +5,7 @@ tests/utils/test_state_dict.py
tests/nn/misc/test_checkpoint_activations.py
tests/nn/misc/test_checkpoint_activations_norm.py
tests/nn/misc/test_grad_bucket.py
tests/nn/misc/test_param_bucket.py
tests/nn/wrap/test_wrap.py
tests/nn/pipe_process/test_pipe.py
tests/nn/pipe_process/test_transparency.py
...
...
tests/nn/misc/test_grad_bucket.py
View file @
e3865549
...
...
@@ -55,7 +55,7 @@ def test_collapse():
bucket
.
shrink
()
bucket
.
collapse
()
assert
bucket
.
buffer
is
None
assert
bucket
.
buffer
.
numel
()
==
0
assert
param
.
grad
is
None
bucket
.
rebuild
()
...
...
tests/nn/misc/test_param_bucket.py
0 → 100644
View file @
e3865549
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
pytest
import
torch
from
fairscale.nn.misc
import
ParamBucket
def
test_param_values_conserved
():
param
=
torch
.
rand
((
2
,
3
))
bucket
=
ParamBucket
(
10
,
param
.
dtype
,
param
.
device
)
param_
=
param
.
clone
()
bucket
.
add_param
(
param_
)
torch
.
allclose
(
param
,
param_
)
def
test_max_size
():
param
=
torch
.
rand
((
20
,
30
))
bucket
=
ParamBucket
(
5
,
param
.
dtype
,
param
.
device
)
with
pytest
.
raises
(
AssertionError
):
bucket
.
add_param
(
param
)
def
test_double_check_int
():
param
=
torch
.
rand
((
5
,
6
))
bucket
=
ParamBucket
(
300
,
param
.
dtype
,
param
.
device
)
bucket
.
add_param
(
param
)
with
pytest
.
raises
(
AssertionError
):
bucket
.
add_param
(
param
)
def
test_type_change
():
size
=
(
5
,
6
)
param
=
torch
.
rand
(
size
,
requires_grad
=
True
)
param_
=
param
.
clone
()
bucket
=
ParamBucket
(
30
,
param
.
dtype
,
param
.
device
)
bucket
.
add_param
(
param
)
# Move the bucket to fp16 and back
bucket
.
to
(
dtype
=
torch
.
float16
,
device
=
param
.
device
)
bucket
.
to
(
dtype
=
torch
.
float32
,
device
=
param
.
device
,
keep_param_alignment
=
True
)
# Same with the reference tensor
param_
.
to
(
dtype
=
torch
.
float16
)
param_
.
to
(
dtype
=
torch
.
float32
)
torch
.
allclose
(
param
,
param_
)
tests/optim/test_oss.py
View file @
e3865549
...
...
@@ -484,7 +484,7 @@ def test_collect_shards():
)
def
run_test_reproducibility
(
rank
,
world_size
,
tempfile_name
):
def
run_test_reproducibility
(
rank
,
world_size
,
tempfile_name
,
broadcast_fp16
):
dist_init
(
rank
,
world_size
,
tempfile_name
)
device
=
torch
.
device
(
rank
)
if
torch
.
cuda
.
device_count
()
>
1
else
DEVICE
torch
.
cuda
.
set_device
(
rank
)
...
...
@@ -501,7 +501,7 @@ def run_test_reproducibility(rank, world_size, tempfile_name):
loss_fn
=
torch
.
nn
.
L1Loss
()
loss_fn
.
to
(
device
)
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
optim
=
torch
.
optim
.
RMSprop
,
lr
=
0.1
)
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
optim
=
torch
.
optim
.
RMSprop
,
lr
=
0.1
,
broadcast_fp16
=
broadcast_fp16
)
def
closure
():
optimizer
.
zero_grad
()
...
...
@@ -534,12 +534,13 @@ def run_test_reproducibility(rank, world_size, tempfile_name):
@
skip_if_single_gpu
def
test_reproducibility
():
@
pytest
.
mark
.
parametrize
(
"broadcast_fp16"
,
[
False
,
True
])
def
test_reproducibility
(
broadcast_fp16
:
bool
):
world_size
=
2
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_reproducibility
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
,
run_test_reproducibility
,
args
=
(
world_size
,
temp_file_name
,
broadcast_fp16
),
nprocs
=
world_size
,
join
=
True
,
)
...
...
@@ -810,7 +811,7 @@ def test_state_dict_distributed():
)
def
run_ddp_parity
(
rank
,
world_size
,
backend
,
temp_file_name
,
change_train_graph
):
def
run_ddp_parity
(
rank
,
world_size
,
backend
,
temp_file_name
,
change_train_graph
,
broadcast_fp16
):
url
=
"file://"
+
temp_file_name
dist
.
init_process_group
(
init_method
=
url
,
backend
=
backend
,
rank
=
rank
,
world_size
=
world_size
)
...
...
@@ -937,9 +938,13 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name, change_train_graph
@
skip_if_single_gpu
@
pytest
.
mark
.
parametrize
(
"change_train_graph"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
dist
.
Backend
.
NCCL
,
dist
.
Backend
.
GLOO
])
def
test_ddp_parity
(
change_train_graph
:
bool
,
backend
:
dist
.
Backend
):
@
pytest
.
mark
.
parametrize
(
"broadcast_fp16"
,
[
False
,
True
])
def
test_ddp_parity
(
change_train_graph
:
bool
,
backend
:
dist
.
Backend
,
broadcast_fp16
:
bool
):
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
world_size
=
torch
.
cuda
.
device_count
()
mp
.
spawn
(
run_ddp_parity
,
args
=
(
world_size
,
backend
,
temp_file_name
,
change_train_graph
),
nprocs
=
world_size
,
join
=
True
run_ddp_parity
,
args
=
(
world_size
,
backend
,
temp_file_name
,
change_train_graph
,
broadcast_fp16
),
nprocs
=
world_size
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment