Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
a31b08a5
Unverified
Commit
a31b08a5
authored
Oct 23, 2020
by
Benjamin Lefaudeux
Committed by
GitHub
Oct 23, 2020
Browse files
[refactor] OSS - broadcasts - getting rid of the while loop (#165)
* small refactor, getting rid of the while loop
parent
339cf060
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
45 additions
and
42 deletions
+45
-42
fairscale/optim/oss.py
fairscale/optim/oss.py
+44
-42
stubs/torch/__init__.pyi
stubs/torch/__init__.pyi
+1
-0
No files found.
fairscale/optim/oss.py
View file @
a31b08a5
...
@@ -418,58 +418,60 @@ class OSS(Optimizer):
...
@@ -418,58 +418,60 @@ class OSS(Optimizer):
return
global_rank
return
global_rank
def
_broadcast_params
(
self
,
buffers
:
List
[
torch
.
Tensor
],
per_rank_params
:
List
[
List
[
Parameter
]])
->
None
:
def
_broadcast_params
(
self
,
buffers
:
List
[
torch
.
Tensor
],
per_rank_params
:
List
[
List
[
Parameter
]])
->
None
:
"""Helper function to broadcast all the parameters from a given device
"""Helper function to broadcast all the parameters from a given device"""
"""
buffer_size
=
buffers
[
0
].
numel
()
buffer_size
=
buffers
[
0
].
numel
()
bucket_requests
=
[]
bucket_requests
=
[]
direct_requests
=
[]
direct_requests
=
[]
# Bucket and issue all the async calls
# Bucket and issue all the async calls
for
(
dst_rank
,
params
),
buffer
in
zip
(
enumerate
(
per_rank_params
),
buffers
):
for
(
src_rank
,
params
),
buffer
in
zip
(
enumerate
(
per_rank_params
),
buffers
):
# All the params are sorted per rank and per increasing size
global_src_rank
=
self
.
get_global_rank
(
self
.
group
,
src_rank
)
if
len
(
params
)
==
0
:
continue
global_dst_rank
=
OSS
.
get_global_rank
(
self
.
group
,
dst_rank
)
# Copy small parameters into per-GPU buffers
# Copy small parameters into per-GPU buffers and then async broadcast
i_bucketed
=
0
# the number of tensors packed in the buffer
offset
=
0
offset
=
0
bucket_sent
=
False
bucket_params
=
[]
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
# All the params are sorted per rank and per increasing size
while
i_bucketed
<
len
(
params
)
and
offset
+
params
[
i_bucketed
].
numel
()
<
buffer_size
:
for
p
in
params
:
end
=
offset
+
params
[
i_bucketed
].
numel
()
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
if
global_dst_rank
==
self
.
global_rank
:
if
not
bucket_sent
and
offset
+
p
.
numel
()
<
buffer_size
:
buffer
[
offset
:
end
].
copy_
(
params
[
i_bucketed
].
data
.
view
(
-
1
))
# type: ignore
end
=
offset
+
p
.
numel
()
offset
=
end
buffer
[
offset
:
end
].
copy_
(
p
.
data
.
view
(
-
1
))
i_bucketed
+=
1
bucket_params
.
append
((
p
,
offset
,
end
))
offset
=
end
if
i_bucketed
>
0
:
else
:
future
=
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_dst_rank
,
group
=
self
.
group
,
async_op
=
True
)
if
offset
>
0
and
not
bucket_sent
:
if
global_dst_rank
!=
self
.
global_rank
:
bucket_requests
.
append
(
# This request will need to be unrolled
(
bucket_requests
.
append
((
future
,
dst_rank
))
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
src_rank
,
# Directly broadcast the rest
bucket_params
,
for
param
in
params
[
i_bucketed
:]:
)
direct_requests
.
append
(
)
dist
.
broadcast
(
tensor
=
param
.
data
,
src
=
global_dst_rank
,
group
=
self
.
group
,
async_op
=
True
),
bucket_sent
=
True
direct_requests
.
append
(
dist
.
broadcast
(
tensor
=
p
.
data
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
)
)
# Catch a trailing bucket
if
not
bucket_sent
:
bucket_requests
.
append
(
(
dist
.
broadcast
(
tensor
=
buffer
,
src
=
global_src_rank
,
group
=
self
.
group
,
async_op
=
True
),
src_rank
,
bucket_params
,
)
)
)
# Unroll the initial packed small parameters
# Unroll the initial packed small parameters
for
gate
,
rank
in
bucket_requests
:
for
work_handle
,
src_rank
,
bucket_params
in
bucket_requests
:
gate
.
wait
()
work_handle
.
wait
()
if
src_rank
!=
self
.
rank
:
params
=
per_rank_params
[
rank
]
for
p
,
offset
,
end
in
bucket_params
:
buffer
=
buffers
[
rank
]
p
.
data
.
copy_
(
buffers
[
src_rank
][
offset
:
end
].
view_as
(
p
.
data
))
i_bucketed
=
0
# the number of tensors packed in the buffer
offset
=
0
while
i_bucketed
<
len
(
params
)
and
offset
+
params
[
i_bucketed
].
numel
()
<
buffer_size
:
end
=
offset
+
params
[
i_bucketed
].
numel
()
params
[
i_bucketed
].
data
.
copy_
(
buffer
[
offset
:
end
].
view_as
(
params
[
i_bucketed
]))
# type: ignore
offset
=
end
i_bucketed
+=
1
# Unroll all the async work items,
wait for completion
# Unroll all the async work items,
just in case
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
direct_requests
))
_
=
list
(
map
(
lambda
x
:
x
.
wait
(),
direct_requests
))
stubs/torch/__init__.pyi
View file @
a31b08a5
...
@@ -320,6 +320,7 @@ class Tensor:
...
@@ -320,6 +320,7 @@ class Tensor:
def coalesce(self) -> Tensor: ...
def coalesce(self) -> Tensor: ...
def conj(self) -> Tensor: ...
def conj(self) -> Tensor: ...
def contiguous(self) -> Tensor: ...
def contiguous(self) -> Tensor: ...
def copy_(self, other: Tensor) -> None: ...
def cos(self) -> Tensor: ...
def cos(self) -> Tensor: ...
def cos_(self) -> Tensor: ...
def cos_(self) -> Tensor: ...
def cosh(self) -> Tensor: ...
def cosh(self) -> Tensor: ...
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment