Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ccf0eae2
Commit
ccf0eae2
authored
Dec 09, 2021
by
zihanl
Browse files
use new text generation
parent
b3cd8a47
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
41 deletions
+32
-41
tasks/knwl_dialo/prompt.py
tasks/knwl_dialo/prompt.py
+32
-41
No files found.
tasks/knwl_dialo/prompt.py
View file @
ccf0eae2
...
...
@@ -26,19 +26,26 @@ from megatron.model import GPTModel
from
megatron.training
import
get_model
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
tasks.knwl_dialo.utils
import
get_token_stream
# from megatron.text_generation import generate_and_post_process
from
megatron.text_generation
import
generate_and_post_process
def
call_model_api
(
inputs
):
def
call_model_api
(
inputs
,
tokens_to_generate
):
"""Calling the model api to get the output generations"""
# TODO
# Implement the model api, and get output generations from the inputs
# After that, return the output generations
args
=
get_args
()
# The following is an example of using the Megatron API
# You can also implement your own API function to place this part
headers
=
{
'Content-Type'
:
'application/json; charset=UTF-8'
}
data
=
{
"prompts"
:
[
inputs
],
"tokens_to_generate"
:
tokens_to_generate
,
"top_k"
:
1
}
data_json
=
json
.
dumps
(
data
)
outputs
=
requests
.
put
(
args
.
megatron_api_url
,
headers
=
headers
,
data
=
data_json
).
json
()[
"text"
][
0
]
# outputs = call_model_api(inputs)
# return outputs
pass
input_len
=
len
(
inputs
)
outputs
=
outputs
[
input_len
:]
outputs
=
outputs
.
split
(
"
\n
"
)[
0
].
strip
()
return
outputs
def
read_prompts
(
prompt_path
,
prompt_type
,
n_example
):
...
...
@@ -107,7 +114,7 @@ def generate_samples_by_calling_api():
# prepare the inputs for the api
if
args
.
prompt_type
==
"knowledge"
:
# inputs = prompt + current test
#
#
inputs = prompt + current test
# get the prompt
turns
=
splits
[
1
].
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
...
...
@@ -216,7 +223,6 @@ def generate_samples_by_prompting_input_from_file(model):
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
context_count
=
0
input_pos
=
0
model
.
eval
()
# perform prompting
...
...
@@ -261,47 +267,32 @@ def generate_samples_by_prompting_input_from_file(model):
input_pos
+=
1
raw_text_len
=
len
(
raw_text
)
context_tokens
=
tokenizer
.
tokenize
(
raw_text
)
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
# raw_text = "EMPTY TEXT"
raw_text
=
"EMPTY TEXT"
if
input_pos
%
100
==
0
:
print_rank_0
(
"input_pos: %d"
%
input_pos
)
# get the generation outputs (in decode_tokens)
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
pass
# outputs = generate_and_post_process(
# model=model,
# prompts=[raw_text],
# tokens_to_generate=args.out_seq_length,
# top_k_sampling=1)
# prompts_plus_generations = outputs[0]
outputs
=
generate_and_post_process
(
model
=
model
,
prompts
=
[
raw_text
],
tokens_to_generate
=
args
.
out_seq_length
,
top_k_sampling
=
1
)
prompts_plus_generations
=
outputs
[
0
]
prompts_plus_generations
=
prompts_plus_generations
[
0
]
# write the generated output to the output file
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
trim_decode_tokens
=
tokenizer
.
detokenize
(
decode_tokens
)[
raw_text_len
:]
generated_output
=
trim_decode_tokens
.
split
(
"
\n
"
)[
0
]
generated_output
=
generated_output
.
strip
()
fname_out
.
write
(
generated_output
)
generations
=
prompts_plus_generations
[
raw_text_len
:]
generations
=
generations
.
split
(
"
\n
"
)[
0
]
generations
=
generations
.
strip
()
fname_out
.
write
(
generations
)
fname_out
.
write
(
"
\n
"
)
# generations = prompts_plus_generations[raw_text_len:]
# generations = generations.split("\n")[0]
# generations = generations.strip()
# fname_out.write(generations)
# fname_out.write("\n")
raw_text
=
None
context_count
+=
1
if
input_pos
==
input_count
:
return
...
...
@@ -309,7 +300,7 @@ def generate_samples_by_prompting_input_from_file(model):
def
main
():
args
=
get_args
()
if
args
.
api_prompt
ing
:
if
args
.
api_prompt
:
# obtain the generations by calling the api
generate_samples_by_calling_api
()
return
...
...
@@ -319,7 +310,7 @@ def main():
exit
()
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
model
=
get_model
(
model_provider
,
wrap_with_ddp
=
False
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment