Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a31b08a5
Unverified
Commit
a31b08a5
authored
Oct 23, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Oct 23, 2020
Browse files
[refactor] OSS - broadcasts - getting rid of the while loop (#165)
* small refactor, getting rid of the while loop
parent
339cf060
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
42 deletions
+45
-42
fairscale/optim/oss.py
fairscale/optim/oss.py
+44
-42
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+1
-0
No files found.
fairscale/optim/oss.py
View file @
a31b08a5
...
...
@@ -418,58 +418,60 @@ class OSS(Optimizer):
return
global_rank
def
_broadcast_params
(
self
,
buffers
:
List
[
torch
.
Tensor
],
per_rank_params
:
List
[
List
[
Parameter
]])
->
None
:
"""Helper function to broadcast all the parameters from a given device
"""
"""Helper function to broadcast all the parameters from a given device"""
buffer_size
=
buffers
[
0
].
numel
()
bucket_requests
=
[]
direct_requests
=
[]
# Bucket and issue all the async calls
for
(
dst_rank
,
params
),
buffer
in
zip
(
enumerate
(
per_rank_params
),
buffers
):
# All the params are sorted per rank and per increasing size
if
len
(
params
)
==
0
:
continue
global_dst_rank
=
OSS
.
get_global_rank
(
self
.
group
,
dst_rank
)
for
(
src_rank
,
params
),
buffer
in
zip
(
enumerate
(
per_rank_params
),
buffers
):
global_src_rank
=
self
.
get_global_rank
(
self
.
group
,
src_rank
)
# Copy small parameters into per-GPU buffers
i_bucketed
=
0
# the number of tensors packed in the buffer
# Copy small parameters into per-GPU buffers and then async broadcast
offset
=
0
bucket_sent
=
False
bucket_params
=
[]
# All the params are sorted per rank and per increasing size
for
p
in
params
:
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while
i_bucketed
<
len
(
params
)
and
offset
+
params
[
i_bucketed
]
.
numel
()
<
buffer_size
:
end
=
offset
+
p
arams
[
i_bucketed
]
.
numel
()
if
global_dst_rank
==
self
.
global_rank
:
bu
ffer
[
offset
:
end
].
copy_
(
params
[
i_bucketed
].
data
.
view
(
-
1
))
# type: ignore
if
not
bucket_sent
and
offset
+
p
.
numel
()
<
buffer_size
:
end
=
offset
+
p
.
numel
()
buffer
[
offset
:
end
].
copy_
(
p
.
data
.
view
(
-
1
))
bu
cket_params
.
append
((
p
,
offset
,
end
))
offset
=
end
i_bucketed
+=
1
else
:
if
offset
>
0
and
not
bucket_sent
:
bucket_requests
.
append
(
(
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
src_rank
,
bucket_params
,
)
)
if
i_bucketed
>
0
:
future
=
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_dst_rank
,
group
=
self
.
group
,
async_op
=
True
)
if
global_dst_rank
!=
self
.
global_rank
:
# This request will need to be unrolled
bucket_requests
.
append
((
future
,
dst_rank
))
bucket_sent
=
True
# Directly broadcast the rest
for
param
in
params
[
i_bucketed
:]:
direct_requests
.
append
(
dist
.
broadcast
(
tensor
=
p
aram
.
data
,
src
=
global_
dst
_rank
,
group
=
self
.
group
,
async_op
=
True
)
,
dist
.
broadcast
(
tensor
=
p
.
data
,
src
=
global_
src
_rank
,
group
=
self
.
group
,
async_op
=
True
)
)
# Unroll the initial packed small parameters
for
gate
,
rank
in
bucket_requests
:
gate
.
wait
()
params
=
per_rank_params
[
rank
]
buffer
=
buffers
[
rank
]
i_bucketed
=
0
# the number of tensors packed in the buffer
offset
=
0
# Catch a trailing bucket
if
not
bucket_sent
:
bucket_requests
.
append
(
(
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
src_rank
,
bucket_params
,
)
)
while
i_bucketed
<
len
(
params
)
and
offset
+
params
[
i_bucketed
].
numel
()
<
buffer_size
:
end
=
offset
+
params
[
i_bucketed
].
numel
()
params
[
i_bucketed
].
data
.
copy_
(
buffer
[
offset
:
end
].
view_as
(
params
[
i_bucketed
]))
# type: ignore
offset
=
end
i_bucketed
+=
1
# Unroll the initial packed small parameters
for
work_handle
,
src_rank
,
bucket_params
in
bucket_requests
:
work_handle
.
wait
()
if
src_rank
!=
self
.
rank
:
for
p
,
offset
,
end
in
bucket_params
:
p
.
data
.
copy_
(
buffers
[
src_rank
][
offset
:
end
].
view_as
(
p
.
data
))
# Unroll all the async work items,
wait for completion
# Unroll all the async work items,
just in case
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
direct_requests
))
stubs/torch/__init__.pyi
View file @
a31b08a5
...
...
@@ -320,6 +320,7 @@ class Tensor:
def coalesce(self) -> Tensor: ...
def conj(self) -> Tensor: ...
def contiguous(self) -> Tensor: ...
def copy_(self, other: Tensor) -> None: ...
def cos(self) -> Tensor: ...
def cos_(self) -> Tensor: ...
def cosh(self) -> Tensor: ...
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment