Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
81ac5b28
Unverified
Commit
81ac5b28
authored
Oct 08, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Oct 08, 2020
Browse files
[fix] OSS unit test to check data group (#129)
* new unit test to catch rank issues in OSS
parent
22ff665d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
92 additions
and
11 deletions
+92
-11
.circleci/config.yml
.circleci/config.yml
+9
-1
fairscale/optim/oss.py
fairscale/optim/oss.py
+7
-10
tests/optim/test_oss.py
tests/optim/test_oss.py
+76
-0
No files found.
.circleci/config.yml
View file @
81ac5b28
...
@@ -101,6 +101,11 @@ run_oss_benchmark: &run_oss_benchmark
...
@@ -101,6 +101,11 @@ run_oss_benchmark: &run_oss_benchmark
name
:
Run OSS Benchmark
name
:
Run OSS Benchmark
command
:
|
command
:
|
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 21.2 --reference_memory 4220 --reference_loss 0.63
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 21.2 --reference_memory 4220 --reference_loss 0.63
run_oss_gloo
:
&run_oss_gloo
-
run
:
name
:
Run OSS with Gloo
command
:
|
python benchmarks/oss.py --gloo --optim_type oss
python benchmarks/oss.py --gloo --optim_type oss
...
@@ -254,6 +259,9 @@ jobs:
...
@@ -254,6 +259,9 @@ jobs:
-
<<
:
*run_oss_benchmark
-
<<
:
*run_oss_benchmark
-
<<
:
*run_oss_gloo
workflows
:
workflows
:
...
...
fairscale/optim/oss.py
View file @
81ac5b28
...
@@ -194,7 +194,7 @@ class OSS(Optimizer):
...
@@ -194,7 +194,7 @@ class OSS(Optimizer):
device
,
device
,
device_params
,
device_params
,
)
in
self
.
per_device_params
.
items
():
# all the params on this device (inc all ranks)
)
in
self
.
per_device_params
.
items
():
# all the params on this device (inc all ranks)
self
.
_broadcast_params
(
self
.
_broadcast_buffers
[
device
],
device_params
,
self
.
group
,
self
.
global_rank
)
self
.
_broadcast_params
(
self
.
_broadcast_buffers
[
device
],
device_params
)
return
loss
return
loss
...
@@ -408,10 +408,7 @@ class OSS(Optimizer):
...
@@ -408,10 +408,7 @@ class OSS(Optimizer):
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
# type: ignore
global_rank
=
dist
.
distributed_c10d
.
_get_global_rank
(
group
,
rank
)
# type: ignore
return
global_rank
return
global_rank
@
staticmethod
def
_broadcast_params
(
self
,
buffers
:
List
[
torch
.
Tensor
],
per_rank_params
:
List
[
List
[
Parameter
]])
->
None
:
def
_broadcast_params
(
buffers
:
List
[
torch
.
Tensor
],
per_rank_params
:
List
[
List
[
Parameter
]],
group
:
Any
,
self_rank
:
int
)
->
None
:
"""Helper function to broadcast all the parameters from a given device
"""Helper function to broadcast all the parameters from a given device
"""
"""
buffer_size
=
buffers
[
0
].
numel
()
buffer_size
=
buffers
[
0
].
numel
()
...
@@ -425,7 +422,7 @@ class OSS(Optimizer):
...
@@ -425,7 +422,7 @@ class OSS(Optimizer):
if
len
(
params
)
==
0
:
if
len
(
params
)
==
0
:
continue
continue
global_rank
=
OSS
.
get_global_rank
(
group
,
rank
)
global_rank
=
OSS
.
get_global_rank
(
self
.
group
,
rank
)
# Copy small parameters into per-GPU buffers
# Copy small parameters into per-GPU buffers
i_bucketed
=
0
# the number of tensors packed in the buffer
i_bucketed
=
0
# the number of tensors packed in the buffer
...
@@ -434,14 +431,14 @@ class OSS(Optimizer):
...
@@ -434,14 +431,14 @@ class OSS(Optimizer):
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while
i_bucketed
<
len
(
params
)
and
offset
+
params
[
i_bucketed
].
numel
()
<
buffer_size
:
while
i_bucketed
<
len
(
params
)
and
offset
+
params
[
i_bucketed
].
numel
()
<
buffer_size
:
end
=
offset
+
params
[
i_bucketed
].
numel
()
end
=
offset
+
params
[
i_bucketed
].
numel
()
if
global_rank
==
self_rank
:
if
global_rank
==
self
.
global
_rank
:
buffer
[
offset
:
end
].
copy_
(
params
[
i_bucketed
].
data
.
view
(
-
1
))
# type: ignore
buffer
[
offset
:
end
].
copy_
(
params
[
i_bucketed
].
data
.
view
(
-
1
))
# type: ignore
offset
=
end
offset
=
end
i_bucketed
+=
1
i_bucketed
+=
1
if
i_bucketed
>
0
:
if
i_bucketed
>
0
:
future
=
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_rank
,
group
=
group
,
async_op
=
True
)
future
=
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_rank
,
group
=
self
.
group
,
async_op
=
True
)
if
global_rank
!=
self_rank
:
if
global_rank
!=
self
.
global
_rank
:
# This request will need to be unrolled
# This request will need to be unrolled
bucket_requests
.
append
((
future
,
rank
))
bucket_requests
.
append
((
future
,
rank
))
...
@@ -455,7 +452,7 @@ class OSS(Optimizer):
...
@@ -455,7 +452,7 @@ class OSS(Optimizer):
restore_require_grad
.
append
(
param
)
restore_require_grad
.
append
(
param
)
param
.
requires_grad
=
False
param
.
requires_grad
=
False
requests
.
append
(
dist
.
broadcast
(
tensor
=
param
,
src
=
global_rank
,
group
=
group
,
async_op
=
True
))
requests
.
append
(
dist
.
broadcast
(
tensor
=
param
,
src
=
global_rank
,
group
=
self
.
group
,
async_op
=
True
))
# Unroll the initial packed small parameters
# Unroll the initial packed small parameters
for
gate
,
rank
in
bucket_requests
:
for
gate
,
rank
in
bucket_requests
:
...
...
tests/optim/test_oss.py
View file @
81ac5b28
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
import
os
import
os
import
numpy
as
np
import
pytest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -334,3 +335,78 @@ def test_collect_shards():
...
@@ -334,3 +335,78 @@ def test_collect_shards():
mp
.
spawn
(
mp
.
spawn
(
run_test_collect_shards
,
args
=
(
world_size
,
reference_rank
),
nprocs
=
world_size
,
join
=
True
,
run_test_collect_shards
,
args
=
(
world_size
,
reference_rank
),
nprocs
=
world_size
,
join
=
True
,
)
)
def
run_test_multiple_groups
(
rank
,
world_size
):
# Only work with the even ranks, to check that the global_rank indexing is properly used
os
.
environ
[
"MASTER_ADDR"
]
=
"localhost"
os
.
environ
[
"MASTER_PORT"
]
=
"29501"
dist
.
init_process_group
(
backend
=
"gloo"
,
rank
=
rank
,
world_size
=
world_size
)
sub_group_ranks
=
[
0
,
2
,
4
]
process_group
=
torch
.
distributed
.
new_group
(
ranks
=
sub_group_ranks
,
backend
=
"gloo"
)
# Make sure that all the ranks get different training data
# So that the sync check in between their models is meaningful
torch
.
manual_seed
(
rank
)
np
.
random
.
seed
(
rank
)
# Standard deep learning setup
device
=
"cpu"
epochs
,
batch
,
input_width
,
hidden
,
target_width
=
5
,
3
,
20
,
10
,
5
loss_fn
=
torch
.
nn
.
L1Loss
().
to
(
device
)
def
check
(
optimizer
):
# Just run a couple of epochs, check that the model is properly updated
for
_
in
range
(
epochs
):
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
def
closure
():
optimizer
.
zero_grad
()
output
=
model
(
inputs
)
loss
=
loss_fn
(
output
,
target
)
loss
/=
world_size
loss
.
backward
()
dist
.
all_reduce
(
loss
,
group
=
process_group
)
# Not strictly needed for the test below
return
loss
_
=
optimizer
.
step
(
closure
=
closure
)
# Check that all the params are the same on all ranks
for
pg
in
optimizer
.
param_groups
:
for
p
in
pg
[
"params"
]:
receptacle
=
[
p
.
clone
()
for
_
in
sub_group_ranks
]
if
rank
==
0
else
[]
dist
.
gather
(
p
,
receptacle
,
dst
=
0
,
group
=
process_group
)
if
rank
==
0
:
for
sync_p
in
receptacle
[
1
:]:
assert
torch
.
all
(
torch
.
eq
(
receptacle
[
0
],
sync_p
)),
"Models differ in between ranks"
if
rank
in
sub_group_ranks
:
# Model fitting in the broadcast bucket
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
input_width
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
target_width
)).
to
(
device
)
# With SGD, Momentum is required to get a state to shard
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
,
group
=
process_group
,
broadcast_buffer_size
=
2
**
20
)
check
(
optimizer
)
# Model not-fitting in the broadcast bucket
model
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
input_width
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
target_width
)).
to
(
device
)
# With SGD, Momentum is required to get a state to shard
optimizer
=
optim
.
OSS
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
,
group
=
process_group
,
broadcast_buffer_size
=
0
)
check
(
optimizer
)
def
test_multiple_groups
():
world_size
=
6
mp
.
spawn
(
run_test_multiple_groups
,
args
=
(
world_size
,),
nprocs
=
world_size
,
join
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment