Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
77d94861
Unverified
Commit
77d94861
authored
Feb 07, 2021
by
Benjamin Lefaudeux
Committed by
GitHub
Feb 07, 2021
Browse files
[refactor] OSS only use flat buffers (#371)
* flat params all along, way simpler * updating the docstring
parent
8778fa66
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
64 deletions
+19
-64
fairscale/optim/oss.py
fairscale/optim/oss.py
+19
-64
No files found.
fairscale/optim/oss.py
View file @
77d94861
...
@@ -51,9 +51,7 @@ class OSS(Optimizer):
...
@@ -51,9 +51,7 @@ class OSS(Optimizer):
group (group):
group (group):
torch.distributed group (default: group.WORLD)
torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int):
broadcast_buffer_size (int):
the max size of the buffer used to batch the small parameter tensors, in number of elements (default 16M).
(deprecated) used to cap the size of the broadcast buffers, not being used anymore.
this will not impact the long term memory consumption, but the peak memory can be impacted by the moment
when the buffers are allocated and the bucketed params have not yet been relocated to them.
"""
"""
#: The optimizer used for a given shard
#: The optimizer used for a given shard
...
@@ -66,7 +64,7 @@ class OSS(Optimizer):
...
@@ -66,7 +64,7 @@ class OSS(Optimizer):
params
:
_params_t
,
params
:
_params_t
,
optim
:
Type
[
Optimizer
]
=
SGD
,
optim
:
Type
[
Optimizer
]
=
SGD
,
group
:
Optional
[
Any
]
=
None
,
group
:
Optional
[
Any
]
=
None
,
broadcast_buffer_size
:
int
=
2
**
24
,
broadcast_buffer_size
:
int
=
-
1
,
**
default
:
Any
,
**
default
:
Any
,
):
):
...
@@ -101,12 +99,9 @@ class OSS(Optimizer):
...
@@ -101,12 +99,9 @@ class OSS(Optimizer):
# Current default device is set by the parameters allocated to this rank
# Current default device is set by the parameters allocated to this rank
self
.
_device
=
list
(
self
.
per_device_params
.
keys
())[
0
]
self
.
_device
=
list
(
self
.
per_device_params
.
keys
())[
0
]
self
.
buckets
:
Dict
[
torch
.
device
,
List
[
torch
.
Tensor
]]
=
{}
self
.
buffer_max_size
=
broadcast_buffer_size
self
.
should_bucket_param
:
List
[
bool
]
=
[]
self
.
work_handles
:
Deque
[
Workhandle
]
=
deque
()
self
.
work_handles
:
Deque
[
Workhandle
]
=
deque
()
self
.
_setup_bucket_strategy
()
self
.
buckets
:
Dict
[
torch
.
device
,
List
[
torch
.
Tensor
]]
=
{}
self
.
_setup_flat_buffers
()
# Partition helpers
# Partition helpers
def
partition_parameters
(
self
)
->
List
[
List
[
dict
]]:
def
partition_parameters
(
self
)
->
List
[
List
[
dict
]]:
...
@@ -509,7 +504,7 @@ class OSS(Optimizer):
...
@@ -509,7 +504,7 @@ class OSS(Optimizer):
self
.
optim
.
add_param_group
(
param_groups
[
-
1
])
self
.
optim
.
add_param_group
(
param_groups
[
-
1
])
# Update the bucketing strategy accordingly
# Update the bucketing strategy accordingly
self
.
_setup_
bucket_strategy
()
self
.
_setup_
flat_buffers
()
def
_clear_cache
(
self
)
->
None
:
def
_clear_cache
(
self
)
->
None
:
self
.
_partition_parameters
.
clear
()
self
.
_partition_parameters
.
clear
()
...
@@ -540,25 +535,11 @@ class OSS(Optimizer):
...
@@ -540,25 +535,11 @@ class OSS(Optimizer):
def
_broadcast_params
(
self
)
->
None
:
def
_broadcast_params
(
self
)
->
None
:
"""Helper function to broadcast all the parameters from a given device"""
"""Helper function to broadcast all the parameters from a given device"""
i_param
=
0
last_work_handle
=
None
# Work handles are consumed within this scope, no callback
last_work_handle
=
None
# Work handles are consumed within this scope, no callback
for
(
device
,
device_params
,)
in
self
.
per_device_params
.
items
():
# all the params on this device (inc all ranks)
for
device
in
self
.
buckets
.
keys
():
buckets
=
self
.
buckets
[
device
]
for
src_rank
,
bucket
in
enumerate
(
self
.
buckets
[
device
]):
# Bucket and issue all the async calls
for
(
src_rank
,
params
),
bucket
in
zip
(
enumerate
(
device_params
),
buckets
):
global_src_rank
=
self
.
get_global_rank
(
self
.
group
,
src_rank
)
global_src_rank
=
self
.
get_global_rank
(
self
.
group
,
src_rank
)
# Direct broadcasts only
for
param
in
params
:
if
not
self
.
should_bucket_param
[
i_param
]:
last_work_handle
=
dist
.
broadcast
(
tensor
=
param
.
data
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
)
i_param
+=
1
# Bucket broadcasts
last_work_handle
=
dist
.
broadcast
(
tensor
=
bucket
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
)
last_work_handle
=
dist
.
broadcast
(
tensor
=
bucket
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
)
# Only check on the last handle, they're all inlined on the same CUDA stream
# Only check on the last handle, they're all inlined on the same CUDA stream
...
@@ -569,7 +550,6 @@ class OSS(Optimizer):
...
@@ -569,7 +550,6 @@ class OSS(Optimizer):
"""Consume all the futures which are tied to this optimizer's buckets.
"""Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
"""
while
len
(
self
.
work_handles
)
>
0
:
while
len
(
self
.
work_handles
)
>
0
:
work_handle
=
self
.
work_handles
.
popleft
()
work_handle
=
self
.
work_handles
.
popleft
()
work_handle
.
handle
.
wait
()
work_handle
.
handle
.
wait
()
...
@@ -583,51 +563,26 @@ class OSS(Optimizer):
...
@@ -583,51 +563,26 @@ class OSS(Optimizer):
if
work_handle
.
callback
is
not
None
:
if
work_handle
.
callback
is
not
None
:
work_handle
.
callback
()
work_handle
.
callback
()
def
_setup_bucket_strategy
(
self
)
->
None
:
def
_setup_flat_buffers
(
self
)
->
None
:
"""Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered
"""Make all params which are on the same device and tied to the same rank views of a single buffer.
(smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent
This is used at construction time, and anytime parameter trainability is changed (frozen or unfrozen) and
over the wire.
`refresh_trainability` is called.
Generating the partition once and for all allows us to save some time at runtime, and to know when all the
network requests have been issued.
"""
"""
# (re) allocate the buckets
# - Get the correct size for the buckets, cannot be bigger than the model
model_size
=
sum
([
p
.
numel
()
for
p
in
self
.
param_to_rank
.
keys
()])
self
.
bucket_size
=
min
(
self
.
buffer_max_size
,
model_size
)
logging
.
info
(
"Bucket size: {:.2f}M parameters, model size {:.2f}M parameters"
.
format
(
self
.
bucket_size
/
2
**
20
,
model_size
/
2
**
20
)
)
# - Allocate one buffer per rank and per device to group the small parameters
for
device
,
per_device
in
self
.
per_device_params
.
items
():
self
.
buckets
[
device
]
=
[
torch
.
zeros
(
self
.
bucket_size
,
dtype
=
per_device
[
0
][
0
].
dtype
,
device
=
device
)
for
_
in
range
(
len
(
per_device
))
]
# Devise the bucketing strategy
for
device
,
per_rank_params
in
self
.
per_device_params
.
items
():
for
device
,
per_rank_params
in
self
.
per_device_params
.
items
():
self
.
buckets
[
device
]
=
[]
for
dst_rank
,
params
in
enumerate
(
per_rank_params
):
for
dst_rank
,
params
in
enumerate
(
per_rank_params
):
if
len
(
params
)
>
0
:
trainable_params
=
list
(
filter
(
lambda
x
:
x
.
requires_grad
,
params
))
buffer_size
=
sum
(
map
(
lambda
x
:
x
.
numel
(),
trainable_params
))
self
.
buckets
[
device
].
append
(
torch
.
empty
(
buffer_size
,
dtype
=
params
[
0
].
dtype
,
device
=
device
))
offset
=
0
offset
=
0
for
param
in
params
:
for
param
in
trainable_params
:
# Criteria to decide whether this parameter is to be bucketed or not:
# - enough room in the bucket
if
param
.
requires_grad
and
(
offset
+
param
.
numel
())
<
self
.
bucket_size
:
self
.
should_bucket_param
.
append
(
True
)
# This parameter becomes a view of the bucket
# This parameter becomes a view of the bucket
offset_next
=
offset
+
param
.
numel
()
offset_next
=
offset
+
param
.
numel
()
self
.
buckets
[
device
][
dst_rank
][
offset
:
offset_next
].
copy_
(
param
.
data
.
flatten
())
self
.
buckets
[
device
][
dst_rank
][
offset
:
offset_next
].
copy_
(
param
.
data
.
flatten
())
param
.
data
=
self
.
buckets
[
device
][
dst_rank
][
offset
:
offset_next
].
view_as
(
param
.
data
)
param
.
data
=
self
.
buckets
[
device
][
dst_rank
][
offset
:
offset_next
].
view_as
(
param
.
data
)
offset
=
offset_next
offset
=
offset_next
else
:
self
.
should_bucket_param
.
append
(
False
)
# Resize the bucket to remove lost space in the end
self
.
buckets
[
device
][
dst_rank
].
resize_
(
offset
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment