Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a31b08a5
Unverified
Commit
a31b08a5
authored
Oct 23, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Oct 23, 2020
Browse files
[refactor] OSS - broadcasts - getting rid of the while loop (#165)
* small refactor, getting rid of the while loop
parent
339cf060
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
42 deletions
+45
-42
fairscale/optim/oss.py
fairscale/optim/oss.py
+44
-42
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+1
-0
No files found.
fairscale/optim/oss.py
View file @
a31b08a5
...
@@ -418,58 +418,60 @@ class OSS(Optimizer):
...
@@ -418,58 +418,60 @@ class OSS(Optimizer):
return
global_rank
return
global_rank
def
_broadcast_params
(
self
,
buffers
:
List
[
torch
.
Tensor
],
per_rank_params
:
List
[
List
[
Parameter
]])
->
None
:
def
_broadcast_params
(
self
,
buffers
:
List
[
torch
.
Tensor
],
per_rank_params
:
List
[
List
[
Parameter
]])
->
None
:
"""Helper function to broadcast all the parameters from a given device
"""Helper function to broadcast all the parameters from a given device"""
"""
buffer_size
=
buffers
[
0
].
numel
()
buffer_size
=
buffers
[
0
].
numel
()
bucket_requests
=
[]
bucket_requests
=
[]
direct_requests
=
[]
direct_requests
=
[]
# Bucket and issue all the async calls
# Bucket and issue all the async calls
for
(
dst_rank
,
params
),
buffer
in
zip
(
enumerate
(
per_rank_params
),
buffers
):
for
(
src_rank
,
params
),
buffer
in
zip
(
enumerate
(
per_rank_params
),
buffers
):
# All the params are sorted per rank and per increasing size
global_src_rank
=
self
.
get_global_rank
(
self
.
group
,
src_rank
)
if
len
(
params
)
==
0
:
continue
global_dst_rank
=
OSS
.
get_global_rank
(
self
.
group
,
dst_rank
)
# Copy small parameters into per-GPU buffers
# Copy small parameters into per-GPU buffers and then async broadcast
i_bucketed
=
0
# the number of tensors packed in the buffer
offset
=
0
offset
=
0
bucket_sent
=
False
bucket_params
=
[]
# All the params are sorted per rank and per increasing size
for
p
in
params
:
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while
i_bucketed
<
len
(
params
)
and
offset
+
params
[
i_bucketed
]
.
numel
()
<
buffer_size
:
if
not
bucket_sent
and
offset
+
p
.
numel
()
<
buffer_size
:
end
=
offset
+
p
arams
[
i_bucketed
]
.
numel
()
end
=
offset
+
p
.
numel
()
if
global_dst_rank
==
self
.
global_rank
:
buffer
[
offset
:
end
].
copy_
(
p
.
data
.
view
(
-
1
))
bu
ffer
[
offset
:
end
].
copy_
(
params
[
i_bucketed
].
data
.
view
(
-
1
))
# type: ignore
bu
cket_params
.
append
((
p
,
offset
,
end
))
offset
=
end
offset
=
end
i_bucketed
+=
1
else
:
if
offset
>
0
and
not
bucket_sent
:
bucket_requests
.
append
(
(
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
src_rank
,
bucket_params
,
)
)
if
i_bucketed
>
0
:
bucket_sent
=
True
future
=
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_dst_rank
,
group
=
self
.
group
,
async_op
=
True
)
if
global_dst_rank
!=
self
.
global_rank
:
# This request will need to be unrolled
bucket_requests
.
append
((
future
,
dst_rank
))
# Directly broadcast the rest
for
param
in
params
[
i_bucketed
:]:
direct_requests
.
append
(
direct_requests
.
append
(
dist
.
broadcast
(
tensor
=
p
aram
.
data
,
src
=
global_
dst
_rank
,
group
=
self
.
group
,
async_op
=
True
)
,
dist
.
broadcast
(
tensor
=
p
.
data
,
src
=
global_
src
_rank
,
group
=
self
.
group
,
async_op
=
True
)
)
)
# Unroll the initial packed small parameters
# Catch a trailing bucket
for
gate
,
rank
in
bucket_requests
:
if
not
bucket_sent
:
gate
.
wait
()
bucket_requests
.
append
(
(
params
=
per_rank_params
[
rank
]
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
buffer
=
buffers
[
rank
]
src_rank
,
i_bucketed
=
0
# the number of tensors packed in the buffer
bucket_params
,
offset
=
0
)
)
while
i_bucketed
<
len
(
params
)
and
offset
+
params
[
i_bucketed
].
numel
()
<
buffer_size
:
# Unroll the initial packed small parameters
end
=
offset
+
params
[
i_bucketed
].
numel
()
for
work_handle
,
src_rank
,
bucket_params
in
bucket_requests
:
params
[
i_bucketed
].
data
.
copy_
(
buffer
[
offset
:
end
].
view_as
(
params
[
i_bucketed
]))
# type: ignore
work_handle
.
wait
()
offset
=
end
if
src_rank
!=
self
.
rank
:
i_bucketed
+=
1
for
p
,
offset
,
end
in
bucket_params
:
p
.
data
.
copy_
(
buffers
[
src_rank
][
offset
:
end
].
view_as
(
p
.
data
))
# Unroll all the async work items,
wait for completion
# Unroll all the async work items,
just in case
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
direct_requests
))
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
direct_requests
))
stubs/torch/__init__.pyi
View file @
a31b08a5
...
@@ -320,6 +320,7 @@ class Tensor:
...
@@ -320,6 +320,7 @@ class Tensor:
def coalesce(self) -> Tensor: ...
def coalesce(self) -> Tensor: ...
def conj(self) -> Tensor: ...
def conj(self) -> Tensor: ...
def contiguous(self) -> Tensor: ...
def contiguous(self) -> Tensor: ...
def copy_(self, other: Tensor) -> None: ...
def cos(self) -> Tensor: ...
def cos(self) -> Tensor: ...
def cos_(self) -> Tensor: ...
def cos_(self) -> Tensor: ...
def cosh(self) -> Tensor: ...
def cosh(self) -> Tensor: ...
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment