Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
21d2b0fe
Commit
21d2b0fe
authored
Oct 15, 2021
by
rprenger
Browse files
Allowing for a 0 tokens/just scoring mode
parent
8d405805
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
14 deletions
+25
-14
megatron/text_generation/api.py
megatron/text_generation/api.py
+12
-7
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+6
-4
megatron/text_generation_server.py
megatron/text_generation_server.py
+7
-3
No files found.
megatron/text_generation/api.py
View file @
21d2b0fe
...
...
@@ -37,7 +37,8 @@ def generate_and_post_process(model,
top_p_sampling
=
0.0
,
temperature
=
1.0
,
add_BOS
=
False
,
use_eod_token_for_early_termination
=
True
):
use_eod_token_for_early_termination
=
True
,
just_score
=
False
):
"""Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
...
...
@@ -53,7 +54,8 @@ def generate_and_post_process(model,
top_p_sampling
=
top_p_sampling
,
temperature
=
temperature
,
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
)
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
just_score
=
just_score
)
# Only post-process on first stage.
if
mpu
.
is_pipeline_first_stage
():
...
...
@@ -83,7 +85,8 @@ def generate(model,
top_p_sampling
=
0.0
,
temperature
=
1.0
,
add_BOS
=
False
,
use_eod_token_for_early_termination
=
True
):
use_eod_token_for_early_termination
=
True
,
just_score
=
False
):
"""Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can
...
...
@@ -97,8 +100,8 @@ def generate(model,
values
=
[
tokens_to_generate
,
return_output_log_probs
,
return_all_log_probs
,
greedy_sampling
,
top_k_sampling
,
top_p_sampling
,
temperature
,
add_BOS
,
use_eod_token_for_early_termination
]
values_float_tensor
=
broadcast_float_list
(
9
,
float_list
=
values
)
temperature
,
add_BOS
,
use_eod_token_for_early_termination
,
just_score
]
values_float_tensor
=
broadcast_float_list
(
10
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
return_output_log_probs
=
bool
(
values_float_tensor
[
1
].
item
())
return_all_log_probs
=
bool
(
values_float_tensor
[
2
].
item
())
...
...
@@ -108,12 +111,13 @@ def generate(model,
temperature
=
values_float_tensor
[
6
].
item
()
add_BOS
=
bool
(
values_float_tensor
[
7
].
item
())
use_eod_token_for_early_termination
=
bool
(
values_float_tensor
[
8
].
item
())
just_score
=
bool
(
values_float_tensor
[
9
].
item
())
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
if
torch
.
distributed
.
get_rank
()
==
0
:
assert
prompts
is
not
None
assert
tokens_to_generate
>
0
#
assert tokens_to_generate > 0
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
add_BOS
=
add_BOS
)
...
...
@@ -125,4 +129,5 @@ def generate(model,
return_all_log_probs
=
return_all_log_probs
,
greedy
=
greedy_sampling
,
top_k
=
top_k_sampling
,
top_p
=
top_p_sampling
,
temperature
=
temperature
,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
)
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
just_score
=
just_score
)
megatron/text_generation/generation.py
View file @
21d2b0fe
...
...
@@ -34,7 +34,8 @@ def generate_tokens_probs_and_return_on_first_stage(
return_all_log_probs
=
False
,
greedy
=
False
,
top_k
=
0
,
top_p
=
0.0
,
temperature
=
1.0
,
use_eod_token_for_early_termination
=
True
):
use_eod_token_for_early_termination
=
True
,
just_score
=
False
):
"""Main token generation function.
Arguments:
model: no interleaving is supported.
...
...
@@ -107,8 +108,9 @@ def generate_tokens_probs_and_return_on_first_stage(
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
generated_sequence_lengths
=
torch
.
ones
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
*
max_sequence_length
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
*
max_sequence_length
# Whether we have reached a termination id.
is_generation_done
=
torch
.
zeros
(
batch_size
,
dtype
=
torch
.
uint8
,
device
=
torch
.
cuda
.
current_device
())
...
...
@@ -207,7 +209,7 @@ def generate_tokens_probs_and_return_on_first_stage(
tokens
=
tokens
[:,
:(
context_length
+
1
)]
if
mpu
.
is_pipeline_last_stage
():
if
return_output_log_probs
:
output_log_probs
=
output_log_probs
[:,
:
context_length
]
output_log_probs
=
output_log_probs
[:,
:
context_length
]
.
contiguous
()
if
return_all_log_probs
:
all_log_probs
=
all_log_probs
[:,
:
context_length
,
:]
...
...
megatron/text_generation_server.py
View file @
21d2b0fe
...
...
@@ -54,12 +54,15 @@ class MegatronGenerate(Resource):
return
"Maximum number of prompts is 128"
,
400
tokens_to_generate
=
64
# Choosing hopefully sane default. Full sequence is slow
just_score
=
False
if
"tokens_to_generate"
in
request
.
get_json
():
tokens_to_generate
=
request
.
get_json
()[
"tokens_to_generate"
]
if
not
isinstance
(
tokens_to_generate
,
int
):
return
"tokens_to_generate must be an integer greater than 0"
if
tokens_to_generate
<
1
:
return
"tokens_to_generate must be an integer greater than 0"
if
tokens_to_generate
<
0
:
return
"tokens_to_generate must be an integer greater than or equal to 0"
if
tokens_to_generate
==
0
:
just_score
=
True
logprobs
=
False
if
"logprobs"
in
request
.
get_json
():
...
...
@@ -113,7 +116,8 @@ class MegatronGenerate(Resource):
top_p_sampling
=
top_p
,
temperature
=
temperature
,
add_BOS
=
add_BOS
,
use_eod_token_for_early_termination
=
True
)
use_eod_token_for_early_termination
=
True
,
just_score
=
just_score
)
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment