Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
ce2f64f9
Unverified
Commit
ce2f64f9
authored
Jan 19, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 19, 2021
Browse files
[fix] OSS tensor view corner case + corresponding unit tests (#315)
parent
44b9bcd8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
50 deletions
+48
-50
fairscale/optim/oss.py
fairscale/optim/oss.py
+18
-31
pyproject.toml
pyproject.toml
+1
-1
tests/optim/test_oss.py
tests/optim/test_oss.py
+29
-18
No files found.
fairscale/optim/oss.py
View file @
ce2f64f9
...
...
@@ -109,25 +109,10 @@ class OSS(Optimizer):
# Current default device is set by the parameters allocated to this rank
self
.
_device
=
list
(
self
.
per_device_params
.
keys
())[
0
]
self
.
buckets
:
Dict
[
torch
.
device
,
List
[
torch
.
Tensor
]]
=
{}
self
.
buffer_max_size
=
broadcast_buffer_size
# Get the correct size for the buckets, cannot be bigger than the model
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
param_to_rank
.
keys
()])
self
.
bucket_size
=
min
(
broadcast_buffer_size
,
model_size
)
logging
.
info
(
"Bucket size: {:.2f}M parameters, model size {:.2f}M parameters"
.
format
(
self
.
bucket_size
/
2
**
20
,
model_size
/
2
**
20
)
)
# Allocate one buffer per rank and per device to group the small parameters
for
device
,
per_device
in
self
.
per_device_params
.
items
():
self
.
buckets
[
device
]
=
[
torch
.
zeros
(
self
.
bucket_size
,
dtype
=
per_device
[
0
][
0
].
dtype
,
device
=
device
)
for
_
in
range
(
len
(
per_device
))
]
self
.
should_bucket_param
:
List
[
bool
]
=
[]
self
.
work_handles
:
Deque
[
Workhandle
]
=
deque
()
self
.
_max_work_handles
=
-
1
self
.
_setup_bucket_strategy
()
# Partition helpers
...
...
@@ -624,10 +609,24 @@ class OSS(Optimizer):
network requests have been issued.
"""
# Determine the max work handles in flight:
# - count all the buckets on the fly
self
.
_max_work_handles
=
0
# (re) allocate the buckets
# - Get the correct size for the buckets, cannot be bigger than the model
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
param_to_rank
.
keys
()])
self
.
bucket_size
=
min
(
self
.
buffer_max_size
,
model_size
)
logging
.
info
(
"Bucket size: {:.2f}M parameters, model size {:.2f}M parameters"
.
format
(
self
.
bucket_size
/
2
**
20
,
model_size
/
2
**
20
)
)
# - Allocate one buffer per rank and per device to group the small parameters
for
device
,
per_device
in
self
.
per_device_params
.
items
():
self
.
buckets
[
device
]
=
[
torch
.
zeros
(
self
.
bucket_size
,
dtype
=
per_device
[
0
][
0
].
dtype
,
device
=
device
)
for
_
in
range
(
len
(
per_device
))
]
# Devise the bucketing strategy
for
device
,
per_rank_params
in
self
.
per_device_params
.
items
():
for
dst_rank
,
params
in
enumerate
(
per_rank_params
):
offset
=
0
...
...
@@ -638,10 +637,6 @@ class OSS(Optimizer):
if
param
.
requires_grad
and
(
offset
+
param
.
numel
())
<
self
.
bucket_size
:
self
.
should_bucket_param
.
append
(
True
)
if
offset
==
0
:
# count this bucket, only once
self
.
_max_work_handles
+=
1
# This parameter becomes a view of the bucket
offset_next
=
offset
+
param
.
numel
()
...
...
@@ -654,11 +649,3 @@ class OSS(Optimizer):
# Resize the bucket to remove lost space in the end
self
.
buckets
[
device
][
dst_rank
].
resize_
(
offset
)
# Make sure that the memory previously taken by the bucketed parameters is released
if
self
.
_device
.
type
==
"cuda"
:
torch
.
cuda
.
empty_cache
()
# Determine the max work handles in flight:
# - all the direct reduce/broadcast
self
.
_max_work_handles
+=
sum
(
not
value
for
value
in
self
.
should_bucket_param
)
pyproject.toml
View file @
ce2f64f9
...
...
@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob
=
[
"build/*"
,
"stubs/*"
]
# Don't split "import" and "from".
force_sort_within_sections
=
true
known_third_party
=
[
"
benchmark_
dataset"
,
"
dataset
s"
,
"helpers"
,
"models"
,
"numpy"
,
"pytest"
,
"recommonmark"
,
"setuptools"
,
"torch"
,
"torch_pg"
,
"torchtext"
,
"torchvision"
]
known_third_party
=
[
"dataset
s
"
,
"
golden_config
s"
,
"helpers"
,
"models"
,
"numpy"
,
"pytest"
,
"recommonmark"
,
"setuptools"
,
"torch"
,
"torch_pg"
,
"torchtext"
,
"torchvision"
]
tests/optim/test_oss.py
View file @
ce2f64f9
...
...
@@ -191,7 +191,9 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
# Test with all parameters trainable to begin with
def
all_trainable
():
params
=
[]
for
size
in
[
4
,
5
,
2
,
6
,
4
]:
sizes
=
[
9
,
7
,
5
,
3
]
sizes_world
=
sizes
*
world_size
for
size
in
sizes_world
[:
-
1
]:
params
.
append
(
torch
.
rand
(
size
,
1
))
# Make sure that the params are trainable, enforces size-based partitioning
...
...
@@ -204,8 +206,9 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
o
.
add_param_group
({
"params"
:
[
torch
.
rand
(
3
,
1
)]})
assert
len
(
o
.
param_groups
)
==
2
# Verify that added group is added to the correct partition making all have 8 elements.
assert
sum
([
x
.
numel
()
for
g
in
o
.
optim
.
param_groups
for
x
in
g
[
"params"
]])
==
8
# Verify that added group is added to the correct partition making all have the same number of elements
assert
sum
([
x
.
numel
()
for
g
in
o
.
optim
.
param_groups
for
x
in
g
[
"params"
]])
==
sum
(
sizes
)
assert
len
(
o
.
optim
.
param_groups
)
==
2
# Test a pathological config with a first big non-trainable param
...
...
@@ -233,9 +236,10 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
def
test_add_param_group
():
world_size
=
3
world_size
=
4
if
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
device_count
()
<
world_size
:
pytest
.
skip
(
"Not enough GPUs for NCCL-based test"
)
world_size
=
min
(
world_size
,
torch
.
cuda
.
device_count
())
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_add_param_group
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
...
...
@@ -591,10 +595,11 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
model_oss1
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
input_width
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
target_width
)
,
).
to
(
device
)
model_oss1
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
input_width
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
hidden
),).
to
(
device
)
head_oss1
=
torch
.
nn
.
Linear
(
hidden
,
target_width
)
.
to
(
device
)
model_oss2
=
copy
.
deepcopy
(
model_oss1
)
head_oss2
=
copy
.
deepcopy
(
head_oss1
)
# For this test the gradients are (all) reduced in the same way in between the torch reference and fairscale.
# Normally OSS would use ShardedDDP and only reduce to the proper rank, but this does not change the
...
...
@@ -602,16 +607,19 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# to keep the comparison apples-to-apples DDP is used in both cases
model_oss1
=
DDP
(
module
=
model_oss1
,
device_ids
=
[
rank
],)
sharded_optimizer1
=
optim
.
OSS
(
model_oss1
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer1
.
add_param_group
({
"params"
:
head_oss1
.
parameters
()})
model_oss2
=
DDP
(
module
=
model_oss2
,
device_ids
=
[
rank
],)
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
.
add_param_group
({
"params"
:
head_oss2
.
parameters
()})
def
run_grad_step
(
device
,
model
,
optimizer
):
def
run_grad_step
(
device
,
model
,
head
,
optimizer
):
loss_fn
=
torch
.
nn
.
L1Loss
()
loss_fn
.
to
(
device
)
model
.
zero_grad
()
outputs
=
model
(
inputs
)
outputs
=
head
(
model
(
inputs
)
)
loss
=
loss_fn
(
outputs
,
target
)
loss
.
backward
()
...
...
@@ -622,21 +630,23 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# save and reload without taking any steps
sharded_optimizer2
.
consolidate_state_dict
()
state_dict2
=
sharded_optimizer2
.
state_dict
()
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
.
add_param_group
({
"params"
:
head_oss2
.
parameters
()})
sharded_optimizer2
.
load_state_dict
(
state_dict2
)
# now take a step and check that parameters are equal
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
run_grad_step
(
device
,
model_oss1
,
head_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
head_oss2
,
sharded_optimizer2
)
# check that model parameters are equal
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
assert
torch
.
allclose
(
param1
,
param2
),
"parameters of the two identical models have diverged (before any steps)"
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
run_grad_step
(
device
,
model_oss1
,
head_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
head_oss2
,
sharded_optimizer2
)
# check that model parameters are equal
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
...
...
@@ -653,8 +663,8 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
assert
state_dict2
[
"param_groups"
][
replica
][
k
]
==
sharded_optimizer2
.
param_groups
[
0
][
k
]
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
run_grad_step
(
device
,
model_oss1
,
head_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
head_oss2
,
sharded_optimizer2
)
# check that saving did not cause a change in the parameters
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
...
...
@@ -668,11 +678,12 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# reload the state_dict
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
.
add_param_group
({
"params"
:
head_oss2
.
parameters
()})
sharded_optimizer2
.
load_state_dict
(
state_dict2
)
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
run_grad_step
(
device
,
model_oss1
,
head_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
head_oss2
,
sharded_optimizer2
)
# check that reloading a saved state dict does not change the parameters
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment