Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e722c4a9
Commit
e722c4a9
authored
Sep 23, 2021
by
mshoeybi
Browse files
tested and woking
parent
107c29e8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
5 deletions
+18
-5
megatron/inference/communication.py
megatron/inference/communication.py
+18
-5
No files found.
megatron/inference/communication.py
View file @
e722c4a9
...
...
@@ -38,12 +38,25 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
return
tensor
def
broadcast_list
(
size
,
dtype
,
list_values
=
None
,
rank
=
0
):
"""Broadcast a list of values with a given type."""
tensor
=
None
if
torch
.
distributed
.
get_rank
()
==
rank
:
tensor
=
torch
.
tensor
(
list_values
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
return
broadcast_tensor
(
size
,
dtype
,
tensor
=
tensor
,
rank
=
rank
)
def
broadcast_int_list
(
size
,
int_list
=
None
,
rank
=
0
):
"""Broadcast a list of interger values."""
long_tensor
=
None
if
torch
.
distributed
.
get_rank
()
==
rank
:
long_tensor
=
torch
.
tensor
(
int_list
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
return
broadcast_list
(
size
,
torch
.
int64
,
list_values
=
int_list
,
rank
=
rank
)
def
broadcast_float_list
(
size
,
float_list
=
None
,
rank
=
0
):
"""Broadcast a list of float values."""
return
broadcast_tensor
(
size
,
torch
.
int64
,
tensor
=
long_tensor
,
rank
=
rank
)
return
broadcast_list
(
size
,
torch
.
float32
,
list_values
=
float_list
,
rank
=
rank
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment