Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
de84b2af
Commit
de84b2af
authored
Feb 18, 2022
by
rprenger
Browse files
Possibly working beam search
parent
fa75238b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
44 additions
and
9 deletions
+44
-9
megatron/text_generation/__init__.py
megatron/text_generation/__init__.py
+2
-1
megatron/text_generation/api.py
megatron/text_generation/api.py
+3
-2
megatron/text_generation_server.py
megatron/text_generation_server.py
+39
-6
No files found.
megatron/text_generation/__init__.py
View file @
de84b2af
...
...
@@ -16,4 +16,5 @@
from
.api
import
(
generate
,
generate_and_post_process
)
generate_and_post_process
,
beam_search_and_post_process
)
megatron/text_generation/api.py
View file @
de84b2af
...
...
@@ -158,11 +158,12 @@ def beam_search_and_post_process(model,
if
mpu
.
is_pipeline_first_stage
():
lengths
=
tokens
.
size
(
1
)
*
torch
.
ones
(
beam_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
tokens
,
prompts_plus_generations
,
prompts_plus_generations_segments
=
detokenize_generations
(
tokens
,
lengths
,
True
)
return
prompts_plus_generations
,
prompts_plus_generations_segments
,
tokens
scores
=
scores
.
cpu
().
numpy
().
tolist
()
return
prompts_plus_generations
,
prompts_plus_generations_segments
,
scores
return
None
def
beam_search
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
beam_size
=
0
,
add_BOS
=
False
)
def
beam_search
(
model
,
prompts
=
None
,
tokens_to_generate
=
0
,
beam_size
=
0
,
add_BOS
=
False
)
:
# Make sure input params are avaialble to all ranks.
values
=
[
tokens_to_generate
,
beam_size
,
...
...
megatron/text_generation_server.py
View file @
de84b2af
...
...
@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app
from
flask_restful
import
Resource
,
Api
from
megatron
import
get_args
from
megatron.text_generation
import
generate_and_post_process
from
megatron.text_generation
import
beam_search_and_post_process
GENERATE_NUM
=
0
BEAM_NUM
=
0
lock
=
threading
.
Lock
()
class
MegatronGenerate
(
Resource
):
...
...
@@ -34,6 +36,11 @@ class MegatronGenerate(Resource):
choice
=
torch
.
cuda
.
LongTensor
([
GENERATE_NUM
])
torch
.
distributed
.
broadcast
(
choice
,
0
)
@
staticmethod
def
send_do_beam_search
():
choice
=
torch
.
cuda
.
LongTensor
([
BEAM_NUM
])
torch
.
distributed
.
broadcast
(
choice
,
0
)
def
put
(
self
):
args
=
get_args
()
...
...
@@ -134,15 +141,39 @@ class MegatronGenerate(Resource):
if
not
isinstance
(
no_log
,
bool
):
return
"beam_search must be a boolean value"
beam_size
=
4
if
"beam_size"
in
request
.
get_json
():
beam_size
=
request
.
get_json
()[
"beam_size"
]
if
not
isinstance
(
beam_size
,
int
):
return
"beam_size must be integer"
if
beam_size
<
1
:
return
"beam_size must be an integer > 1"
with
lock
:
# Need to get lock to keep multiple threads from hitting code
if
not
no_log
:
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
print
(
json
.
dumps
(
request
.
get_json
()),
flush
=
True
)
print
(
"start time: "
,
datetime
.
datetime
.
now
())
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
try
:
response
,
response_seg
,
response_logprobs
,
_
=
\
generate_and_post_process
(
if
beam_search
:
MegatronGenerate
.
send_do_beam_search
()
# Tell other ranks we're doing beam_search
response
,
response_seg
,
response_scores
=
\
beam_search_and_post_process
(
self
.
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
beam_size
=
beam_size
,
add_BOS
=
add_BOS
)
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
"scores"
:
response_scores
})
else
:
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
response
,
response_seg
,
response_logprobs
,
_
=
\
generate_and_post_process
(
self
.
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
...
...
@@ -155,13 +186,15 @@ class MegatronGenerate(Resource):
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_eol
=
stop_on_eol
,
random_seed
=
random_seed
)
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
"logprobs"
:
response_logprobs
})
except
ValueError
as
ve
:
return
"Length of prompt + tokens_to_generate longer than allowed"
print
(
"end time: "
,
datetime
.
datetime
.
now
())
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
"logprobs"
:
response_logprobs
})
class
MegatronServer
(
object
):
def
__init__
(
self
,
model
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment