Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
81ac5b28
Unverified
Commit
81ac5b28
authored
Oct 08, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Oct 08, 2020
Browse files
[fix] OSS unit test to check data group (#129)
* new unit test to catch rank issues in OSS
parent
22ff665d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
92 additions
and
11 deletions
+92
-11
.circleci/config.yml
.circleci/config.yml
+9
-1
fairscale/optim/oss.py
fairscale/optim/oss.py
+7
-10
tests/optim/test_oss.py
tests/optim/test_oss.py
+76
-0
No files found.
.circleci/config.yml
View file @
81ac5b28
...
...
@@ -101,7 +101,12 @@ run_oss_benchmark: &run_oss_benchmark
name
:
Run OSS Benchmark
command
:
|
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 21.2 --reference_memory 4220 --reference_loss 0.63
python benchmarks/oss.py --gloo --optim_type oss
run_oss_gloo
:
&run_oss_gloo
-
run
:
name
:
Run OSS with Gloo
command
:
|
python benchmarks/oss.py --gloo --optim_type oss
# -------------------------------------------------------------------------------------
...
...
@@ -254,6 +259,9 @@ jobs:
-
<<
:
*run_oss_benchmark
-
<<
:
*run_oss_gloo
workflows
:
...
...
fairscale/optim/oss.py
View file @
81ac5b28
...
...
@@ -194,7 +194,7 @@ class OSS(Optimizer):
device
,
device_params
,
)
in
self
.
per_device_params
.
items
():
# all the params on this device (inc all ranks)
self
.
_broadcast_params
(
self
.
_broadcast_buffers
[
device
],
device_params
,
self
.
group
,
self
.
global_rank
)
self
.
_broadcast_params
(
self
.
_broadcast_buffers
[
device
],
device_params
)
return
loss
...
...
@@ -408,10 +408,7 @@ class OSS(Optimizer):
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
# type: ignore
return
global_rank
@
staticmethod
def
_broadcast_params
(
buffers
:
List
[
torch
.
Tensor
],
per_rank_params
:
List
[
List
[
Parameter
]],
group
:
Any
,
self_rank
:
int
)
->
None
:
def
_broadcast_params
(
self
,
buffers
:
List
[
torch
.
Tensor
],
per_rank_params
:
List
[
List
[
Parameter
]])
->
None
:
"""Helper function to broadcast all the parameters from a given device
"""
buffer_size
=
buffers
[
0
].
numel
()
...
...
@@ -425,7 +422,7 @@ class OSS(Optimizer):
if
len
(
params
)
==
0
:
continue
global_rank
=
OSS
.
get_global_rank
(
group
,
rank
)
global_rank
=
OSS
.
get_global_rank
(
self
.
group
,
rank
)
# Copy small parameters into per-GPU buffers
i_bucketed
=
0
# the number of tensors packed in the buffer
...
...
@@ -434,14 +431,14 @@ class OSS(Optimizer):
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while
i_bucketed
<
len
(
params
)
and
offset
+
params
[
i_bucketed
].
numel
()
<
buffer_size
:
end
=
offset
+
params
[
i_bucketed
].
numel
()
if
global_rank
==
self_rank
:
if
global_rank
==
self
.
global
_rank
:
buffer
[
offset
:
end
].
copy_
(
params
[
i_bucketed
].
data
.
view
(
-
1
))
# type: ignore
offset
=
end
i_bucketed
+=
1
if
i_bucketed
>
0
:
future
=
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_rank
,
group
=
group
,
async_op
=
True
)
if
global_rank
!=
self_rank
:
future
=
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_rank
,
group
=
self
.
group
,
async_op
=
True
)
if
global_rank
!=
self
.
global
_rank
:
# This request will need to be unrolled
bucket_requests
.
append
((
future
,
rank
))
...
...
@@ -455,7 +452,7 @@ class OSS(Optimizer):
restore_require_grad
.
append
(
param
)
param
.
requires_grad
=
False
requests
.
append
(
dist
.
broadcast
(
tensor
=
param
,
src
=
global_rank
,
group
=
group
,
async_op
=
True
))
requests
.
append
(
dist
.
broadcast
(
tensor
=
param
,
src
=
global_rank
,
group
=
self
.
group
,
async_op
=
True
))
# Unroll the initial packed small parameters
for
gate
,
rank
in
bucket_requests
:
...
...
tests/optim/test_oss.py
View file @
81ac5b28
...
...
@@ -9,6 +9,7 @@
import
os
import
numpy
as
np
import
pytest
import
torch
import
torch.distributed
as
dist
...
...
@@ -334,3 +335,78 @@ def test_collect_shards():
mp
.
spawn
(
run_test_collect_shards
,
args
=
(
world_size
,
reference_rank
),
nprocs
=
world_size
,
join
=
True
,
)
def
run_test_multiple_groups
(
rank
,
world_size
):
# Only work with the even ranks, to check that the global_rank indexing is properly used
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
dist
.
init_process_group
(
backend
=
"gloo"
,
rank
=
rank
,
world_size
=
world_size
)
sub_group_ranks
=
[
0
,
2
,
4
]
process_group
=
torch
.
distributed
.
new_group
(
ranks
=
sub_group_ranks
,
backend
=
"gloo"
)
# Make sure that all the ranks get different training data
# So that the sync check in between their models is meaningful
torch
.
manual_seed
(
rank
)
np
.
random
.
seed
(
rank
)
# Standard deep learning setup
device
=
"cpu"
epochs
,
batch
,
input_width
,
hidden
,
target_width
=
5
,
3
,
20
,
10
,
5
loss_fn
=
torch
.
nn
.
L1Loss
().
to
(
device
)
def
check
(
optimizer
):
# Just run a couple of epochs, check that the model is properly updated
for
_
in
range
(
epochs
):
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
def
closure
():
optimizer
.
zero_grad
()
output
=
model
(
inputs
)
loss
=
loss_fn
(
output
,
target
)
loss
/=
world_size
loss
.
backward
()
dist
.
all_reduce
(
loss
,
group
=
process_group
)
# Not strictly needed for the test below
return
loss
_
=
optimizer
.
step
(
closure
=
closure
)
# Check that all the params are the same on all ranks
for
pg
in
optimizer
.
param_groups
:
for
p
in
pg
[
"params"
]:
receptacle
=
[
p
.
clone
()
for
_
in
sub_group_ranks
]
if
rank
==
0
else
[]
dist
.
gather
(
p
,
receptacle
,
dst
=
0
,
group
=
process_group
)
if
rank
==
0
:
for
sync_p
in
receptacle
[
1
:]:
assert
torch
.
all
(
torch
.
eq
(
receptacle
[
0
],
sync_p
)),
"Models differ in between ranks"
if
rank
in
sub_group_ranks
:
# Model fitting in the broadcast bucket
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
input_width
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
target_width
)).
to
(
device
)
# With SGD, Momentum is required to get a state to shard
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
,
group
=
process_group
,
broadcast_buffer_size
=
2
**
20
)
check
(
optimizer
)
# Model not-fitting in the broadcast bucket
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
input_width
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
target_width
)).
to
(
device
)
# With SGD, Momentum is required to get a state to shard
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
,
group
=
process_group
,
broadcast_buffer_size
=
0
)
check
(
optimizer
)
def
test_multiple_groups
():
world_size
=
6
mp
.
spawn
(
run_test_multiple_groups
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment