Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a33e1b35
Commit
a33e1b35
authored
Sep 22, 2021
by
rprenger
Browse files
Fixing bug where temperature was never actually broadcast
parent
5ab64637
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
10 deletions
+11
-10
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+11
-10
No files found.
megatron/text_generation_utils.py
View file @
a33e1b35
...
...
@@ -108,13 +108,13 @@ def tokenize_batch(sentences, max_len, add_BOS):
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
return
context_tokens_tensor
,
context_length_tensor
def
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_probs
):
def
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_probs
,
temperature
):
"""
Needs to be synced up with receive_generate_info
"""
# Send the sizes of the tensors
input_info
=
[
context_tokens_tensor
.
size
(
0
),
context_tokens_tensor
.
size
(
1
),
tokens_to_generate
,
all_probs
]
input_info_tensor
=
torch
.
cuda
.
Long
Tensor
(
input_info
)
input_info
=
[
context_tokens_tensor
.
size
(
0
),
context_tokens_tensor
.
size
(
1
),
tokens_to_generate
,
all_probs
,
temperature
]
input_info_tensor
=
torch
.
cuda
.
Float
Tensor
(
input_info
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
# Send variables to all ranks
...
...
@@ -125,12 +125,13 @@ def receive_generate_info():
"""
Needs to be synced up with send_generate_info
"""
input_info_tensor
=
torch
.
empty
(
4
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
input_info_tensor
=
torch
.
empty
(
5
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
batch_size
=
input_info_tensor
[
0
].
item
()
seq_len
=
input_info_tensor
[
1
].
item
()
tokens_to_generate
=
input_info_tensor
[
2
].
item
()
all_probs
=
input_info_tensor
[
3
].
item
()
batch_size
=
int
(
input_info_tensor
[
0
].
item
())
seq_len
=
int
(
input_info_tensor
[
1
].
item
())
tokens_to_generate
=
int
(
input_info_tensor
[
2
].
item
())
all_probs
=
int
(
input_info_tensor
[
3
].
item
())
temperature
=
float
(
input_info_tensor
[
4
].
item
())
context_length_tensor
=
torch
.
empty
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
context_tokens_tensor
=
torch
.
empty
(
batch_size
,
seq_len
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
...
...
@@ -139,7 +140,7 @@ def receive_generate_info():
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
return
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
all_probs
return
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
all_probs
,
temperature
def
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_probs
,
temperature
):
context_length
=
context_length_tensor
.
min
().
item
()
...
...
@@ -182,7 +183,7 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
model
.
eval
()
if
torch
.
distributed
.
get_rank
()
==
0
:
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
,
tokens_to_generate
,
add_BOS
)
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_probs
)
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_probs
,
temperature
)
else
:
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
all_probs
=
receive_generate_info
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment