Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a7f882fe
Unverified
Commit
a7f882fe
authored
Dec 25, 2021
by
Satpal Singh Rathore
Committed by
GitHub
Dec 25, 2021
Browse files
Merge branch 'main' into main
parents
8f241a96
9d86ca67
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
999 additions
and
7 deletions
+999
-7
tasks/msdp/metrics.py
tasks/msdp/metrics.py
+77
-0
tasks/msdp/preprocessing.py
tasks/msdp/preprocessing.py
+595
-0
tasks/msdp/prompt.py
tasks/msdp/prompt.py
+322
-0
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+2
-4
tools/text_generation_cli.py
tools/text_generation_cli.py
+3
-3
No files found.
tasks/msdp/metrics.py
0 → 100644
View file @
a7f882fe
# The following code is adapted from
# https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/metrics.py,
# which is licensed under the MIT license. More details on the license can be
# found at https://github.com/facebookresearch/ParlAI/blob/master/LICENSE.
"""Provides standard metric evaluations for dialog."""
from
collections
import
Counter
from
typing
import
List
import
numpy
as
np
import
re
re_art
=
re
.
compile
(
r
'\b(a|an|the)\b'
)
re_punc
=
re
.
compile
(
r
'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']'
)
def
normalize_answer
(
s
):
"""
Lower text and remove punctuation, articles and extra whitespace.
"""
s
=
s
.
lower
()
s
=
re_punc
.
sub
(
' '
,
s
)
s
=
re_art
.
sub
(
' '
,
s
)
s
=
' '
.
join
(
s
.
split
())
return
s
class
F1Metric
:
"""
Helper class which computes token-level F1.
"""
@
staticmethod
def
_prec_recall_f1_score
(
pred_items
,
gold_items
):
"""
Compute precision, recall and f1 given a set of gold and prediction items.
:param pred_items: iterable of predicted values
:param gold_items: iterable of gold values
:return: tuple (p, r, f1) for precision, recall, f1
"""
common
=
Counter
(
gold_items
)
&
Counter
(
pred_items
)
num_same
=
sum
(
common
.
values
())
if
num_same
==
0
:
return
0
,
0
,
0
precision
=
1.0
*
num_same
/
len
(
pred_items
)
recall
=
1.0
*
num_same
/
len
(
gold_items
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
precision
,
recall
,
f1
@
staticmethod
def
compute_each_pair
(
guess
:
str
,
answer
:
str
):
if
answer
==
""
:
return
None
,
None
,
None
if
guess
==
""
:
return
0
,
0
,
0
g_tokens
=
normalize_answer
(
guess
).
split
()
a_tokens
=
normalize_answer
(
answer
).
split
()
precision
,
recall
,
f1
=
F1Metric
.
_prec_recall_f1_score
(
g_tokens
,
a_tokens
)
return
precision
,
recall
,
f1
@
staticmethod
def
compute_all_pairs
(
guesses
:
List
[
str
],
answers
:
List
[
str
]):
# additional augment:
assert
len
(
guesses
)
==
len
(
answers
)
precision_list
,
recall_list
,
f1_list
=
[],
[],
[]
for
guess
,
answer
in
zip
(
guesses
,
answers
):
precision
,
recall
,
f1
=
F1Metric
.
compute_each_pair
(
guess
,
answer
)
if
precision
is
None
or
recall
is
None
or
f1
is
None
:
continue
precision_list
.
append
(
precision
)
recall_list
.
append
(
recall
)
f1_list
.
append
(
f1
)
return
np
.
mean
(
precision_list
),
np
.
mean
(
recall_list
),
np
.
mean
(
f1_list
)
tasks/msdp/preprocessing.py
0 → 100644
View file @
a7f882fe
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
import
torch
import
argparse
from
nltk
import
word_tokenize
from
tqdm
import
tqdm
import
numpy
as
np
import
json
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Preprocessing"
)
parser
.
add_argument
(
"--func"
,
type
=
str
,
default
=
None
,
help
=
"choose to run which function"
)
parser
.
add_argument
(
"--raw_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the input file"
)
parser
.
add_argument
(
"--processed_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the output file"
)
parser
.
add_argument
(
"--knwl_ref_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the knowledge reference file"
)
parser
.
add_argument
(
"--resp_ref_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the knowledge reference file"
)
parser
.
add_argument
(
"--knwl_gen_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the generated knowledge file"
)
parser
.
add_argument
(
"--test_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the test file"
)
parser
.
add_argument
(
"--train_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the train file"
)
parser
.
add_argument
(
"--model_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the model file"
)
parser
.
add_argument
(
"--data_type"
,
type
=
str
,
default
=
None
,
help
=
"data types, choose one out of three types:
\
wow_seen, wow_unseen, and woi"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1234
,
help
=
"random seed"
)
args
=
parser
.
parse_args
()
return
args
def
process_wow_dataset
(
raw_file
,
processed_file
,
knwl_ref_file
,
resp_ref_file
):
"""
This is a function used for processing the wizard of wikipedia (wow) dataset
Expected processed format:
topic
\t
dialogue context
\t
golden knowledge
\t
golden response
"""
# loading the raw data
print
(
"> Loading data from %s"
%
raw_file
)
with
open
(
raw_file
,
"r"
)
as
fr
:
dialog_data
=
json
.
load
(
fr
)
print
(
"> Processing data ..."
)
fproc
=
open
(
processed_file
,
"w"
)
fknwl
=
open
(
knwl_ref_file
,
"w"
)
if
knwl_ref_file
else
None
fresp
=
open
(
resp_ref_file
,
"w"
)
if
resp_ref_file
else
None
for
i
,
sample
in
enumerate
(
tqdm
(
dialog_data
)):
# get all the dialog data for a single dialog sample
dialog
=
sample
[
"dialog"
]
turn_list
=
[]
# collect the dialog history
# processing for each single dialog sample
for
j
,
turn
in
enumerate
(
dialog
):
# text of each turn
text
=
turn
[
"text"
]
if
not
(
text
.
endswith
(
"?"
)
or
text
.
endswith
(
"."
)
or
text
.
endswith
(
"!"
)):
text
=
text
+
"."
if
j
==
0
:
# first turn
turn_list
.
append
(
text
)
continue
speaker
=
turn
[
"speaker"
].
lower
()
if
"wizard"
in
speaker
:
checked_sentence
=
list
(
turn
[
"checked_sentence"
].
values
())
# knowledge
checked_passage
=
list
(
turn
[
"checked_passage"
].
values
())
# topic
assert
len
(
checked_sentence
)
<=
1
# get the ground truth knowledge
if
len
(
checked_sentence
)
>
0
:
checked_sentence
=
checked_sentence
[
0
]
else
:
checked_sentence
=
"no_passages_used"
if
len
(
checked_passage
)
==
1
:
checked_passage
=
checked_passage
[
0
]
else
:
checked_passage
=
"no_passages_used"
# get the topic
if
checked_passage
!=
"no_passages_used"
:
topic
=
checked_passage
else
:
topic
=
sample
[
"chosen_topic"
]
dialog_context
=
" [SEP] "
.
join
(
turn_list
)
knowledge
=
checked_sentence
response
=
text
# add the response into the dialog history
turn_list
.
append
(
response
)
# write to the output files
fproc
.
write
(
topic
+
"
\t
"
+
dialog_context
+
"
\t
"
+
\
knowledge
+
"
\t
"
+
response
+
"
\n
"
)
if
fknwl
:
fknwl
.
write
(
knowledge
+
"
\n
"
)
if
fresp
:
# tokenize for evaluation
response
=
" "
.
join
(
word_tokenize
(
response
))
fresp
.
write
(
response
+
"
\n
"
)
else
:
assert
"apprentice"
in
speaker
turn_list
.
append
(
text
)
fproc
.
close
()
if
fknwl
:
fknwl
.
close
()
if
fresp
:
fresp
.
close
()
def
process_woi_dataset
(
raw_file
,
processed_file
,
knwl_ref_file
,
resp_ref_file
):
"""
This is a function used for processing the wizard of internet (woi) dataset
Expected processed format:
topic
\t
dialogue context
\t
golden knowledge
\t
golden response
"""
print
(
"> Processing %s"
%
raw_file
)
fproc
=
open
(
processed_file
,
"w"
)
fknwl
=
open
(
knwl_ref_file
,
"w"
)
if
knwl_ref_file
else
None
fresp
=
open
(
resp_ref_file
,
"w"
)
if
resp_ref_file
else
None
with
open
(
raw_file
,
"r"
)
as
fr
:
for
i
,
line
in
tqdm
(
enumerate
(
fr
)):
# read line by line, each line uses json format
line
=
line
.
strip
()
item_dict
=
json
.
loads
(
line
)
# item_dict is a dictionary
# its key is the data id, and its value contains all the data content
item_dict
=
item_dict
.
values
()
item_dict
=
list
(
item_dict
)[
0
]
# len(item_dict) == 1
# get the whole dialog data for a single dialog sample
dialog_data
=
item_dict
[
'dialog_history'
]
length
=
len
(
dialog_data
)
turn_list
=
[]
# collect the dialog history
search_text
=
""
for
i
in
range
(
length
):
item
=
dialog_data
[
i
]
action
=
item
[
'action'
]
if
action
==
"Wizard => SearchAgent"
:
search_text
=
item
[
'text'
]
elif
action
==
"Wizard => Apprentice"
:
if
len
(
turn_list
)
==
0
:
# first turn
turn
=
item
[
'text'
]
turn_list
.
append
(
turn
)
continue
# get the relevant content
contents
=
item
[
"context"
][
"contents"
]
selects
=
item
[
"context"
][
"selected_contents"
]
flag
=
selects
[
0
][
0
]
selects
=
selects
[
1
:]
assert
len
(
selects
)
==
len
(
contents
)
# get the topic
if
flag
:
# no knowledge sentence is used for the response
topic
=
"no_topic"
knwl_sent
=
"no_passages_used"
else
:
# we consider the search text as the topic
topic
=
search_text
# get the knowledge sentence
knwl_sent
=
""
for
content
,
select
in
zip
(
contents
,
selects
):
content
=
content
[
'content'
]
assert
len
(
content
)
==
len
(
select
)
for
c
,
s
in
zip
(
content
,
select
):
if
s
:
knwl_sent
=
c
break
if
knwl_sent
==
""
:
# no knowledge is used for the response
topic
=
"no_topic"
knwl_sent
=
"no_passages_used"
# get dialogue context, knowledge, and response
dialog_context
=
" [SEP] "
.
join
(
turn_list
)
response
=
item
[
'text'
]
# processing
topic
=
topic
.
replace
(
"
\n
"
,
""
).
replace
(
"
\r
"
,
\
""
).
replace
(
"
\t
"
,
""
)
dialog_context
=
dialog_context
.
replace
(
"
\n
"
,
""
).
replace
(
"
\r
"
,
\
""
).
replace
(
"
\t
"
,
""
)
knwl_sent
=
knwl_sent
.
replace
(
"
\n
"
,
""
).
replace
(
"
\r
"
,
\
""
).
replace
(
"
\t
"
,
""
)
response
=
response
.
replace
(
"
\n
"
,
""
).
replace
(
"
\r
"
,
\
""
).
replace
(
"
\t
"
,
""
)
if
topic
!=
"no_topic"
:
# write to the ouput files
fproc
.
write
(
topic
+
"
\t
"
+
dialog_context
+
"
\t
"
+
\
knwl_sent
+
"
\t
"
+
response
+
"
\n
"
)
if
fknwl
:
fknwl
.
write
(
knwl_sent
+
"
\n
"
)
if
fresp
:
# tokenize for evaluation
response
=
" "
.
join
(
word_tokenize
(
response
))
fresp
.
write
(
response
+
"
\n
"
)
turn_list
.
append
(
response
)
elif
action
==
"Apprentice => Wizard"
:
turn
=
item
[
'text'
]
turn_list
.
append
(
turn
)
else
:
assert
action
==
"SearchAgent => Wizard"
,
\
"Please check whether you have used the correct data!"
fproc
.
close
()
if
fknwl
:
fknwl
.
close
()
if
fresp
:
fresp
.
close
()
def
get_database
(
test_datapath
,
train_datapath
,
data_type
):
"""Get the database by topics"""
assert
data_type
in
[
"wow_seen"
,
"wow_unseen"
,
"woi"
],
\
"Please input a correct data type!!"
# get test data topic dictionary
print
(
"> reading test data from %s"
%
test_datapath
)
test_topics
=
{}
with
open
(
test_datapath
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
test_topics
[
topic
]
=
True
print
(
"> reading data from %s"
%
train_datapath
)
train_data_by_topic
=
{}
dialog_data_by_topic
=
{}
dialog_examples
=
[]
with
open
(
train_datapath
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
knowledge
=
splits
[
2
]
response
=
splits
[
3
]
# filtering data samples
if
knowledge
==
"no_passages_used"
:
# when no knowledge is used
continue
if
data_type
!=
"wow_seen"
and
(
"("
in
knowledge
or
")"
in
knowledge
):
# when bracket exists in the knowledge
continue
if
data_type
!=
"wow_seen"
and
topic
not
in
knowledge
:
# when topic does not exist in the knowledge
continue
# get the instance
last_turn
=
turns
[
-
1
]
instance
=
"( "
+
last_turn
+
" ) "
+
topic
+
" => "
+
knowledge
# construct dialog example
dialog_example
=
""
if
data_type
!=
"wow_seen"
:
dialog_example
+=
"( "
+
topic
+
" ) "
for
i
,
turn
in
enumerate
(
turns
):
if
i
!=
0
:
dialog_example
+=
" "
dialog_example
+=
turn
# check overlaps
if
topic
in
test_topics
:
if
topic
not
in
train_data_by_topic
:
train_data_by_topic
[
topic
]
=
[
instance
]
else
:
train_data_by_topic
[
topic
].
append
(
instance
)
if
topic
not
in
dialog_data_by_topic
:
dialog_data_by_topic
[
topic
]
=
[
dialog_example
]
else
:
dialog_data_by_topic
[
topic
].
append
(
dialog_example
)
else
:
# filtering data samples
if
len
(
knowledge
.
split
())
>
20
:
# knowledge is too long
continue
if
knowledge
.
startswith
(
"It"
)
or
knowledge
.
startswith
(
"it"
)
or
\
knowledge
.
startswith
(
"This"
)
or
knowledge
.
startswith
(
"this"
):
continue
# append all the data into dialogue examples list
dialog_examples
.
append
((
topic
,
dialog_example
,
instance
))
return
train_data_by_topic
,
dialog_data_by_topic
,
dialog_examples
emb_dict
=
{}
def
select_prompts_based_on_similarity
(
query
,
dialog_list
,
prompt_list
,
topic
,
tokenizer
,
encoder
,
topk
):
"""Select samples based on the similarity"""
with
torch
.
no_grad
():
# get the query embeddings
query_ids
=
tokenizer
.
encode
(
query
)
query_ids
=
torch
.
LongTensor
([
query_ids
]).
cuda
()
query_emb
=
encoder
(
input_ids
=
query_ids
).
pooler_output
query_emb
=
query_emb
[
0
]
# calculate embeddings for the samples in the database
if
topic
in
emb_dict
:
example_embeddings
=
emb_dict
[
topic
]
example_embeddings
=
example_embeddings
.
cuda
()
else
:
for
idx
,
example
in
enumerate
(
dialog_list
):
example_ids
=
tokenizer
.
encode
(
example
)
example_ids
=
torch
.
LongTensor
([
example_ids
]).
cuda
()
example_emb
=
encoder
(
input_ids
=
example_ids
).
pooler_output
if
idx
==
0
:
example_embeddings
=
example_emb
else
:
example_embeddings
=
torch
.
cat
(
(
example_embeddings
,
example_emb
),
dim
=
0
)
emb_dict
[
topic
]
=
example_embeddings
.
cpu
()
# compare the similarity and select the topk samples
similarity_list
=
example_embeddings
.
matmul
(
query_emb
)
_
,
indices
=
torch
.
topk
(
similarity_list
,
k
=
topk
)
indices
=
indices
.
tolist
()
indices
=
indices
[::
-
1
]
# reverse the order
selected_prompts
=
[]
for
index
in
indices
:
# index = index.item()
selected_prompts
.
append
(
prompt_list
[
index
])
return
selected_prompts
def
prompt_selection_for_knowledge_generation
(
test_datapath
,
train_datapath
,
model_path
,
output_prompt_path
,
data_type
):
"""Selecting prompts for the knowledge generation"""
print
(
"> Selecting prompts for the knowledge generation"
)
train_data_by_topic
,
dialog_data_by_topic
,
dialog_examples
=
\
get_database
(
test_datapath
,
train_datapath
,
data_type
)
from
transformers
import
DPRQuestionEncoderTokenizer
print
(
"> loading tokenizer and encoder"
)
tokenizer
=
DPRQuestionEncoderTokenizer
.
from_pretrained
(
'facebook/dpr-question_encoder-single-nq-base'
)
encoder
=
torch
.
load
(
model_path
).
cuda
()
print
(
"> getting dialog embeddings"
)
with
torch
.
no_grad
():
for
idx
,
example
in
tqdm
(
enumerate
(
dialog_examples
)):
dialog
=
example
[
1
]
dialog_ids
=
tokenizer
.
encode
(
dialog
)
dialog_ids
=
torch
.
LongTensor
([
dialog_ids
]).
cuda
()
dialog_emb
=
encoder
(
input_ids
=
dialog_ids
).
pooler_output
if
idx
==
0
:
dialog_embeddings
=
dialog_emb
else
:
dialog_embeddings
=
torch
.
cat
((
dialog_embeddings
,
dialog_emb
),
dim
=
0
)
print
(
"> reading test data from %s"
%
test_datapath
)
prompt_list_for_each_sample
=
[]
with
open
(
test_datapath
,
"r"
)
as
f
:
for
i
,
line
in
tqdm
(
enumerate
(
f
)):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
# get the query sentence
query_sent
=
""
if
data_type
!=
"seen"
:
query_sent
+=
"( "
+
topic
+
" ) "
for
i
,
turn
in
enumerate
(
turns
):
if
i
!=
0
:
query_sent
+=
" "
query_sent
+=
turn
if
topic
not
in
train_data_by_topic
:
# get the query embedding
query_ids
=
tokenizer
.
encode
(
query_sent
)
query_ids
=
torch
.
LongTensor
([
query_ids
]).
cuda
()
query_emb
=
encoder
(
input_ids
=
query_ids
).
pooler_output
query_emb
=
query_emb
[
0
]
# calculate the similarity
similarity_list
=
dialog_embeddings
.
matmul
(
query_emb
)
_
,
indices
=
torch
.
sort
(
similarity_list
)
indices
=
indices
.
tolist
()
selected_topics
=
{}
selected_prompts
=
[]
num_prompt
=
0
for
index
in
indices
:
example
=
dialog_examples
[
index
]
topic_temp
=
example
[
0
]
if
topic_temp
not
in
selected_topics
:
selected_topics
[
topic_temp
]
=
True
selected_prompts
.
append
(
example
[
2
])
num_prompt
+=
1
if
num_prompt
==
10
:
break
# get the selected samples
example_list
=
selected_prompts
[::
-
1
]
key
=
topic
+
" "
+
turns
[
-
1
]
prompt_list_for_each_sample
.
append
({
key
:
example_list
})
else
:
num_data_sample
=
min
(
len
(
train_data_by_topic
[
topic
]),
10
)
total_example_list
=
train_data_by_topic
[
topic
]
dialog_list
=
dialog_data_by_topic
[
topic
]
assert
len
(
dialog_list
)
==
len
(
train_data_by_topic
[
topic
])
# calculate the similarity
example_list
=
select_prompts_based_on_similarity
(
query_sent
,
dialog_list
,
total_example_list
,
topic
,
tokenizer
,
encoder
,
topk
=
num_data_sample
)
key
=
topic
+
" "
+
turns
[
-
1
]
prompt_list_for_each_sample
.
append
({
key
:
example_list
})
print
(
"writing to %s"
%
output_prompt_path
)
with
open
(
output_prompt_path
,
"w"
)
as
f
:
for
instance
in
tqdm
(
prompt_list_for_each_sample
):
json
.
dump
(
instance
,
f
)
f
.
write
(
"
\n
"
)
def
prompt_selection_for_response_generation
(
input_path
,
output_path
,
seed
):
"""Selecting prompts for the response generation"""
print
(
"> Selecting prompts for the response generation"
)
print
(
"> set random seed"
)
np
.
random
.
seed
(
seed
)
prompt_example_list
=
[]
print
(
"> reading data from %s"
%
input_path
)
with
open
(
input_path
,
"r"
)
as
f
:
for
i
,
line
in
tqdm
(
enumerate
(
f
)):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
# get the topic, context, knowledge and response
topic
=
splits
[
0
]
dialog_context
=
splits
[
1
]
knowledge
=
splits
[
2
]
response
=
splits
[
3
]
turns
=
dialog_context
.
split
(
" [SEP] "
)[
-
3
:]
if
knowledge
==
"no_passages_used"
:
continue
# calculate the overlap ratio
from
nltk
import
word_tokenize
knowledge_sent_token_list
=
word_tokenize
(
knowledge
)
knowledge_sent_token_dict
=
{
token
:
True
for
token
in
knowledge_sent_token_list
}
knowledge_len
=
len
(
knowledge_sent_token_list
)
response_token_list
=
word_tokenize
(
response
)
response_len
=
len
(
response_token_list
)
num_overlap_token
=
0
accumulator
=
0
for
token
in
response_token_list
:
if
token
in
knowledge_sent_token_dict
:
accumulator
+=
1
else
:
if
accumulator
>=
10
:
num_overlap_token
+=
accumulator
accumulator
=
0
if
accumulator
>=
10
:
num_overlap_token
+=
accumulator
# filtering the data based on the ratio
if
num_overlap_token
>
response_len
*
0.9
or
num_overlap_token
<
response_len
*
0.6
:
continue
if
num_overlap_token
<
knowledge_len
*
0.8
:
continue
last_turn
=
" "
.
join
(
word_tokenize
(
turns
[
-
1
]))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
response
=
" "
.
join
(
word_tokenize
(
response
))
prompt_example
=
""
# add dialog context
prompt_example
+=
"Topic: "
+
topic
+
". "
prompt_example
+=
"User says: "
+
last_turn
+
" "
prompt_example
+=
"We know that: "
+
knowledge
+
" "
prompt_example
+=
"System replies: "
+
response
prompt_example_list
.
append
(
prompt_example
)
# shuffle the prompt examples
np
.
random
.
shuffle
(
prompt_example_list
)
print
(
"> writing to %s"
%
output_path
)
with
open
(
output_path
,
"w"
)
as
f
:
# f.write("Generate the System's response based on the knowledge sentence:\n")
for
i
in
tqdm
(
range
(
20
)):
example
=
prompt_example_list
[
i
]
f
.
write
(
example
+
"
\n
"
)
def
prepare_input_for_response_generation
(
test_file
,
knwl_gen_file
,
processed_file
):
"""Preparing inputs for the response generation"""
print
(
"> Reading knowledge file from %s"
%
knwl_gen_file
)
# get the knowledge list
with
open
(
knwl_gen_file
,
"r"
)
as
f
:
knowledge_list
=
f
.
readlines
()
print
(
"> Processing ..."
)
with
open
(
test_file
,
"r"
)
as
fr
:
with
open
(
processed_file
,
"w"
)
as
fw
:
for
line_num
,
line
in
enumerate
(
tqdm
(
fr
)):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
# prepare topic, context, knowledge and response
topic
=
splits
[
0
]
dialog_context
=
splits
[
1
]
response
=
splits
[
3
]
knowledge
=
knowledge_list
[
line_num
]
knowledge
=
knowledge
.
strip
()
if
"<|endoftext|>"
in
knowledge
:
knowledge
=
knowledge
.
replace
(
"<|endoftext|>"
,
""
)
# write to the output file
fw
.
write
(
topic
+
"
\t
"
+
dialog_context
+
"
\t
"
\
+
knowledge
+
"
\t
"
+
response
+
"
\n
"
)
if
__name__
==
"__main__"
:
args
=
get_args
()
if
args
.
func
==
"process_wow_dataset"
:
process_wow_dataset
(
args
.
raw_file
,
args
.
processed_file
,
args
.
knwl_ref_file
,
args
.
resp_ref_file
)
elif
args
.
func
==
"process_woi_dataset"
:
process_woi_dataset
(
args
.
raw_file
,
args
.
processed_file
,
args
.
knwl_ref_file
,
args
.
resp_ref_file
)
elif
args
.
func
==
"get_knwl_gen_prompts"
:
prompt_selection_for_knowledge_generation
(
args
.
test_file
,
args
.
train_file
,
args
.
model_file
,
args
.
processed_file
,
args
.
data_type
)
elif
args
.
func
==
"get_resp_gen_prompts"
:
prompt_selection_for_response_generation
(
args
.
train_file
,
args
.
processed_file
,
args
.
seed
)
elif
args
.
func
==
"prepare_input"
:
prepare_input_for_response_generation
(
args
.
test_file
,
args
.
knwl_gen_file
,
args
.
processed_file
)
tasks/msdp/prompt.py
0 → 100644
View file @
a7f882fe
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompting the pretrained language model to generate knowledge/response"""
import
json
import
torch
import
requests
from
nltk
import
word_tokenize
from
megatron
import
mpu
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.text_generation
import
generate_and_post_process
def
call_model_api
(
inputs
,
tokens_to_generate
):
"""Calling the model api to get the output generations"""
args
=
get_args
()
# The following is an example of using the Megatron API
# You can also implement your own API function to place this part
headers
=
{
'Content-Type'
:
'application/json; charset=UTF-8'
}
data
=
{
"prompts"
:
[
inputs
],
"tokens_to_generate"
:
tokens_to_generate
,
"top_k"
:
1
}
data_json
=
json
.
dumps
(
data
)
outputs
=
requests
.
put
(
args
.
megatron_api_url
,
headers
=
headers
,
data
=
data_json
).
json
()[
"text"
][
0
]
input_len
=
len
(
inputs
)
outputs
=
outputs
[
input_len
:]
outputs
=
outputs
.
split
(
"
\n
"
)[
0
].
strip
()
return
outputs
def
read_prompts
(
prompt_path
,
prompt_type
,
n_example
):
"""Read prompt data"""
if
prompt_type
==
"knowledge"
:
# prompts for the knowledge generation
prompt_examples_dict
=
{}
# read prompt_path
with
open
(
prompt_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
line_dict
=
json
.
loads
(
line
)
key
=
list
(
line_dict
.
keys
())[
0
]
if
key
not
in
prompt_examples_dict
:
prompt_examples
=
line_dict
[
key
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
prompt_examples_dict
[
key
]
=
prompt
return
prompt_examples_dict
else
:
# prompts for the response generation
# read prompt_path
prompt
=
""
with
open
(
prompt_path
,
"r"
)
as
f
:
prompt_examples
=
f
.
readlines
()
prompt_examples
=
prompt_examples
[:
n_example
]
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
return
prompt
def
generate_samples_by_calling_api
():
""" Generate outputs by calling"""
args
=
get_args
()
assert
args
.
prompt_type
in
[
"knowledge"
,
"response"
],
\
"Please input a correct prompt type!"
if
args
.
prompt_type
==
"knowledge"
:
# read knowledge generation prompts
knwl_gen_prompt_dict
=
read_prompts
(
args
.
prompt_file
,
args
.
prompt_type
,
args
.
num_prompt_examples
)
else
:
resp_gen_prompt
=
read_prompts
(
args
.
prompt_file
,
args
.
prompt_type
,
args
.
num_prompt_examples
)
# read the test data
fname
=
open
(
args
.
sample_input_file
,
"r"
)
test_sample_list
=
fname
.
readlines
()
# create output file
fname_out
=
open
(
args
.
sample_output_file
,
"w"
)
# call the api to get the output generations
for
test_sample
in
test_sample_list
:
test_sample
=
test_sample
.
strip
()
splits
=
test_sample
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
# prepare the inputs for the api
if
args
.
prompt_type
==
"knowledge"
:
## inputs = prompt + current test
# get the prompt
turns
=
splits
[
1
].
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
key
=
topic
+
" "
+
last_turn
inputs
=
knwl_gen_prompt_dict
[
key
]
# add current test
inputs
+=
"( "
+
last_turn
+
" ) "
+
topic
+
" =>"
else
:
# inputs = prompt + current test
# get the prompt
inputs
=
resp_gen_prompt
# add current test
turns
=
splits
[
1
].
split
(
" [SEP] "
)
knowledge
=
splits
[
2
]
last_turn
=
turns
[
-
1
]
last_turn
=
" "
.
join
(
word_tokenize
(
last_turn
))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
knowledge
=
knowledge
.
strip
()
last_turn
=
last_turn
.
strip
()
inputs
+=
"Topic: "
+
topic
+
". "
inputs
+=
"User says: "
+
last_turn
+
" "
inputs
+=
"We know that: "
+
knowledge
+
" "
inputs
+=
"System replies:"
# get the output generations from the api,
# and write to the output file
generations
=
call_model_api
(
inputs
,
args
.
out_seq_length
)
fname_out
.
write
(
generations
)
fname_out
.
write
(
"
\n
"
)
fname
.
close
()
fname_out
.
close
()
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
generate_samples_by_prompting_input_from_file
(
model
):
"""Prompt a pretrained language model to generate knowledge/response"""
# get tokenizer
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'`sample-output-file` not specified, setting '
'it to {}'
.
format
(
sample_output_file
))
else
:
sample_output_file
=
args
.
sample_output_file
fname_out
=
open
(
sample_output_file
,
"w"
)
# only two prompt types (i.e., knowledge and response) are allowed
assert
args
.
prompt_type
in
[
"knowledge"
,
"response"
],
\
"Please input a correct prompt type!"
# Read the prompt file
if
args
.
prompt_type
==
"knowledge"
:
# read the prompts for the knowledge generation
prompt_examples_dict
=
{}
with
open
(
args
.
prompt_file
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
line_dict
=
json
.
loads
(
line
)
key
=
list
(
line_dict
.
keys
())[
0
]
# get the prompt examples based on the key
if
key
not
in
prompt_examples_dict
:
prompt_examples
=
line_dict
[
key
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
prompt_examples_dict
[
key
]
=
prompt
else
:
# read the prompts for the response generation
# prompts are fixed for all test samples
with
open
(
args
.
prompt_file
,
"r"
)
as
f
:
prompt_examples
=
f
.
readlines
()
prompt_examples
=
prompt_examples
[:
args
.
num_prompt_examples
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
input_pos
=
0
model
.
eval
()
# perform prompting
with
torch
.
no_grad
():
while
True
:
raw_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
input_str
=
all_raw_text
[
input_pos
]
input_str
=
input_str
.
strip
()
splits
=
input_str
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
if
args
.
prompt_type
==
"knowledge"
:
# first add the prompt into the raw_text
turns
=
splits
[
1
].
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
key
=
topic
+
" "
+
last_turn
raw_text
=
prompt_examples_dict
[
key
]
# construct inputs for knowledge generation
# then add the constructed inputs into the raw_text
raw_text
+=
"( "
+
last_turn
+
" ) "
+
topic
+
" =>"
else
:
# first add the prompt into the raw_text
raw_text
=
prompt
# construct inputs for response generation
# then add the constructed inputs into the raw_text
turns
=
splits
[
1
].
split
(
" [SEP] "
)
knowledge
=
splits
[
2
]
last_turn
=
turns
[
-
1
]
last_turn
=
" "
.
join
(
word_tokenize
(
last_turn
))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
knowledge
=
knowledge
.
strip
()
last_turn
=
last_turn
.
strip
()
raw_text
+=
"Topic: "
+
topic
+
". "
raw_text
+=
"User says: "
+
last_turn
+
" "
raw_text
+=
"We know that: "
+
knowledge
+
" "
raw_text
+=
"System replies:"
input_pos
+=
1
raw_text_len
=
len
(
raw_text
)
else
:
raw_text
=
"EMPTY TEXT"
if
input_pos
%
100
==
0
:
print_rank_0
(
"input_pos: %d"
%
input_pos
)
outputs
=
generate_and_post_process
(
model
=
model
,
prompts
=
[
raw_text
],
tokens_to_generate
=
args
.
out_seq_length
,
top_k_sampling
=
1
)
prompts_plus_generations
=
outputs
[
0
]
prompts_plus_generations
=
prompts_plus_generations
[
0
]
# write the generated output to the output file
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
generations
=
prompts_plus_generations
[
raw_text_len
:]
generations
=
generations
.
split
(
"
\n
"
)[
0
]
generations
=
generations
.
strip
()
fname_out
.
write
(
generations
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
if
input_pos
==
input_count
:
return
def
main
():
args
=
get_args
()
if
args
.
api_prompt
:
# obtain the generations by calling the api
generate_samples_by_calling_api
()
return
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
,
wrap_with_ddp
=
False
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
# perform the prompting
generate_samples_by_prompting_input_from_file
(
model
)
tools/run_text_generation_server.py
View file @
a7f882fe
...
@@ -27,7 +27,7 @@ from megatron.initialize import initialize_megatron
...
@@ -27,7 +27,7 @@ from megatron.initialize import initialize_megatron
from
megatron.model
import
GPTModel
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.training
import
get_model
from
megatron.text_generation_server
import
MegatronServer
from
megatron.text_generation_server
import
MegatronServer
from
megatron.text_generation
_utils
import
generate
from
megatron.text_generation
import
generate
_and_post_process
import
torch
import
torch
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
...
@@ -43,8 +43,6 @@ def add_text_generate_args(parser):
...
@@ -43,8 +43,6 @@ def add_text_generate_args(parser):
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
'Sampling temperature.'
)
help
=
'Sampling temperature.'
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
,
help
=
'Use greedy sampling.'
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
,
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
,
help
=
'Top p sampling.'
)
help
=
'Top p sampling.'
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
,
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
,
...
@@ -80,4 +78,4 @@ if __name__ == "__main__":
...
@@ -80,4 +78,4 @@ if __name__ == "__main__":
choice
=
torch
.
cuda
.
LongTensor
(
1
)
choice
=
torch
.
cuda
.
LongTensor
(
1
)
torch
.
distributed
.
broadcast
(
choice
,
0
)
torch
.
distributed
.
broadcast
(
choice
,
0
)
if
choice
[
0
].
item
()
==
0
:
if
choice
[
0
].
item
()
==
0
:
generate
(
model
)
generate
_and_post_process
(
model
)
tools/text_generation_cli.py
View file @
a7f882fe
...
@@ -25,10 +25,10 @@ if __name__ == "__main__":
...
@@ -25,10 +25,10 @@ if __name__ == "__main__":
url
=
sys
.
argv
[
1
]
url
=
sys
.
argv
[
1
]
while
True
:
while
True
:
sentence
=
raw_input
(
"Enter prompt: "
)
sentence
=
raw_input
(
"Enter prompt: "
)
max_len
=
int
(
input
(
"Enter number tokens
output
: "
))
tokens_to_generate
=
int
(
input
(
"Enter number
of
tokens
to generate
: "
))
data
=
json
.
dumps
({
"
sentence
s"
:
[
sentence
],
"
max_len"
:
max_len
})
data
=
json
.
dumps
({
"
prompt
s"
:
[
sentence
],
"
tokens_to_generate"
:
tokens_to_generate
})
req
=
PutRequest
(
url
,
data
,
{
'Content-Type'
:
'application/json'
})
req
=
PutRequest
(
url
,
data
,
{
'Content-Type'
:
'application/json'
})
response
=
urllib2
.
urlopen
(
req
)
response
=
urllib2
.
urlopen
(
req
)
resp_sentences
=
json
.
load
(
response
)
resp_sentences
=
json
.
load
(
response
)
print
(
"Megatron Response: "
)
print
(
"Megatron Response: "
)
print
(
resp_sentences
[
"
sentences
"
][
0
])
print
(
resp_sentences
[
"
text
"
][
0
])
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment