Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
bb618c02
Commit
bb618c02
authored
May 25, 2022
by
Jared Casper
Browse files
Merge branch 'beam_search' into 'main'
Beam search See merge request ADLR/megatron-lm!396
parents
d898a899
da11c982
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
309 additions
and
10 deletions
+309
-10
megatron/text_generation/__init__.py
megatron/text_generation/__init__.py
+2
-1
megatron/text_generation/api.py
megatron/text_generation/api.py
+53
-1
megatron/text_generation/beam_utils.py
megatron/text_generation/beam_utils.py
+64
-0
megatron/text_generation/forward_step.py
megatron/text_generation/forward_step.py
+12
-1
megatron/text_generation/generation.py
megatron/text_generation/generation.py
+115
-1
megatron/text_generation_server.py
megatron/text_generation_server.py
+57
-6
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+6
-0
No files found.
megatron/text_generation/__init__.py
View file @
bb618c02
...
...
@@ -16,4 +16,5 @@
from
.api
import
(
generate
,
generate_and_post_process
)
generate_and_post_process
,
beam_search_and_post_process
)
megatron/text_generation/api.py
View file @
bb618c02
...
...
@@ -22,7 +22,8 @@ from megatron import mpu
from
.communication
import
broadcast_float_list
from
.generation
import
(
generate_tokens_probs_and_return_on_first_stage
,
score_and_return_on_first_stage
)
score_and_return_on_first_stage
,
beam_search_and_return_on_first_stage
)
from
.tokenization
import
(
tokenize_prompts
,
detokenize_generations
)
...
...
@@ -138,3 +139,54 @@ def generate(model,
use_eod_token_for_early_termination
=
use_eod_token_for_early_termination
,
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_eol
=
stop_on_eol
)
def
beam_search_and_post_process
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
beam_size
=
0
,
add_BOS
=
False
,
stop_token
=
50256
,
num_return_gen
=
1
,
length_penalty
=
1
):
"""Run beam search and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
# Main inference.
tokens
,
scores
=
beam_search
(
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
beam_size
=
beam_size
,
add_BOS
=
add_BOS
,
stop_token
=
stop_token
,
num_return_gen
=
num_return_gen
,
length_penalty
=
length_penalty
)
# Only post-process on first stage.
if
mpu
.
is_pipeline_first_stage
():
lengths
=
tokens
.
size
(
1
)
*
torch
.
ones
(
beam_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
tokens
,
prompts_plus_generations
,
prompts_plus_generations_segments
=
detokenize_generations
(
tokens
,
lengths
,
True
)
scores
=
scores
.
cpu
().
numpy
().
tolist
()
return
prompts_plus_generations
,
prompts_plus_generations_segments
,
scores
return
None
def
beam_search
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
beam_size
=
0
,
add_BOS
=
False
,
stop_token
=
50256
,
num_return_gen
=
1
,
length_penalty
=
1
):
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
beam_size
,
add_BOS
,
stop_token
,
num_return_gen
,
length_penalty
]
values_float_tensor
=
broadcast_float_list
(
6
,
float_list
=
values
)
tokens_to_generate
=
int
(
values_float_tensor
[
0
].
item
())
beam_size
=
int
(
values_float_tensor
[
1
].
item
())
add_BOS
=
bool
(
values_float_tensor
[
2
].
item
())
stop_token
=
int
(
values_float_tensor
[
3
].
item
())
num_return_gen
=
int
(
values_float_tensor
[
4
].
item
())
length_penalty
=
values_float_tensor
[
5
].
item
()
context_tokens_tensor
,
context_length_tensor
=
tokenize_prompts
(
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
add_BOS
=
add_BOS
)
return
beam_search_and_return_on_first_stage
(
model
,
context_tokens_tensor
,
context_length_tensor
,
beam_size
,
stop_token
=
stop_token
,
num_return_gen
=
num_return_gen
,
length_penalty
=
length_penalty
)
megatron/text_generation/beam_utils.py
0 → 100644
View file @
bb618c02
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
## from huggingface beam search
class
BeamHypotheses
(
object
):
def
__init__
(
self
,
num_beams
,
length_penalty
=
1.0
,
early_stopping
=
False
):
"""
Initialize n-best list of hypotheses.
"""
self
.
length_penalty
=
length_penalty
self
.
early_stopping
=
early_stopping
self
.
num_beams
=
num_beams
self
.
beams
=
[]
self
.
worst_score
=
1e9
def
__len__
(
self
):
"""
Number of hypotheses in the list.
"""
return
len
(
self
.
beams
)
def
add
(
self
,
hyp
,
sum_logprobs
,
length
):
"""
Add a new hypothesis to the list.
"""
score
=
sum_logprobs
/
length
**
self
.
length_penalty
if
len
(
self
)
<
self
.
num_beams
or
score
>
self
.
worst_score
:
self
.
beams
.
append
((
score
,
hyp
))
if
len
(
self
)
>
self
.
num_beams
:
sorted_scores
=
sorted
([(
s
,
idx
)
for
idx
,
(
s
,
_
)
in
enumerate
(
self
.
beams
)])
del
self
.
beams
[
sorted_scores
[
0
][
1
]]
self
.
worst_score
=
sorted_scores
[
1
][
0
]
else
:
self
.
worst_score
=
min
(
score
,
self
.
worst_score
)
def
is_done
(
self
,
best_sum_logprobs
,
cur_len
):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if
len
(
self
)
<
self
.
num_beams
:
return
False
elif
self
.
early_stopping
:
return
True
else
:
cur_score
=
best_sum_logprobs
/
cur_len
**
self
.
length_penalty
ret
=
self
.
worst_score
>=
cur_score
return
ret
megatron/text_generation/forward_step.py
View file @
bb618c02
...
...
@@ -42,7 +42,18 @@ class InferenceParams:
self
.
batch_size_offset
=
0
self
.
key_value_memory_dict
=
{}
def
swap_key_value_dict
(
self
,
batch_idx
):
"swap between batches"
if
len
(
self
.
key_value_memory_dict
)
==
0
:
raise
ValueError
(
"should not swap when dict in empty"
)
for
layer_number
in
self
.
key_value_memory_dict
.
keys
():
inference_key_memory
,
inference_value_memory
=
self
.
key_value_memory_dict
[
layer_number
]
assert
len
(
batch_idx
)
==
inference_key_memory
.
shape
[
1
]
## make sure batch size is the same
new_inference_key_memory
=
inference_key_memory
[:,
batch_idx
]
new_inference_value_memory
=
inference_value_memory
[:,
batch_idx
]
self
.
key_value_memory_dict
[
layer_number
]
=
(
new_inference_key_memory
,
new_inference_value_memory
)
class
ForwardStep
:
"""Forward step function with all the communications.
...
...
megatron/text_generation/generation.py
View file @
bb618c02
...
...
@@ -26,6 +26,7 @@ from .communication import (
broadcast_from_last_to_first_pipeline_stage
)
from
.forward_step
import
ForwardStep
from
.sampling
import
sample
from
.beam_utils
import
BeamHypotheses
def
score_and_return_on_first_stage
(
model
,
tokens
,
lengths
):
"""Function for just scoring.
...
...
@@ -200,6 +201,7 @@ def generate_tokens_probs_and_return_on_first_stage(
top_p
=
top_p
,
temperature
=
temperature
,
vocab_size
=
tokenizer
.
vocab_size
)
# If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens
started
=
lengths
<=
context_length
...
...
@@ -257,7 +259,7 @@ def generate_tokens_probs_and_return_on_first_stage(
tensor
=
done
)
if
use_eod_token_for_early_termination
and
done
:
break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
...
...
@@ -280,6 +282,118 @@ def generate_tokens_probs_and_return_on_first_stage(
return
tokens
,
generated_sequence_lengths
,
output_log_probs
def
beam_search_and_return_on_first_stage
(
model
,
tokens
,
lengths
,
beam_size
,
stop_token
,
num_return_gen
,
length_penalty
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
batch_size
=
tokens
.
size
(
0
)
assert
(
batch_size
==
1
)
prompt_length
=
lengths
.
item
()
final_sequence_length
=
tokens
.
size
(
1
)
final_sequence_length
=
min
(
final_sequence_length
,
args
.
max_position_embeddings
)
# If the context is too big, this happens
if
prompt_length
>=
final_sequence_length
:
raise
ValueError
(
"context length + tokens_to_generate too large"
)
# forward step.
forward_step
=
ForwardStep
(
model
,
beam_size
,
final_sequence_length
)
beam_hyp
=
BeamHypotheses
(
beam_size
,
length_penalty
)
done
=
False
scores
=
torch
.
zeros
(
beam_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
()).
unsqueeze
(
1
)
# =============
# Run infernece
# =============
with
torch
.
no_grad
():
tokens
=
tokens
.
repeat
(
beam_size
,
1
)
attention_mask
,
position_ids
=
_build_attention_mask_and_position_ids
(
tokens
)
prev_context_length
=
0
for
context_length
in
range
(
prompt_length
,
final_sequence_length
):
# Pick the slice that we need to pass through the network.
tokens2use
=
tokens
[:,
prev_context_length
:
context_length
]
positions2use
=
position_ids
[:,
prev_context_length
:
context_length
]
attention_mask2use
=
attention_mask
[
...,
prev_context_length
:
context_length
,
:
context_length
]
# logits will be meanigful only in the last pipeline stage.
logits
=
forward_step
(
tokens2use
,
positions2use
,
attention_mask2use
)
if
mpu
.
is_pipeline_last_stage
():
vocab_size
=
logits
.
size
(
2
)
log_probs
=
F
.
log_softmax
(
logits
,
dim
=
2
)
new_scores
=
log_probs
[:,
-
1
,
:]
+
scores
if
context_length
==
prompt_length
:
# if this is the first one
sorted_scores
,
indices
=
torch
.
sort
(
new_scores
[
0
,:],
descending
=
True
)
else
:
sorted_scores
,
indices
=
torch
.
sort
(
new_scores
.
view
(
-
1
),
descending
=
True
)
best_beam_ids
=
torch
.
div
(
indices
[:
2
*
beam_size
],
vocab_size
).
trunc
().
long
()
best_words
=
indices
[:
2
*
beam_size
]
%
vocab_size
best_scores
=
sorted_scores
[:
2
*
beam_size
]
next_beams
=
[]
for
beam_token_rank
,
(
token_id
,
beam_score
,
beam_id
)
in
enumerate
(
zip
(
best_words
,
best_scores
,
best_beam_ids
)
):
if
token_id
.
item
()
==
stop_token
:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams
=
beam_token_rank
>=
beam_size
if
is_beam_token_worse_than_top_num_beams
:
continue
beam_hyp
.
add
(
tokens
[
beam_id
].
clone
(),
beam_score
,
context_length
+
1
-
prompt_length
)
else
:
# add next predicted token since it is not eos_token
next_beams
.
append
((
token_id
,
beam_score
,
beam_id
))
if
len
(
next_beams
)
==
beam_size
:
break
if
beam_hyp
.
is_done
(
best_scores
.
max
().
item
(),
context_length
+
1
-
prompt_length
):
done
=
True
break
best_batches
=
tokens
.
new
([
item
[
2
]
for
item
in
next_beams
])
tokens
=
tokens
[
best_batches
,:]
tokens
[:,
context_length
]
=
tokens
.
new
([
item
[
0
]
for
item
in
next_beams
])
scores
=
scores
.
new
([
item
[
1
]
for
item
in
next_beams
]).
unsqueeze
(
1
)
# set inference key values to make it consistent with best beam index
forward_step
.
inference_params
.
swap_key_value_dict
(
best_batches
)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage
(
batch_size
,
torch
.
int64
,
tokens
[:,
context_length
])
# Update the context length for the next token generation.
prev_context_length
=
context_length
copy_from_last_to_first_pipeline_stage
(
scores
.
size
(
0
),
torch
.
float32
,
scores
[:,
0
])
# if cannot find stop token, add open beams to hyps
if
not
done
:
for
beam_id
in
range
(
beam_size
):
beam_hyp
.
add
(
tokens
[
beam_id
].
clone
(),
scores
[
beam_id
],
context_length
+
1
-
prompt_length
)
# rank based on scores
sorted_hyps
=
sorted
(
beam_hyp
.
beams
,
key
=
lambda
x
:
x
[
0
],
reverse
=
True
)
num_return_gen
=
min
(
num_return_gen
,
len
(
sorted_hyps
))
scores
=
[
sorted_hyps
[
i
][
0
]
for
i
in
range
(
num_return_gen
)]
tokens
=
[
sorted_hyps
[
i
][
1
]
for
i
in
range
(
num_return_gen
)]
scores
=
torch
.
stack
(
scores
,
dim
=
0
)
tokens
=
torch
.
stack
(
tokens
,
dim
=
0
)
return
tokens
,
scores
def
_build_attention_mask_and_position_ids
(
tokens
):
...
...
megatron/text_generation_server.py
View file @
bb618c02
...
...
@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app
from
flask_restful
import
Resource
,
Api
from
megatron
import
get_args
from
megatron.text_generation
import
generate_and_post_process
from
megatron.text_generation
import
beam_search_and_post_process
GENERATE_NUM
=
0
BEAM_NUM
=
1
lock
=
threading
.
Lock
()
class
MegatronGenerate
(
Resource
):
...
...
@@ -34,6 +36,11 @@ class MegatronGenerate(Resource):
choice
=
torch
.
cuda
.
LongTensor
([
GENERATE_NUM
])
torch
.
distributed
.
broadcast
(
choice
,
0
)
@
staticmethod
def
send_do_beam_search
():
choice
=
torch
.
cuda
.
LongTensor
([
BEAM_NUM
])
torch
.
distributed
.
broadcast
(
choice
,
0
)
def
put
(
self
):
args
=
get_args
()
...
...
@@ -128,15 +135,57 @@ class MegatronGenerate(Resource):
if
not
isinstance
(
no_log
,
bool
):
return
"no_log must be a boolean value"
beam_width
=
None
if
"beam_width"
in
request
.
get_json
():
beam_width
=
request
.
get_json
()[
"beam_width"
]
if
not
isinstance
(
beam_width
,
int
):
return
"beam_width must be integer"
if
beam_width
<
1
:
return
"beam_width must be an integer > 1"
if
len
(
prompts
)
>
1
:
return
"When doing beam_search, batch size must be 1"
stop_token
=
50256
if
"stop_token"
in
request
.
get_json
():
stop_token
=
request
.
get_json
()[
"stop_token"
]
if
not
isinstance
(
stop_token
,
int
):
return
"stop_token must be an integer"
length_penalty
=
1
if
"length_penalty"
in
request
.
get_json
():
length_penalty
=
request
.
get_json
()[
"length_penalty"
]
if
not
isinstance
(
length_penalty
,
float
):
return
"length_penalty must be a float"
with
lock
:
# Need to get lock to keep multiple threads from hitting code
if
not
no_log
:
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
print
(
json
.
dumps
(
request
.
get_json
()),
flush
=
True
)
print
(
"start time: "
,
datetime
.
datetime
.
now
())
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
try
:
response
,
response_seg
,
response_logprobs
,
_
=
\
generate_and_post_process
(
if
beam_width
is
not
None
:
MegatronGenerate
.
send_do_beam_search
()
# Tell other ranks we're doing beam_search
response
,
response_seg
,
response_scores
=
\
beam_search_and_post_process
(
self
.
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
beam_size
=
beam_width
,
add_BOS
=
add_BOS
,
stop_token
=
stop_token
,
num_return_gen
=
beam_width
,
# Returning whole beam
length_penalty
=
length_penalty
)
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
"scores"
:
response_scores
})
else
:
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
response
,
response_seg
,
response_logprobs
,
_
=
\
generate_and_post_process
(
self
.
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
...
...
@@ -149,13 +198,15 @@ class MegatronGenerate(Resource):
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_eol
=
stop_on_eol
,
random_seed
=
random_seed
)
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
"logprobs"
:
response_logprobs
})
except
ValueError
as
ve
:
return
"Length of prompt + tokens_to_generate longer than allowed"
print
(
"end time: "
,
datetime
.
datetime
.
now
())
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
"logprobs"
:
response_logprobs
})
class
MegatronServer
(
object
):
def
__init__
(
self
,
model
):
...
...
tools/run_text_generation_server.py
View file @
bb618c02
...
...
@@ -28,6 +28,7 @@ from megatron.model import GPTModel
from
megatron.training
import
get_model
from
megatron.text_generation_server
import
MegatronServer
from
megatron.text_generation
import
generate_and_post_process
from
megatron.text_generation
import
beam_search_and_post_process
import
torch
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
...
...
@@ -82,3 +83,8 @@ if __name__ == "__main__":
generate_and_post_process
(
model
)
except
ValueError
as
ve
:
pass
elif
choice
[
0
].
item
()
==
1
:
try
:
beam_search_and_post_process
(
model
)
except
ValueError
as
ve
:
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment