Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
fa75238b
Commit
fa75238b
authored
Feb 18, 2022
by
rprenger
Browse files
Almost working beam search
parent
f00d0a3f
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
120 additions
and
1 deletion
+120
-1
megatron/text_generation/api.py
megatron/text_generation/api.py
+39
-1
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+69
-0
megatron/text_generation_server.py
megatron/text_generation_server.py
+6
-0
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+6
-0
No files found.
megatron/text_generation/api.py
View file @
fa75238b
...
...
@@ -22,7 +22,8 @@ from megatron import mpu
from
.communication
import
broadcast_float_list
from
.generation
import
(
generate_tokens_probs_and_return_on_first_stage
,
score_and_return_on_first_stage
)
score_and_return_on_first_stage
,
beam_search_and_return_on_first_stage
)
from
.tokenization
import
(
tokenize_prompts
,
detokenize_generations
)
...
...
@@ -138,3 +139,40 @@ def generate(model,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_eol
=
stop_on_eol
)
def
beam_search_and_post_process
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
beam_size
=
0
,
add_BOS
=
False
):
"""Run beam search and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens
,
scores
=
beam_search
(
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
beam_size
=
beam_size
,
add_BOS
=
add_BOS
)
# Only post-process on first stage.
if
mpu
.
is_pipeline_first_stage
():
lengths
=
tokens
.
size
(
1
)
*
torch
.
ones
(
beam_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
tokens
,
prompts_plus_generations
,
prompts_plus_generations_segments
=
detokenize_generations
(
tokens
,
lengths
,
True
)
return
prompts_plus_generations
,
prompts_plus_generations_segments
,
tokens
return
None
def
beam_search
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
beam_size
=
0
,
add_BOS
=
False
)
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
beam_size
,
add_BOS
]
values_float_tensor
=
broadcast_float_list
(
3
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
beam_size
=
int
(
values_float_tensor
[
1
].
item
())
add_BOS
=
bool
(
values_float_tensor
[
2
].
item
())
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
add_BOS
=
add_BOS
)
return
beam_search_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
,
beam_size
)
megatron/text_generation/generation.py
View file @
fa75238b
...
...
@@ -200,6 +200,7 @@ def generate_tokens_probs_and_return_on_first_stage(
top_p
=
top_p
,
temperature
=
temperature
,
vocab_size
=
tokenizer
.
vocab_size
)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started
=
lengths
<=
context_length
...
...
@@ -281,6 +282,74 @@ def generate_tokens_probs_and_return_on_first_stage(
return
tokens
,
generated_sequence_lengths
,
output_log_probs
def
beam_search_and_return_on_first_stage
(
model
,
tokens
,
lengths
,
beam_size
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
batch_size
=
tokens
.
size
(
0
)
assert
(
batch_size
==
1
)
prompt_length
=
lengths
.
item
()
final_sequence_length
=
tokens
.
size
(
1
)
final_sequence_length
=
min
(
final_sequence_length
,
args
.
max_position_embeddings
)
# If the context is too big, this happens
if
prompt_length
>=
final_sequence_length
:
raise
ValueError
(
"context length + tokens_to_generate too large"
)
# forward step.
forward_step
=
ForwardStep
(
model
,
beam_size
,
final_sequence_length
)
if
mpu
.
is_pipeline_last_stage
():
scores
=
torch
.
zeros
(
beam_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()).
unsqueeze
(
1
)
# =============
# Run infernece
# =============
with
torch
.
no_grad
():
tokens
=
tokens
.
repeat
(
beam_size
,
1
)
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
tokens
)
prev_context_length
=
0
for
context_length
in
range
(
prompt_length
,
final_sequence_length
):
# Pick the slice that we need to pass through the network.
tokens2use
=
tokens
[:,
prev_context_length
:
context_length
]
positions2use
=
position_ids
[:,
prev_context_length
:
context_length
]
attention_mask2use
=
attention_mask
[
...,
prev_context_length
:
context_length
,
:
context_length
]
# logits will be meanigful only in the last pipeline stage.
logits
=
forward_step
(
tokens2use
,
positions2use
,
attention_mask2use
)
vocab_size
=
logits
.
size
(
2
)
if
mpu
.
is_pipeline_last_stage
():
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
new_scores
=
log_probs
[:,
-
1
,
:]
+
scores
if
context_length
==
prompt_length
:
# if this is the first one
sorted_scores
,
indices
=
torch
.
sort
(
new_scores
[
0
,:],
descending
=
True
)
else
:
sorted_scores
,
indices
=
torch
.
sort
(
new_scores
.
view
(
-
1
),
descending
=
True
)
best_batches
=
torch
.
div
(
indices
[:
beam_size
],
vocab_size
,
rounding_mode
=
'floor'
)
best_words
=
indices
[:
beam_size
]
%
vocab_size
tokens
=
tokens
[
best_batches
,:]
tokens
[:,
context_length
]
=
best_words
scores
=
sorted_scores
[:
beam_size
].
unsqueeze
(
1
)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage
(
batch_size
,
torch
.
int64
,
tokens
[:,
context_length
])
# Update the context length for the next token generation.
prev_context_length
=
context_length
copy_from_last_to_first_pipeline_stage
(
scores
.
size
(
0
),
torch
.
float32
,
scores
[:,
0
])
return
tokens
,
scores
def
_build_attention_mask_and_position_ids
(
tokens
):
"""Build the attention mask and postition ids for the input tokens."""
...
...
megatron/text_generation_server.py
View file @
fa75238b
...
...
@@ -128,6 +128,12 @@ class MegatronGenerate(Resource):
if
not
isinstance
(
no_log
,
bool
):
return
"no_log must be a boolean value"
beam_search
=
False
if
"beam_search"
in
request
.
get_json
():
beam_search
=
request
.
get_json
()[
"beam_search"
]
if
not
isinstance
(
no_log
,
bool
):
return
"beam_search must be a boolean value"
with
lock
:
# Need to get lock to keep multiple threads from hitting code
if
not
no_log
:
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
...
...
tools/run_text_generation_server.py
View file @
fa75238b
...
...
@@ -28,6 +28,7 @@ from megatron.model import GPTModel
from
megatron.training
import
get_model
from
megatron.text_generation_server
import
MegatronServer
from
megatron.text_generation
import
generate_and_post_process
from
megatron.text_generation
import
beam_search_and_post_process
import
torch
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
...
...
@@ -82,3 +83,8 @@ if __name__ == "__main__":
generate_and_post_process
(
model
)
except
ValueError
as
ve
:
pass
elif
choice
[
0
].
item
()
==
1
:
try
:
beam_search_and_post_process
(
model
)
except
ValueError
as
ve
:
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment