Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5f433016
Commit
5f433016
authored
Dec 09, 2021
by
zihanl
Browse files
add api option
parent
fec9668d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
123 additions
and
3 deletions
+123
-3
tasks/knwl_dialo/prompt.py
tasks/knwl_dialo/prompt.py
+121
-3
tasks/main.py
tasks/main.py
+2
-0
No files found.
tasks/knwl_dialo/prompt.py
View file @
5f433016
...
@@ -30,6 +30,121 @@ from tasks.knwl_dialo.utils import get_token_stream
...
@@ -30,6 +30,121 @@ from tasks.knwl_dialo.utils import get_token_stream
# from megatron.text_generation import generate_and_post_process
# from megatron.text_generation import generate_and_post_process
def
call_model_api
(
inputs
):
"""Calling the model api to get the output generations"""
# TODO
# Implement the model api, and get output generations from the inputs
# After that, return the output generations
# outputs = call_model_api(inputs)
# return outputs
pass
def
read_prompts
(
prompt_path
,
prompt_type
,
n_example
):
"""Read prompt data"""
if
prompt_type
==
"knowledge"
:
# prompts for the knowledge generation
prompt_examples_dict
=
{}
# read prompt_path
with
open
(
prompt_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
line_dict
=
json
.
loads
(
line
)
key
=
list
(
line_dict
.
keys
())[
0
]
if
key
not
in
prompt_examples_dict
:
prompt_examples
=
line_dict
[
key
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
prompt_examples_dict
[
key
]
=
prompt
return
prompt_examples_dict
else
:
# prompts for the response generation
# read prompt_path
prompt
=
""
with
open
(
prompt_path
,
"r"
)
as
f
:
prompt_examples
=
f
.
readlines
()
prompt_examples
=
prompt_examples
[:
n_example
]
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
return
prompt
def
generate_samples_by_calling_api
():
""" Generate outputs by calling"""
args
=
get_args
()
assert
args
.
prompt_type
in
[
"knowledge"
,
"response"
],
\
"Please input a correct prompt type!"
if
args
.
prompt_type
==
"knowledge"
:
# read knowledge generation prompts
knwl_gen_prompt_dict
=
read_prompts
(
args
.
prompt_file
,
args
.
prompt_type
,
args
.
num_prompt_examples
)
else
:
resp_gen_prompt
=
read_prompts
(
args
.
prompt_file
,
args
.
prompt_type
,
args
.
num_prompt_examples
)
# read the test data
fname
=
open
(
args
.
sample_input_file
,
"r"
)
test_sample_list
=
fname
.
readlines
()
# create output file
fname_out
=
open
(
sample_output_file
,
"w"
)
# call the api to get the output generations
for
test_sample
in
test_sample_list
:
test_sample
=
test_sample
.
strip
()
splits
=
input_str
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
# prepare the inputs for the api
if
args
.
prompt_type
==
"knowledge"
:
# inputs = prompt + current test
# get the prompt
turns
=
splits
[
1
].
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
key
=
topic
+
" "
+
last_turn
inputs
=
knwl_gen_prompt_dict
[
key
]
# add current test
inputs
+=
"( "
+
last_turn
+
" ) "
+
topic
+
" =>"
else
:
# inputs = prompt + current test
# get the prompt
inputs
=
resp_gen_prompt
# add current test
turns
=
splits
[
1
].
split
(
" [SEP] "
)
knowledge
=
splits
[
2
]
last_turn
=
turns
[
-
1
]
last_turn
=
" "
.
join
(
word_tokenize
(
last_turn
))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
knowledge
=
knowledge
.
strip
()
last_turn
=
last_turn
.
strip
()
inputs
+=
"Topic: "
+
topic
+
". "
inputs
+=
"User says: "
+
last_turn
+
" "
inputs
+=
"We know that: "
+
knowledge
+
" "
inputs
+=
"System replies:"
# get the output generations from the api,
# and write to the output file
generations
=
call_model_api
(
inputs
)
fname_out
.
write
(
generations
)
fname_out
.
write
(
"
\n
"
)
fname
.
close
()
fname_out
.
close
()
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
"""Build the model."""
...
@@ -124,9 +239,7 @@ def generate_samples_by_prompting_input_from_file(model):
...
@@ -124,9 +239,7 @@ def generate_samples_by_prompting_input_from_file(model):
# construct inputs for knowledge generation
# construct inputs for knowledge generation
# then add the constructed inputs into the raw_text
# then add the constructed inputs into the raw_text
turns
=
splits
[
1
].
split
(
" [SEP] "
)
raw_text
+=
"( "
+
last_turn
+
" ) "
+
topic
+
" =>"
context
=
turns
[
-
1
]
raw_text
+=
"( "
+
context
+
" ) "
+
topic
+
" =>"
else
:
else
:
# first add the prompt into the raw_text
# first add the prompt into the raw_text
...
@@ -196,6 +309,11 @@ def generate_samples_by_prompting_input_from_file(model):
...
@@ -196,6 +309,11 @@ def generate_samples_by_prompting_input_from_file(model):
def
main
():
def
main
():
args
=
get_args
()
args
=
get_args
()
if
args
.
api_prompting
:
# obtain the generations by calling the api
generate_samples_by_calling_api
()
return
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
exit
()
...
...
tasks/main.py
View file @
5f433016
...
@@ -102,6 +102,8 @@ def get_tasks_args(parser):
...
@@ -102,6 +102,8 @@ def get_tasks_args(parser):
help
=
'datapath for golden sentences'
)
help
=
'datapath for golden sentences'
)
group
.
add_argument
(
'--out-seq-length'
,
type
=
int
,
default
=
100
,
group
.
add_argument
(
'--out-seq-length'
,
type
=
int
,
default
=
100
,
help
=
'output sequence length'
)
help
=
'output sequence length'
)
group
.
add_argument
(
'--api-prompt'
,
default
=
False
,
action
=
"store_true"
,
help
=
'setup model api for prompting'
)
return
parser
return
parser
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment