Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
ce2f64f9
Unverified
Commit
ce2f64f9
authored
Jan 19, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Jan 19, 2021
Browse files
[fix] OSS tensor view corner case + corresponding unit tests (#315)
parent
44b9bcd8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
50 deletions
+48
-50
fairscale/optim/oss.py
fairscale/optim/oss.py
+18
-31
pyproject.toml
pyproject.toml
+1
-1
tests/optim/test_oss.py
tests/optim/test_oss.py
+29
-18
No files found.
fairscale/optim/oss.py
View file @
ce2f64f9
...
@@ -109,25 +109,10 @@ class OSS(Optimizer):
...
@@ -109,25 +109,10 @@ class OSS(Optimizer):
# Current default device is set by the parameters allocated to this rank
# Current default device is set by the parameters allocated to this rank
self
.
_device
=
list
(
self
.
per_device_params
.
keys
())[
0
]
self
.
_device
=
list
(
self
.
per_device_params
.
keys
())[
0
]
self
.
buckets
:
Dict
[
torch
.
device
,
List
[
torch
.
Tensor
]]
=
{}
self
.
buckets
:
Dict
[
torch
.
device
,
List
[
torch
.
Tensor
]]
=
{}
self
.
buffer_max_size
=
broadcast_buffer_size
# Get the correct size for the buckets, cannot be bigger than the model
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
param_to_rank
.
keys
()])
self
.
bucket_size
=
min
(
broadcast_buffer_size
,
model_size
)
logging
.
info
(
"Bucket size: {:.2f}M parameters, model size {:.2f}M parameters"
.
format
(
self
.
bucket_size
/
2
**
20
,
model_size
/
2
**
20
)
)
# Allocate one buffer per rank and per device to group the small parameters
for
device
,
per_device
in
self
.
per_device_params
.
items
():
self
.
buckets
[
device
]
=
[
torch
.
zeros
(
self
.
bucket_size
,
dtype
=
per_device
[
0
][
0
].
dtype
,
device
=
device
)
for
_
in
range
(
len
(
per_device
))
]
self
.
should_bucket_param
:
List
[
bool
]
=
[]
self
.
should_bucket_param
:
List
[
bool
]
=
[]
self
.
work_handles
:
Deque
[
Workhandle
]
=
deque
()
self
.
work_handles
:
Deque
[
Workhandle
]
=
deque
()
self
.
_max_work_handles
=
-
1
self
.
_setup_bucket_strategy
()
self
.
_setup_bucket_strategy
()
# Partition helpers
# Partition helpers
...
@@ -624,10 +609,24 @@ class OSS(Optimizer):
...
@@ -624,10 +609,24 @@ class OSS(Optimizer):
network requests have been issued.
network requests have been issued.
"""
"""
# Determine the max work handles in flight:
# (re) allocate the buckets
# - count all the buckets on the fly
# - Get the correct size for the buckets, cannot be bigger than the model
self
.
_max_work_handles
=
0
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
param_to_rank
.
keys
()])
self
.
bucket_size
=
min
(
self
.
buffer_max_size
,
model_size
)
logging
.
info
(
"Bucket size: {:.2f}M parameters, model size {:.2f}M parameters"
.
format
(
self
.
bucket_size
/
2
**
20
,
model_size
/
2
**
20
)
)
# - Allocate one buffer per rank and per device to group the small parameters
for
device
,
per_device
in
self
.
per_device_params
.
items
():
self
.
buckets
[
device
]
=
[
torch
.
zeros
(
self
.
bucket_size
,
dtype
=
per_device
[
0
][
0
].
dtype
,
device
=
device
)
for
_
in
range
(
len
(
per_device
))
]
# Devise the bucketing strategy
for
device
,
per_rank_params
in
self
.
per_device_params
.
items
():
for
device
,
per_rank_params
in
self
.
per_device_params
.
items
():
for
dst_rank
,
params
in
enumerate
(
per_rank_params
):
for
dst_rank
,
params
in
enumerate
(
per_rank_params
):
offset
=
0
offset
=
0
...
@@ -638,10 +637,6 @@ class OSS(Optimizer):
...
@@ -638,10 +637,6 @@ class OSS(Optimizer):
if
param
.
requires_grad
and
(
offset
+
param
.
numel
())
<
self
.
bucket_size
:
if
param
.
requires_grad
and
(
offset
+
param
.
numel
())
<
self
.
bucket_size
:
self
.
should_bucket_param
.
append
(
True
)
self
.
should_bucket_param
.
append
(
True
)
if
offset
==
0
:
# count this bucket, only once
self
.
_max_work_handles
+=
1
# This parameter becomes a view of the bucket
# This parameter becomes a view of the bucket
offset_next
=
offset
+
param
.
numel
()
offset_next
=
offset
+
param
.
numel
()
...
@@ -654,11 +649,3 @@ class OSS(Optimizer):
...
@@ -654,11 +649,3 @@ class OSS(Optimizer):
# Resize the bucket to remove lost space in the end
# Resize the bucket to remove lost space in the end
self
.
buckets
[
device
][
dst_rank
].
resize_
(
offset
)
self
.
buckets
[
device
][
dst_rank
].
resize_
(
offset
)
# Make sure that the memory previously taken by the bucketed parameters is released
if
self
.
_device
.
type
==
"cuda"
:
torch
.
cuda
.
empty_cache
()
# Determine the max work handles in flight:
# - all the direct reduce/broadcast
self
.
_max_work_handles
+=
sum
(
not
value
for
value
in
self
.
should_bucket_param
)
pyproject.toml
View file @
ce2f64f9
...
@@ -28,4 +28,4 @@ use_parentheses = true
...
@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob
=
[
"build/*"
,
"stubs/*"
]
skip_glob
=
[
"build/*"
,
"stubs/*"
]
# Don't split "import" and "from".
# Don't split "import" and "from".
force_sort_within_sections
=
true
force_sort_within_sections
=
true
known_third_party
=
[
"
benchmark_
dataset"
,
"
dataset
s"
,
"helpers"
,
"models"
,
"numpy"
,
"pytest"
,
"recommonmark"
,
"setuptools"
,
"torch"
,
"torch_pg"
,
"torchtext"
,
"torchvision"
]
known_third_party
=
[
"dataset
s
"
,
"
golden_config
s"
,
"helpers"
,
"models"
,
"numpy"
,
"pytest"
,
"recommonmark"
,
"setuptools"
,
"torch"
,
"torch_pg"
,
"torchtext"
,
"torchvision"
]
tests/optim/test_oss.py
View file @
ce2f64f9
...
@@ -191,7 +191,9 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
...
@@ -191,7 +191,9 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
# Test with all parameters trainable to begin with
# Test with all parameters trainable to begin with
def
all_trainable
():
def
all_trainable
():
params
=
[]
params
=
[]
for
size
in
[
4
,
5
,
2
,
6
,
4
]:
sizes
=
[
9
,
7
,
5
,
3
]
sizes_world
=
sizes
*
world_size
for
size
in
sizes_world
[:
-
1
]:
params
.
append
(
torch
.
rand
(
size
,
1
))
params
.
append
(
torch
.
rand
(
size
,
1
))
# Make sure that the params are trainable, enforces size-based partitioning
# Make sure that the params are trainable, enforces size-based partitioning
...
@@ -204,8 +206,9 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
...
@@ -204,8 +206,9 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
o
.
add_param_group
({
"params"
:
[
torch
.
rand
(
3
,
1
)]})
o
.
add_param_group
({
"params"
:
[
torch
.
rand
(
3
,
1
)]})
assert
len
(
o
.
param_groups
)
==
2
assert
len
(
o
.
param_groups
)
==
2
# Verify that added group is added to the correct partition making all have 8 elements.
assert
sum
([
x
.
numel
()
for
g
in
o
.
optim
.
param_groups
for
x
in
g
[
"params"
]])
==
8
# Verify that added group is added to the correct partition making all have the same number of elements
assert
sum
([
x
.
numel
()
for
g
in
o
.
optim
.
param_groups
for
x
in
g
[
"params"
]])
==
sum
(
sizes
)
assert
len
(
o
.
optim
.
param_groups
)
==
2
assert
len
(
o
.
optim
.
param_groups
)
==
2
# Test a pathological config with a first big non-trainable param
# Test a pathological config with a first big non-trainable param
...
@@ -233,9 +236,10 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
...
@@ -233,9 +236,10 @@ def run_test_add_param_group(rank, world_size, tempfile_name):
def
test_add_param_group
():
def
test_add_param_group
():
world_size
=
3
world_size
=
4
if
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
device_count
()
<
world_size
:
if
not
torch
.
cuda
.
is_available
()
or
torch
.
cuda
.
device_count
()
<
world_size
:
pytest
.
skip
(
"Not enough GPUs for NCCL-based test"
)
world_size
=
min
(
world_size
,
torch
.
cuda
.
device_count
())
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
temp_file_name
=
tempfile
.
mkstemp
()[
1
]
mp
.
spawn
(
run_test_add_param_group
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
mp
.
spawn
(
run_test_add_param_group
,
args
=
(
world_size
,
temp_file_name
),
nprocs
=
world_size
,
join
=
True
)
...
@@ -591,10 +595,11 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
...
@@ -591,10 +595,11 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
target
=
torch
.
rand
((
batch
,
target_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
inputs
=
torch
.
rand
((
batch
,
input_width
),
device
=
device
)
model_oss1
=
torch
.
nn
.
Sequential
(
model_oss1
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
input_width
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
hidden
),).
to
(
device
)
torch
.
nn
.
Linear
(
input_width
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
hidden
),
torch
.
nn
.
Linear
(
hidden
,
target_width
)
,
head_oss1
=
torch
.
nn
.
Linear
(
hidden
,
target_width
)
.
to
(
device
)
).
to
(
device
)
model_oss2
=
copy
.
deepcopy
(
model_oss1
)
model_oss2
=
copy
.
deepcopy
(
model_oss1
)
head_oss2
=
copy
.
deepcopy
(
head_oss1
)
# For this test the gradients are (all) reduced in the same way in between the torch reference and fairscale.
# For this test the gradients are (all) reduced in the same way in between the torch reference and fairscale.
# Normally OSS would use ShardedDDP and only reduce to the proper rank, but this does not change the
# Normally OSS would use ShardedDDP and only reduce to the proper rank, but this does not change the
...
@@ -602,16 +607,19 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
...
@@ -602,16 +607,19 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# to keep the comparison apples-to-apples DDP is used in both cases
# to keep the comparison apples-to-apples DDP is used in both cases
model_oss1
=
DDP
(
module
=
model_oss1
,
device_ids
=
[
rank
],)
model_oss1
=
DDP
(
module
=
model_oss1
,
device_ids
=
[
rank
],)
sharded_optimizer1
=
optim
.
OSS
(
model_oss1
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer1
=
optim
.
OSS
(
model_oss1
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer1
.
add_param_group
({
"params"
:
head_oss1
.
parameters
()})
model_oss2
=
DDP
(
module
=
model_oss2
,
device_ids
=
[
rank
],)
model_oss2
=
DDP
(
module
=
model_oss2
,
device_ids
=
[
rank
],)
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
.
add_param_group
({
"params"
:
head_oss2
.
parameters
()})
def
run_grad_step
(
device
,
model
,
optimizer
):
def
run_grad_step
(
device
,
model
,
head
,
optimizer
):
loss_fn
=
torch
.
nn
.
L1Loss
()
loss_fn
=
torch
.
nn
.
L1Loss
()
loss_fn
.
to
(
device
)
loss_fn
.
to
(
device
)
model
.
zero_grad
()
model
.
zero_grad
()
outputs
=
model
(
inputs
)
outputs
=
head
(
model
(
inputs
)
)
loss
=
loss_fn
(
outputs
,
target
)
loss
=
loss_fn
(
outputs
,
target
)
loss
.
backward
()
loss
.
backward
()
...
@@ -622,21 +630,23 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
...
@@ -622,21 +630,23 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# save and reload without taking any steps
# save and reload without taking any steps
sharded_optimizer2
.
consolidate_state_dict
()
sharded_optimizer2
.
consolidate_state_dict
()
state_dict2
=
sharded_optimizer2
.
state_dict
()
state_dict2
=
sharded_optimizer2
.
state_dict
()
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
.
add_param_group
({
"params"
:
head_oss2
.
parameters
()})
sharded_optimizer2
.
load_state_dict
(
state_dict2
)
sharded_optimizer2
.
load_state_dict
(
state_dict2
)
# now take a step and check that parameters are equal
# now take a step and check that parameters are equal
# take a step
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss1
,
head_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
run_grad_step
(
device
,
model_oss2
,
head_oss2
,
sharded_optimizer2
)
# check that model parameters are equal
# check that model parameters are equal
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
assert
torch
.
allclose
(
param1
,
param2
),
"parameters of the two identical models have diverged (before any steps)"
assert
torch
.
allclose
(
param1
,
param2
),
"parameters of the two identical models have diverged (before any steps)"
# take a step
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss1
,
head_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
run_grad_step
(
device
,
model_oss2
,
head_oss2
,
sharded_optimizer2
)
# check that model parameters are equal
# check that model parameters are equal
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
...
@@ -653,8 +663,8 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
...
@@ -653,8 +663,8 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
assert
state_dict2
[
"param_groups"
][
replica
][
k
]
==
sharded_optimizer2
.
param_groups
[
0
][
k
]
assert
state_dict2
[
"param_groups"
][
replica
][
k
]
==
sharded_optimizer2
.
param_groups
[
0
][
k
]
# take a step
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss1
,
head_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
run_grad_step
(
device
,
model_oss2
,
head_oss2
,
sharded_optimizer2
)
# check that saving did not cause a change in the parameters
# check that saving did not cause a change in the parameters
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
...
@@ -668,11 +678,12 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
...
@@ -668,11 +678,12 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
# reload the state_dict
# reload the state_dict
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
=
optim
.
OSS
(
model_oss2
.
parameters
(),
lr
=
0.1
,
momentum
=
0.99
)
sharded_optimizer2
.
add_param_group
({
"params"
:
head_oss2
.
parameters
()})
sharded_optimizer2
.
load_state_dict
(
state_dict2
)
sharded_optimizer2
.
load_state_dict
(
state_dict2
)
# take a step
# take a step
run_grad_step
(
device
,
model_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss1
,
head_oss1
,
sharded_optimizer1
)
run_grad_step
(
device
,
model_oss2
,
sharded_optimizer2
)
run_grad_step
(
device
,
model_oss2
,
head_oss2
,
sharded_optimizer2
)
# check that reloading a saved state dict does not change the parameters
# check that reloading a saved state dict does not change the parameters
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
for
param1
,
param2
in
zip
(
model_oss1
.
parameters
(),
model_oss2
.
parameters
()):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment