Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
06fc51ce
Commit
06fc51ce
authored
Feb 18, 2022
by
Jared Casper
Browse files
Merge branch 'main' into checkpoint_util
parents
ec561daa
0ed2f6ac
Changes
66
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
928 additions
and
8 deletions
+928
-8
tasks/msdp/preprocessing.py
tasks/msdp/preprocessing.py
+595
-0
tasks/msdp/prompt.py
tasks/msdp/prompt.py
+322
-0
tasks/vision/classification/classification.py
tasks/vision/classification/classification.py
+0
-0
tasks/vision/classification/eval_utils.py
tasks/vision/classification/eval_utils.py
+0
-0
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+7
-7
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+4
-1
No files found.
tasks/msdp/preprocessing.py
0 → 100644
View file @
06fc51ce
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
import
torch
import
argparse
from
nltk
import
word_tokenize
from
tqdm
import
tqdm
import
numpy
as
np
import
json
def
get_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Preprocessing"
)
parser
.
add_argument
(
"--func"
,
type
=
str
,
default
=
None
,
help
=
"choose to run which function"
)
parser
.
add_argument
(
"--raw_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the input file"
)
parser
.
add_argument
(
"--processed_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the output file"
)
parser
.
add_argument
(
"--knwl_ref_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the knowledge reference file"
)
parser
.
add_argument
(
"--resp_ref_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the knowledge reference file"
)
parser
.
add_argument
(
"--knwl_gen_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the generated knowledge file"
)
parser
.
add_argument
(
"--test_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the test file"
)
parser
.
add_argument
(
"--train_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the train file"
)
parser
.
add_argument
(
"--model_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the model file"
)
parser
.
add_argument
(
"--data_type"
,
type
=
str
,
default
=
None
,
help
=
"data types, choose one out of three types:
\
wow_seen, wow_unseen, and woi"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1234
,
help
=
"random seed"
)
args
=
parser
.
parse_args
()
return
args
def
process_wow_dataset
(
raw_file
,
processed_file
,
knwl_ref_file
,
resp_ref_file
):
"""
This is a function used for processing the wizard of wikipedia (wow) dataset
Expected processed format:
topic
\t
dialogue context
\t
golden knowledge
\t
golden response
"""
# loading the raw data
print
(
"> Loading data from %s"
%
raw_file
)
with
open
(
raw_file
,
"r"
)
as
fr
:
dialog_data
=
json
.
load
(
fr
)
print
(
"> Processing data ..."
)
fproc
=
open
(
processed_file
,
"w"
)
fknwl
=
open
(
knwl_ref_file
,
"w"
)
if
knwl_ref_file
else
None
fresp
=
open
(
resp_ref_file
,
"w"
)
if
resp_ref_file
else
None
for
i
,
sample
in
enumerate
(
tqdm
(
dialog_data
)):
# get all the dialog data for a single dialog sample
dialog
=
sample
[
"dialog"
]
turn_list
=
[]
# collect the dialog history
# processing for each single dialog sample
for
j
,
turn
in
enumerate
(
dialog
):
# text of each turn
text
=
turn
[
"text"
]
if
not
(
text
.
endswith
(
"?"
)
or
text
.
endswith
(
"."
)
or
text
.
endswith
(
"!"
)):
text
=
text
+
"."
if
j
==
0
:
# first turn
turn_list
.
append
(
text
)
continue
speaker
=
turn
[
"speaker"
].
lower
()
if
"wizard"
in
speaker
:
checked_sentence
=
list
(
turn
[
"checked_sentence"
].
values
())
# knowledge
checked_passage
=
list
(
turn
[
"checked_passage"
].
values
())
# topic
assert
len
(
checked_sentence
)
<=
1
# get the ground truth knowledge
if
len
(
checked_sentence
)
>
0
:
checked_sentence
=
checked_sentence
[
0
]
else
:
checked_sentence
=
"no_passages_used"
if
len
(
checked_passage
)
==
1
:
checked_passage
=
checked_passage
[
0
]
else
:
checked_passage
=
"no_passages_used"
# get the topic
if
checked_passage
!=
"no_passages_used"
:
topic
=
checked_passage
else
:
topic
=
sample
[
"chosen_topic"
]
dialog_context
=
" [SEP] "
.
join
(
turn_list
)
knowledge
=
checked_sentence
response
=
text
# add the response into the dialog history
turn_list
.
append
(
response
)
# write to the output files
fproc
.
write
(
topic
+
"
\t
"
+
dialog_context
+
"
\t
"
+
\
knowledge
+
"
\t
"
+
response
+
"
\n
"
)
if
fknwl
:
fknwl
.
write
(
knowledge
+
"
\n
"
)
if
fresp
:
# tokenize for evaluation
response
=
" "
.
join
(
word_tokenize
(
response
))
fresp
.
write
(
response
+
"
\n
"
)
else
:
assert
"apprentice"
in
speaker
turn_list
.
append
(
text
)
fproc
.
close
()
if
fknwl
:
fknwl
.
close
()
if
fresp
:
fresp
.
close
()
def
process_woi_dataset
(
raw_file
,
processed_file
,
knwl_ref_file
,
resp_ref_file
):
"""
This is a function used for processing the wizard of internet (woi) dataset
Expected processed format:
topic
\t
dialogue context
\t
golden knowledge
\t
golden response
"""
print
(
"> Processing %s"
%
raw_file
)
fproc
=
open
(
processed_file
,
"w"
)
fknwl
=
open
(
knwl_ref_file
,
"w"
)
if
knwl_ref_file
else
None
fresp
=
open
(
resp_ref_file
,
"w"
)
if
resp_ref_file
else
None
with
open
(
raw_file
,
"r"
)
as
fr
:
for
i
,
line
in
tqdm
(
enumerate
(
fr
)):
# read line by line, each line uses json format
line
=
line
.
strip
()
item_dict
=
json
.
loads
(
line
)
# item_dict is a dictionary
# its key is the data id, and its value contains all the data content
item_dict
=
item_dict
.
values
()
item_dict
=
list
(
item_dict
)[
0
]
# len(item_dict) == 1
# get the whole dialog data for a single dialog sample
dialog_data
=
item_dict
[
'dialog_history'
]
length
=
len
(
dialog_data
)
turn_list
=
[]
# collect the dialog history
search_text
=
""
for
i
in
range
(
length
):
item
=
dialog_data
[
i
]
action
=
item
[
'action'
]
if
action
==
"Wizard => SearchAgent"
:
search_text
=
item
[
'text'
]
elif
action
==
"Wizard => Apprentice"
:
if
len
(
turn_list
)
==
0
:
# first turn
turn
=
item
[
'text'
]
turn_list
.
append
(
turn
)
continue
# get the relevant content
contents
=
item
[
"context"
][
"contents"
]
selects
=
item
[
"context"
][
"selected_contents"
]
flag
=
selects
[
0
][
0
]
selects
=
selects
[
1
:]
assert
len
(
selects
)
==
len
(
contents
)
# get the topic
if
flag
:
# no knowledge sentence is used for the response
topic
=
"no_topic"
knwl_sent
=
"no_passages_used"
else
:
# we consider the search text as the topic
topic
=
search_text
# get the knowledge sentence
knwl_sent
=
""
for
content
,
select
in
zip
(
contents
,
selects
):
content
=
content
[
'content'
]
assert
len
(
content
)
==
len
(
select
)
for
c
,
s
in
zip
(
content
,
select
):
if
s
:
knwl_sent
=
c
break
if
knwl_sent
==
""
:
# no knowledge is used for the response
topic
=
"no_topic"
knwl_sent
=
"no_passages_used"
# get dialogue context, knowledge, and response
dialog_context
=
" [SEP] "
.
join
(
turn_list
)
response
=
item
[
'text'
]
# processing
topic
=
topic
.
replace
(
"
\n
"
,
""
).
replace
(
"
\r
"
,
\
""
).
replace
(
"
\t
"
,
""
)
dialog_context
=
dialog_context
.
replace
(
"
\n
"
,
""
).
replace
(
"
\r
"
,
\
""
).
replace
(
"
\t
"
,
""
)
knwl_sent
=
knwl_sent
.
replace
(
"
\n
"
,
""
).
replace
(
"
\r
"
,
\
""
).
replace
(
"
\t
"
,
""
)
response
=
response
.
replace
(
"
\n
"
,
""
).
replace
(
"
\r
"
,
\
""
).
replace
(
"
\t
"
,
""
)
if
topic
!=
"no_topic"
:
# write to the ouput files
fproc
.
write
(
topic
+
"
\t
"
+
dialog_context
+
"
\t
"
+
\
knwl_sent
+
"
\t
"
+
response
+
"
\n
"
)
if
fknwl
:
fknwl
.
write
(
knwl_sent
+
"
\n
"
)
if
fresp
:
# tokenize for evaluation
response
=
" "
.
join
(
word_tokenize
(
response
))
fresp
.
write
(
response
+
"
\n
"
)
turn_list
.
append
(
response
)
elif
action
==
"Apprentice => Wizard"
:
turn
=
item
[
'text'
]
turn_list
.
append
(
turn
)
else
:
assert
action
==
"SearchAgent => Wizard"
,
\
"Please check whether you have used the correct data!"
fproc
.
close
()
if
fknwl
:
fknwl
.
close
()
if
fresp
:
fresp
.
close
()
def
get_database
(
test_datapath
,
train_datapath
,
data_type
):
"""Get the database by topics"""
assert
data_type
in
[
"wow_seen"
,
"wow_unseen"
,
"woi"
],
\
"Please input a correct data type!!"
# get test data topic dictionary
print
(
"> reading test data from %s"
%
test_datapath
)
test_topics
=
{}
with
open
(
test_datapath
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
test_topics
[
topic
]
=
True
print
(
"> reading data from %s"
%
train_datapath
)
train_data_by_topic
=
{}
dialog_data_by_topic
=
{}
dialog_examples
=
[]
with
open
(
train_datapath
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
knowledge
=
splits
[
2
]
response
=
splits
[
3
]
# filtering data samples
if
knowledge
==
"no_passages_used"
:
# when no knowledge is used
continue
if
data_type
!=
"wow_seen"
and
(
"("
in
knowledge
or
")"
in
knowledge
):
# when bracket exists in the knowledge
continue
if
data_type
!=
"wow_seen"
and
topic
not
in
knowledge
:
# when topic does not exist in the knowledge
continue
# get the instance
last_turn
=
turns
[
-
1
]
instance
=
"( "
+
last_turn
+
" ) "
+
topic
+
" => "
+
knowledge
# construct dialog example
dialog_example
=
""
if
data_type
!=
"wow_seen"
:
dialog_example
+=
"( "
+
topic
+
" ) "
for
i
,
turn
in
enumerate
(
turns
):
if
i
!=
0
:
dialog_example
+=
" "
dialog_example
+=
turn
# check overlaps
if
topic
in
test_topics
:
if
topic
not
in
train_data_by_topic
:
train_data_by_topic
[
topic
]
=
[
instance
]
else
:
train_data_by_topic
[
topic
].
append
(
instance
)
if
topic
not
in
dialog_data_by_topic
:
dialog_data_by_topic
[
topic
]
=
[
dialog_example
]
else
:
dialog_data_by_topic
[
topic
].
append
(
dialog_example
)
else
:
# filtering data samples
if
len
(
knowledge
.
split
())
>
20
:
# knowledge is too long
continue
if
knowledge
.
startswith
(
"It"
)
or
knowledge
.
startswith
(
"it"
)
or
\
knowledge
.
startswith
(
"This"
)
or
knowledge
.
startswith
(
"this"
):
continue
# append all the data into dialogue examples list
dialog_examples
.
append
((
topic
,
dialog_example
,
instance
))
return
train_data_by_topic
,
dialog_data_by_topic
,
dialog_examples
emb_dict
=
{}
def
select_prompts_based_on_similarity
(
query
,
dialog_list
,
prompt_list
,
topic
,
tokenizer
,
encoder
,
topk
):
"""Select samples based on the similarity"""
with
torch
.
no_grad
():
# get the query embeddings
query_ids
=
tokenizer
.
encode
(
query
)
query_ids
=
torch
.
LongTensor
([
query_ids
]).
cuda
()
query_emb
=
encoder
(
input_ids
=
query_ids
).
pooler_output
query_emb
=
query_emb
[
0
]
# calculate embeddings for the samples in the database
if
topic
in
emb_dict
:
example_embeddings
=
emb_dict
[
topic
]
example_embeddings
=
example_embeddings
.
cuda
()
else
:
for
idx
,
example
in
enumerate
(
dialog_list
):
example_ids
=
tokenizer
.
encode
(
example
)
example_ids
=
torch
.
LongTensor
([
example_ids
]).
cuda
()
example_emb
=
encoder
(
input_ids
=
example_ids
).
pooler_output
if
idx
==
0
:
example_embeddings
=
example_emb
else
:
example_embeddings
=
torch
.
cat
(
(
example_embeddings
,
example_emb
),
dim
=
0
)
emb_dict
[
topic
]
=
example_embeddings
.
cpu
()
# compare the similarity and select the topk samples
similarity_list
=
example_embeddings
.
matmul
(
query_emb
)
_
,
indices
=
torch
.
topk
(
similarity_list
,
k
=
topk
)
indices
=
indices
.
tolist
()
indices
=
indices
[::
-
1
]
# reverse the order
selected_prompts
=
[]
for
index
in
indices
:
# index = index.item()
selected_prompts
.
append
(
prompt_list
[
index
])
return
selected_prompts
def
prompt_selection_for_knowledge_generation
(
test_datapath
,
train_datapath
,
model_path
,
output_prompt_path
,
data_type
):
"""Selecting prompts for the knowledge generation"""
print
(
"> Selecting prompts for the knowledge generation"
)
train_data_by_topic
,
dialog_data_by_topic
,
dialog_examples
=
\
get_database
(
test_datapath
,
train_datapath
,
data_type
)
from
transformers
import
DPRQuestionEncoderTokenizer
print
(
"> loading tokenizer and encoder"
)
tokenizer
=
DPRQuestionEncoderTokenizer
.
from_pretrained
(
'facebook/dpr-question_encoder-single-nq-base'
)
encoder
=
torch
.
load
(
model_path
).
cuda
()
print
(
"> getting dialog embeddings"
)
with
torch
.
no_grad
():
for
idx
,
example
in
tqdm
(
enumerate
(
dialog_examples
)):
dialog
=
example
[
1
]
dialog_ids
=
tokenizer
.
encode
(
dialog
)
dialog_ids
=
torch
.
LongTensor
([
dialog_ids
]).
cuda
()
dialog_emb
=
encoder
(
input_ids
=
dialog_ids
).
pooler_output
if
idx
==
0
:
dialog_embeddings
=
dialog_emb
else
:
dialog_embeddings
=
torch
.
cat
((
dialog_embeddings
,
dialog_emb
),
dim
=
0
)
print
(
"> reading test data from %s"
%
test_datapath
)
prompt_list_for_each_sample
=
[]
with
open
(
test_datapath
,
"r"
)
as
f
:
for
i
,
line
in
tqdm
(
enumerate
(
f
)):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
# get the query sentence
query_sent
=
""
if
data_type
!=
"seen"
:
query_sent
+=
"( "
+
topic
+
" ) "
for
i
,
turn
in
enumerate
(
turns
):
if
i
!=
0
:
query_sent
+=
" "
query_sent
+=
turn
if
topic
not
in
train_data_by_topic
:
# get the query embedding
query_ids
=
tokenizer
.
encode
(
query_sent
)
query_ids
=
torch
.
LongTensor
([
query_ids
]).
cuda
()
query_emb
=
encoder
(
input_ids
=
query_ids
).
pooler_output
query_emb
=
query_emb
[
0
]
# calculate the similarity
similarity_list
=
dialog_embeddings
.
matmul
(
query_emb
)
_
,
indices
=
torch
.
sort
(
similarity_list
)
indices
=
indices
.
tolist
()
selected_topics
=
{}
selected_prompts
=
[]
num_prompt
=
0
for
index
in
indices
:
example
=
dialog_examples
[
index
]
topic_temp
=
example
[
0
]
if
topic_temp
not
in
selected_topics
:
selected_topics
[
topic_temp
]
=
True
selected_prompts
.
append
(
example
[
2
])
num_prompt
+=
1
if
num_prompt
==
10
:
break
# get the selected samples
example_list
=
selected_prompts
[::
-
1
]
key
=
topic
+
" "
+
turns
[
-
1
]
prompt_list_for_each_sample
.
append
({
key
:
example_list
})
else
:
num_data_sample
=
min
(
len
(
train_data_by_topic
[
topic
]),
10
)
total_example_list
=
train_data_by_topic
[
topic
]
dialog_list
=
dialog_data_by_topic
[
topic
]
assert
len
(
dialog_list
)
==
len
(
train_data_by_topic
[
topic
])
# calculate the similarity
example_list
=
select_prompts_based_on_similarity
(
query_sent
,
dialog_list
,
total_example_list
,
topic
,
tokenizer
,
encoder
,
topk
=
num_data_sample
)
key
=
topic
+
" "
+
turns
[
-
1
]
prompt_list_for_each_sample
.
append
({
key
:
example_list
})
print
(
"writing to %s"
%
output_prompt_path
)
with
open
(
output_prompt_path
,
"w"
)
as
f
:
for
instance
in
tqdm
(
prompt_list_for_each_sample
):
json
.
dump
(
instance
,
f
)
f
.
write
(
"
\n
"
)
def
prompt_selection_for_response_generation
(
input_path
,
output_path
,
seed
):
"""Selecting prompts for the response generation"""
print
(
"> Selecting prompts for the response generation"
)
print
(
"> set random seed"
)
np
.
random
.
seed
(
seed
)
prompt_example_list
=
[]
print
(
"> reading data from %s"
%
input_path
)
with
open
(
input_path
,
"r"
)
as
f
:
for
i
,
line
in
tqdm
(
enumerate
(
f
)):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
# get the topic, context, knowledge and response
topic
=
splits
[
0
]
dialog_context
=
splits
[
1
]
knowledge
=
splits
[
2
]
response
=
splits
[
3
]
turns
=
dialog_context
.
split
(
" [SEP] "
)[
-
3
:]
if
knowledge
==
"no_passages_used"
:
continue
# calculate the overlap ratio
from
nltk
import
word_tokenize
knowledge_sent_token_list
=
word_tokenize
(
knowledge
)
knowledge_sent_token_dict
=
{
token
:
True
for
token
in
knowledge_sent_token_list
}
knowledge_len
=
len
(
knowledge_sent_token_list
)
response_token_list
=
word_tokenize
(
response
)
response_len
=
len
(
response_token_list
)
num_overlap_token
=
0
accumulator
=
0
for
token
in
response_token_list
:
if
token
in
knowledge_sent_token_dict
:
accumulator
+=
1
else
:
if
accumulator
>=
10
:
num_overlap_token
+=
accumulator
accumulator
=
0
if
accumulator
>=
10
:
num_overlap_token
+=
accumulator
# filtering the data based on the ratio
if
num_overlap_token
>
response_len
*
0.9
or
num_overlap_token
<
response_len
*
0.6
:
continue
if
num_overlap_token
<
knowledge_len
*
0.8
:
continue
last_turn
=
" "
.
join
(
word_tokenize
(
turns
[
-
1
]))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
response
=
" "
.
join
(
word_tokenize
(
response
))
prompt_example
=
""
# add dialog context
prompt_example
+=
"Topic: "
+
topic
+
". "
prompt_example
+=
"User says: "
+
last_turn
+
" "
prompt_example
+=
"We know that: "
+
knowledge
+
" "
prompt_example
+=
"System replies: "
+
response
prompt_example_list
.
append
(
prompt_example
)
# shuffle the prompt examples
np
.
random
.
shuffle
(
prompt_example_list
)
print
(
"> writing to %s"
%
output_path
)
with
open
(
output_path
,
"w"
)
as
f
:
# f.write("Generate the System's response based on the knowledge sentence:\n")
for
i
in
tqdm
(
range
(
20
)):
example
=
prompt_example_list
[
i
]
f
.
write
(
example
+
"
\n
"
)
def
prepare_input_for_response_generation
(
test_file
,
knwl_gen_file
,
processed_file
):
"""Preparing inputs for the response generation"""
print
(
"> Reading knowledge file from %s"
%
knwl_gen_file
)
# get the knowledge list
with
open
(
knwl_gen_file
,
"r"
)
as
f
:
knowledge_list
=
f
.
readlines
()
print
(
"> Processing ..."
)
with
open
(
test_file
,
"r"
)
as
fr
:
with
open
(
processed_file
,
"w"
)
as
fw
:
for
line_num
,
line
in
enumerate
(
tqdm
(
fr
)):
line
=
line
.
strip
()
splits
=
line
.
split
(
"
\t
"
)
# prepare topic, context, knowledge and response
topic
=
splits
[
0
]
dialog_context
=
splits
[
1
]
response
=
splits
[
3
]
knowledge
=
knowledge_list
[
line_num
]
knowledge
=
knowledge
.
strip
()
if
"<|endoftext|>"
in
knowledge
:
knowledge
=
knowledge
.
replace
(
"<|endoftext|>"
,
""
)
# write to the output file
fw
.
write
(
topic
+
"
\t
"
+
dialog_context
+
"
\t
"
\
+
knowledge
+
"
\t
"
+
response
+
"
\n
"
)
if
__name__
==
"__main__"
:
args
=
get_args
()
if
args
.
func
==
"process_wow_dataset"
:
process_wow_dataset
(
args
.
raw_file
,
args
.
processed_file
,
args
.
knwl_ref_file
,
args
.
resp_ref_file
)
elif
args
.
func
==
"process_woi_dataset"
:
process_woi_dataset
(
args
.
raw_file
,
args
.
processed_file
,
args
.
knwl_ref_file
,
args
.
resp_ref_file
)
elif
args
.
func
==
"get_knwl_gen_prompts"
:
prompt_selection_for_knowledge_generation
(
args
.
test_file
,
args
.
train_file
,
args
.
model_file
,
args
.
processed_file
,
args
.
data_type
)
elif
args
.
func
==
"get_resp_gen_prompts"
:
prompt_selection_for_response_generation
(
args
.
train_file
,
args
.
processed_file
,
args
.
seed
)
elif
args
.
func
==
"prepare_input"
:
prepare_input_for_response_generation
(
args
.
test_file
,
args
.
knwl_gen_file
,
args
.
processed_file
)
tasks/msdp/prompt.py
0 → 100644
View file @
06fc51ce
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompting the pretrained language model to generate knowledge/response"""
import
json
import
torch
import
requests
from
nltk
import
word_tokenize
from
megatron
import
mpu
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron.model
import
GPTModel
from
megatron.training
import
get_model
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.text_generation
import
generate_and_post_process
def
call_model_api
(
inputs
,
tokens_to_generate
):
"""Calling the model api to get the output generations"""
args
=
get_args
()
# The following is an example of using the Megatron API
# You can also implement your own API function to place this part
headers
=
{
'Content-Type'
:
'application/json; charset=UTF-8'
}
data
=
{
"prompts"
:
[
inputs
],
"tokens_to_generate"
:
tokens_to_generate
,
"top_k"
:
1
}
data_json
=
json
.
dumps
(
data
)
outputs
=
requests
.
put
(
args
.
megatron_api_url
,
headers
=
headers
,
data
=
data_json
).
json
()[
"text"
][
0
]
input_len
=
len
(
inputs
)
outputs
=
outputs
[
input_len
:]
outputs
=
outputs
.
split
(
"
\n
"
)[
0
].
strip
()
return
outputs
def
read_prompts
(
prompt_path
,
prompt_type
,
n_example
):
"""Read prompt data"""
if
prompt_type
==
"knowledge"
:
# prompts for the knowledge generation
prompt_examples_dict
=
{}
# read prompt_path
with
open
(
prompt_path
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
line_dict
=
json
.
loads
(
line
)
key
=
list
(
line_dict
.
keys
())[
0
]
if
key
not
in
prompt_examples_dict
:
prompt_examples
=
line_dict
[
key
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
prompt_examples_dict
[
key
]
=
prompt
return
prompt_examples_dict
else
:
# prompts for the response generation
# read prompt_path
prompt
=
""
with
open
(
prompt_path
,
"r"
)
as
f
:
prompt_examples
=
f
.
readlines
()
prompt_examples
=
prompt_examples
[:
n_example
]
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
return
prompt
def
generate_samples_by_calling_api
():
""" Generate outputs by calling"""
args
=
get_args
()
assert
args
.
prompt_type
in
[
"knowledge"
,
"response"
],
\
"Please input a correct prompt type!"
if
args
.
prompt_type
==
"knowledge"
:
# read knowledge generation prompts
knwl_gen_prompt_dict
=
read_prompts
(
args
.
prompt_file
,
args
.
prompt_type
,
args
.
num_prompt_examples
)
else
:
resp_gen_prompt
=
read_prompts
(
args
.
prompt_file
,
args
.
prompt_type
,
args
.
num_prompt_examples
)
# read the test data
fname
=
open
(
args
.
sample_input_file
,
"r"
)
test_sample_list
=
fname
.
readlines
()
# create output file
fname_out
=
open
(
args
.
sample_output_file
,
"w"
)
# call the api to get the output generations
for
test_sample
in
test_sample_list
:
test_sample
=
test_sample
.
strip
()
splits
=
test_sample
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
# prepare the inputs for the api
if
args
.
prompt_type
==
"knowledge"
:
## inputs = prompt + current test
# get the prompt
turns
=
splits
[
1
].
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
key
=
topic
+
" "
+
last_turn
inputs
=
knwl_gen_prompt_dict
[
key
]
# add current test
inputs
+=
"( "
+
last_turn
+
" ) "
+
topic
+
" =>"
else
:
# inputs = prompt + current test
# get the prompt
inputs
=
resp_gen_prompt
# add current test
turns
=
splits
[
1
].
split
(
" [SEP] "
)
knowledge
=
splits
[
2
]
last_turn
=
turns
[
-
1
]
last_turn
=
" "
.
join
(
word_tokenize
(
last_turn
))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
knowledge
=
knowledge
.
strip
()
last_turn
=
last_turn
.
strip
()
inputs
+=
"Topic: "
+
topic
+
". "
inputs
+=
"User says: "
+
last_turn
+
" "
inputs
+=
"We know that: "
+
knowledge
+
" "
inputs
+=
"System replies:"
# get the output generations from the api,
# and write to the output file
generations
=
call_model_api
(
inputs
,
args
.
out_seq_length
)
fname_out
.
write
(
generations
)
fname_out
.
write
(
"
\n
"
)
fname
.
close
()
fname_out
.
close
()
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building GPT model ...'
)
model
=
GPTModel
(
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
generate_samples_by_prompting_input_from_file
(
model
):
"""Prompt a pretrained language model to generate knowledge/response"""
# get tokenizer
args
=
get_args
()
tokenizer
=
get_tokenizer
()
# Read the sample file and open the output file.
assert
args
.
sample_input_file
is
not
None
,
\
'sample input file is not provided.'
if
mpu
.
is_pipeline_first_stage
()
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
fname
=
open
(
args
.
sample_input_file
,
"r"
)
all_raw_text
=
fname
.
readlines
()
input_count
=
len
(
all_raw_text
)
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'`sample-output-file` not specified, setting '
'it to {}'
.
format
(
sample_output_file
))
else
:
sample_output_file
=
args
.
sample_output_file
fname_out
=
open
(
sample_output_file
,
"w"
)
# only two prompt types (i.e., knowledge and response) are allowed
assert
args
.
prompt_type
in
[
"knowledge"
,
"response"
],
\
"Please input a correct prompt type!"
# Read the prompt file
if
args
.
prompt_type
==
"knowledge"
:
# read the prompts for the knowledge generation
prompt_examples_dict
=
{}
with
open
(
args
.
prompt_file
,
"r"
)
as
f
:
for
i
,
line
in
enumerate
(
f
):
line
=
line
.
strip
()
line_dict
=
json
.
loads
(
line
)
key
=
list
(
line_dict
.
keys
())[
0
]
# get the prompt examples based on the key
if
key
not
in
prompt_examples_dict
:
prompt_examples
=
line_dict
[
key
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
prompt_examples_dict
[
key
]
=
prompt
else
:
# read the prompts for the response generation
# prompts are fixed for all test samples
with
open
(
args
.
prompt_file
,
"r"
)
as
f
:
prompt_examples
=
f
.
readlines
()
prompt_examples
=
prompt_examples
[:
args
.
num_prompt_examples
]
prompt
=
""
for
instance
in
prompt_examples
:
instance
=
instance
.
strip
()
prompt
+=
instance
+
"
\n
"
input_pos
=
0
model
.
eval
()
# perform prompting
with
torch
.
no_grad
():
while
True
:
raw_text_len
=
0
if
mpu
.
is_pipeline_first_stage
()
\
and
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
input_str
=
all_raw_text
[
input_pos
]
input_str
=
input_str
.
strip
()
splits
=
input_str
.
split
(
"
\t
"
)
topic
=
splits
[
0
]
if
args
.
prompt_type
==
"knowledge"
:
# first add the prompt into the raw_text
turns
=
splits
[
1
].
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
key
=
topic
+
" "
+
last_turn
raw_text
=
prompt_examples_dict
[
key
]
# construct inputs for knowledge generation
# then add the constructed inputs into the raw_text
raw_text
+=
"( "
+
last_turn
+
" ) "
+
topic
+
" =>"
else
:
# first add the prompt into the raw_text
raw_text
=
prompt
# construct inputs for response generation
# then add the constructed inputs into the raw_text
turns
=
splits
[
1
].
split
(
" [SEP] "
)
knowledge
=
splits
[
2
]
last_turn
=
turns
[
-
1
]
last_turn
=
" "
.
join
(
word_tokenize
(
last_turn
))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
knowledge
=
knowledge
.
strip
()
last_turn
=
last_turn
.
strip
()
raw_text
+=
"Topic: "
+
topic
+
". "
raw_text
+=
"User says: "
+
last_turn
+
" "
raw_text
+=
"We know that: "
+
knowledge
+
" "
raw_text
+=
"System replies:"
input_pos
+=
1
raw_text_len
=
len
(
raw_text
)
else
:
raw_text
=
"EMPTY TEXT"
if
input_pos
%
100
==
0
:
print_rank_0
(
"input_pos: %d"
%
input_pos
)
outputs
=
generate_and_post_process
(
model
=
model
,
prompts
=
[
raw_text
],
tokens_to_generate
=
args
.
out_seq_length
,
top_k_sampling
=
1
)
prompts_plus_generations
=
outputs
[
0
]
prompts_plus_generations
=
prompts_plus_generations
[
0
]
# write the generated output to the output file
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
is_pipeline_first_stage
():
generations
=
prompts_plus_generations
[
raw_text_len
:]
generations
=
generations
.
split
(
"
\n
"
)[
0
]
generations
=
generations
.
strip
()
fname_out
.
write
(
generations
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
if
input_pos
==
input_count
:
return
def
main
():
args
=
get_args
()
if
args
.
api_prompt
:
# obtain the generations by calling the api
generate_samples_by_calling_api
()
return
if
args
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
,
wrap_with_ddp
=
False
)
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
assert
len
(
model
)
==
1
,
"Above condition should have caught this"
model
=
model
[
0
]
# perform the prompting
generate_samples_by_prompting_input_from_file
(
model
)
tasks/vision/classification.py
→
tasks/vision/classification
/classification
.py
View file @
06fc51ce
File moved
tasks/vision/eval_utils.py
→
tasks/vision/
classification/
eval_utils.py
View file @
06fc51ce
File moved
tasks/vision/finetune_utils.py
View file @
06fc51ce
...
...
@@ -135,7 +135,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
def
_train
(
model
,
optimizer
,
lr
_scheduler
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
...
...
@@ -179,7 +179,7 @@ def _train(
# Train for one step.
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
lr
_scheduler
forward_step
,
batch
,
model
,
optimizer
,
opt_param
_scheduler
)
iteration
+=
1
...
...
@@ -206,7 +206,7 @@ def _train(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
lr
_scheduler
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Checkpointing
...
...
@@ -215,7 +215,7 @@ def _train(
and
args
.
save_interval
and
iteration
%
args
.
save_interval
==
0
):
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
...
...
@@ -231,7 +231,7 @@ def _train(
# Checkpointing at the end of each epoch.
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
lr
_scheduler
)
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param
_scheduler
)
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
...
...
@@ -266,7 +266,7 @@ def finetune(
# Build model, optimizer and learning rate scheduler.
timers
(
"model and optimizer"
).
start
()
model
,
optimizer
,
lr
_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
opt_param
_scheduler
=
setup_model_and_optimizer
(
model_provider
)
timers
(
"model and optimizer"
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
...
...
@@ -300,7 +300,7 @@ def finetune(
_train
(
model
,
optimizer
,
lr
_scheduler
,
opt_param
_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
...
...
tools/run_text_generation_server.py
View file @
06fc51ce
...
...
@@ -78,4 +78,7 @@ if __name__ == "__main__":
choice
=
torch
.
cuda
.
LongTensor
(
1
)
torch
.
distributed
.
broadcast
(
choice
,
0
)
if
choice
[
0
].
item
()
==
0
:
generate_and_post_process
(
model
)
try
:
generate_and_post_process
(
model
)
except
ValueError
as
ve
:
pass
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment