Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a31b08a5
Unverified
Commit
a31b08a5
authored
Oct 23, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Oct 23, 2020
Browse files
[refactor] OSS - broadcasts - getting rid of the while loop (#165)
* small refactor, getting rid of the while loop
parent
339cf060
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
42 deletions
+45
-42
fairscale/optim/oss.py
fairscale/optim/oss.py
+44
-42
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+1
-0
No files found.
fairscale/optim/oss.py
View file @
a31b08a5
...
@@ -418,58 +418,60 @@ class OSS(Optimizer):
...
@@ -418,58 +418,60 @@ class OSS(Optimizer):
return
global_rank
return
global_rank
def
_broadcast_params
(
self
,
buffers
:
List
[
torch
.
Tensor
],
per_rank_params
:
List
[
List
[
Parameter
]])
->
None
:
def
_broadcast_params
(
self
,
buffers
:
List
[
torch
.
Tensor
],
per_rank_params
:
List
[
List
[
Parameter
]])
->
None
:
"""Helper function to broadcast all the parameters from a given device
"""Helper function to broadcast all the parameters from a given device"""
"""
buffer_size
=
buffers
[
0
].
numel
()
buffer_size
=
buffers
[
0
].
numel
()
bucket_requests
=
[]
bucket_requests
=
[]
direct_requests
=
[]
direct_requests
=
[]
# Bucket and issue all the async calls
# Bucket and issue all the async calls
for
(
dst_rank
,
params
),
buffer
in
zip
(
enumerate
(
per_rank_params
),
buffers
):
for
(
src_rank
,
params
),
buffer
in
zip
(
enumerate
(
per_rank_params
),
buffers
):
# All the params are sorted per rank and per increasing size
global_src_rank
=
self
.
get_global_rank
(
self
.
group
,
src_rank
)
if
len
(
params
)
==
0
:
continue
global_dst_rank
=
OSS
.
get_global_rank
(
self
.
group
,
dst_rank
)
# Copy small parameters into per-GPU buffers
# Copy small parameters into per-GPU buffers and then async broadcast
i_bucketed
=
0
# the number of tensors packed in the buffer
offset
=
0
offset
=
0
bucket_sent
=
False
bucket_params
=
[]
# All the params are sorted per rank and per increasing size
for
p
in
params
:
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while
i_bucketed
<
len
(
params
)
and
offset
+
params
[
i_bucketed
]
.
numel
()
<
buffer_size
:
if
not
bucket_sent
and
offset
+
p
.
numel
()
<
buffer_size
:
end
=
offset
+
p
arams
[
i_bucketed
]
.
numel
()
end
=
offset
+
p
.
numel
()
if
global_dst_rank
==
self
.
global_rank
:
buffer
[
offset
:
end
].
copy_
(
p
.
data
.
view
(
-
1
))
bu
ffer
[
offset
:
end
].
copy_
(
params
[
i_bucketed
].
data
.
view
(
-
1
))
# type: ignore
bu
cket_params
.
append
((
p
,
offset
,
end
))
offset
=
end
offset
=
end
i_bucketed
+=
1
else
:
if
offset
>
0
and
not
bucket_sent
:
bucket_requests
.
append
(
(
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
src_rank
,
bucket_params
,
)
)
if
i_bucketed
>
0
:
bucket_sent
=
True
future
=
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_dst_rank
,
group
=
self
.
group
,
async_op
=
True
)
if
global_dst_rank
!=
self
.
global_rank
:
# This request will need to be unrolled
bucket_requests
.
append
((
future
,
dst_rank
))
# Directly broadcast the rest
for
param
in
params
[
i_bucketed
:]:
direct_requests
.
append
(
direct_requests
.
append
(
dist
.
broadcast
(
tensor
=
p
aram
.
data
,
src
=
global_
dst
_rank
,
group
=
self
.
group
,
async_op
=
True
)
,
dist
.
broadcast
(
tensor
=
p
.
data
,
src
=
global_
src
_rank
,
group
=
self
.
group
,
async_op
=
True
)
)
)
# Unroll the initial packed small parameters
# Catch a trailing bucket
for
gate
,
rank
in
bucket_requests
:
if
not
bucket_sent
:
gate
.
wait
()
bucket_requests
.
append
(
(
params
=
per_rank_params
[
rank
]
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
buffer
=
buffers
[
rank
]
src_rank
,
i_bucketed
=
0
# the number of tensors packed in the buffer
bucket_params
,
offset
=
0
)
)
while
i_bucketed
<
len
(
params
)
and
offset
+
params
[
i_bucketed
].
numel
()
<
buffer_size
:
# Unroll the initial packed small parameters
end
=
offset
+
params
[
i_bucketed
].
numel
()
for
work_handle
,
src_rank
,
bucket_params
in
bucket_requests
:
params
[
i_bucketed
].
data
.
copy_
(
buffer
[
offset
:
end
].
view_as
(
params
[
i_bucketed
]))
# type: ignore
work_handle
.
wait
()
offset
=
end
if
src_rank
!=
self
.
rank
:
i_bucketed
+=
1
for
p
,
offset
,
end
in
bucket_params
:
p
.
data
.
copy_
(
buffers
[
src_rank
][
offset
:
end
].
view_as
(
p
.
data
))
# Unroll all the async work items,
wait for completion
# Unroll all the async work items,
just in case
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
direct_requests
))
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
direct_requests
))
stubs/torch/__init__.pyi
View file @
a31b08a5
...
@@ -320,6 +320,7 @@ class Tensor:
...
@@ -320,6 +320,7 @@ class Tensor:
def coalesce(self) -> Tensor: ...
def coalesce(self) -> Tensor: ...
def conj(self) -> Tensor: ...
def conj(self) -> Tensor: ...
def contiguous(self) -> Tensor: ...
def contiguous(self) -> Tensor: ...
def copy_(self, other: Tensor) -> None: ...
def cos(self) -> Tensor: ...
def cos(self) -> Tensor: ...
def cos_(self) -> Tensor: ...
def cos_(self) -> Tensor: ...
def cosh(self) -> Tensor: ...
def cosh(self) -> Tensor: ...
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment