Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
91a80bd1
Commit
91a80bd1
authored
Dec 04, 2021
by
root
Browse files
update commands
parent
49907915
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
136 additions
and
79 deletions
+136
-79
tasks/knwl_dialo/metrics.py
tasks/knwl_dialo/metrics.py
+0
-11
tasks/knwl_dialo/preprocessing.py
tasks/knwl_dialo/preprocessing.py
+96
-63
tasks/knwl_dialo/prompt.py
tasks/knwl_dialo/prompt.py
+2
-1
tasks/knwl_dialo/scripts/commands.sh
tasks/knwl_dialo/scripts/commands.sh
+31
-0
tasks/knwl_dialo/scripts/data_processing.sh
tasks/knwl_dialo/scripts/data_processing.sh
+4
-4
tasks/knwl_dialo/utils.py
tasks/knwl_dialo/utils.py
+3
-0
No files found.
tasks/knwl_dialo/metrics.py
View file @
91a80bd1
...
@@ -10,14 +10,10 @@ from collections import Counter
...
@@ -10,14 +10,10 @@ from collections import Counter
from
typing
import
List
from
typing
import
List
import
numpy
as
np
import
numpy
as
np
import
re
import
re
from
nltk.corpus
import
stopwords
re_art
=
re
.
compile
(
r
'\b(a|an|the)\b'
)
re_art
=
re
.
compile
(
r
'\b(a|an|the)\b'
)
re_punc
=
re
.
compile
(
r
'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']'
)
re_punc
=
re
.
compile
(
r
'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']'
)
stopword_list
=
stopwords
.
words
(
'english'
)
stopword_list
=
stopword_list
+
[
"n's"
,
"'s"
]
stopword_dict
=
{
token
:
True
for
token
in
stopword_list
}
def
normalize_answer
(
s
):
def
normalize_answer
(
s
):
"""
"""
...
@@ -29,13 +25,6 @@ def normalize_answer(s):
...
@@ -29,13 +25,6 @@ def normalize_answer(s):
s
=
' '
.
join
(
s
.
split
())
s
=
' '
.
join
(
s
.
split
())
return
s
return
s
def
remove_stopwords
(
token_list
):
new_list
=
[]
for
token
in
token_list
:
if
token
in
stopword_dict
:
continue
new_list
.
append
(
token
)
return
new_list
class
F1Metric
:
class
F1Metric
:
"""
"""
...
...
tasks/knwl_dialo/preprocessing.py
View file @
91a80bd1
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
import
torch
import
argparse
import
argparse
from
nltk
import
word_tokenize
from
nltk
import
word_tokenize
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
numpy
as
np
import
numpy
as
np
import
json
import
json
def
get_
p
ar
am
s
():
def
get_ar
g
s
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Preprocessing"
)
parser
=
argparse
.
ArgumentParser
(
description
=
"Preprocessing"
)
parser
.
add_argument
(
"--func"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--func"
,
type
=
str
,
default
=
None
,
help
=
"choose to run which function"
)
help
=
"choose to run which function"
)
parser
.
add_argument
(
"--input_file"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--input_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the input file"
)
help
=
"path of the input file"
)
parser
.
add_argument
(
"--knowledge_file"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--knowledge_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the knowledge file"
)
help
=
"path of the knowledge file"
)
parser
.
add_argument
(
"--test_file"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--test_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the test file"
)
help
=
"path of the test file"
)
parser
.
add_argument
(
"--train_file"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--train_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the train file"
)
help
=
"path of the train file"
)
parser
.
add_argument
(
"--output_file"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--output_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the output file"
)
help
=
"path of the output file"
)
parser
.
add_argument
(
"--model_file"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--model_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the model file"
)
help
=
"path of the model file"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
123456
,
parser
.
add_argument
(
"--data_type"
,
type
=
str
,
default
=
None
,
help
=
"data types (wow_seen, wow_unseen, or woi)"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1234
,
help
=
"random seed"
)
help
=
"random seed"
)
p
ar
am
s
=
parser
.
parse_args
()
ar
g
s
=
parser
.
parse_args
()
return
p
ar
am
s
return
ar
g
s
def
process_wow_dataset
(
input_file
,
output_file
):
def
process_wow_dataset
(
input_file
,
output_file
):
...
@@ -38,9 +41,11 @@ def process_wow_dataset(input_file, output_file):
...
@@ -38,9 +41,11 @@ def process_wow_dataset(input_file, output_file):
topic
\t
dialogue context
\t
golden knowledge
\t
golden response
topic
\t
dialogue context
\t
golden knowledge
\t
golden response
"""
"""
print
(
"> Loading data from %s"
%
input_file
)
with
open
(
input_file
,
"r"
)
as
fr
:
with
open
(
input_file
,
"r"
)
as
fr
:
dialog_data
=
json
.
load
(
fr
)
dialog_data
=
json
.
load
(
fr
)
print
(
"> Processing data ..."
)
with
open
(
output_file
,
"w"
)
as
fw
:
with
open
(
output_file
,
"w"
)
as
fw
:
for
i
,
sample
in
enumerate
(
tqdm
(
dialog_data
)):
for
i
,
sample
in
enumerate
(
tqdm
(
dialog_data
)):
# get all the dialog data for a single sample
# get all the dialog data for a single sample
...
@@ -50,8 +55,7 @@ def process_wow_dataset(input_file, output_file):
...
@@ -50,8 +55,7 @@ def process_wow_dataset(input_file, output_file):
for
j
,
turn
in
enumerate
(
dialog
):
for
j
,
turn
in
enumerate
(
dialog
):
text
=
turn
[
"text"
]
text
=
turn
[
"text"
]
if
not
(
text
.
endswith
(
"?"
)
or
text
.
endswith
(
"."
)
or
text
.
endswith
(
"!"
)):
if
not
(
text
.
endswith
(
"?"
)
or
text
.
endswith
(
"."
)
or
text
.
endswith
(
"!"
)):
text
=
text
+
" ."
text
=
text
+
"."
text
=
" "
.
join
(
word_tokenize
(
text
))
if
j
==
0
:
if
j
==
0
:
# first turn
# first turn
...
@@ -99,8 +103,9 @@ def process_woi_dataset(input_file, output_file):
...
@@ -99,8 +103,9 @@ def process_woi_dataset(input_file, output_file):
topic
\t
dialogue context
\t
golden knowledge
\t
golden response
topic
\t
dialogue context
\t
golden knowledge
\t
golden response
"""
"""
with
open
(
output_path
,
"w"
)
as
fw
:
print
(
"> Processing %s"
%
input_file
)
with
open
(
input_path
,
"r"
)
as
fr
:
with
open
(
output_file
,
"w"
)
as
fw
:
with
open
(
input_file
,
"r"
)
as
fr
:
for
i
,
line
in
tqdm
(
enumerate
(
fr
)):
for
i
,
line
in
tqdm
(
enumerate
(
fr
)):
line
=
line
.
strip
()
line
=
line
.
strip
()
item_dict
=
json
.
loads
(
line
)
item_dict
=
json
.
loads
(
line
)
...
@@ -183,8 +188,8 @@ def process_woi_dataset(input_file, output_file):
...
@@ -183,8 +188,8 @@ def process_woi_dataset(input_file, output_file):
assert
action
==
"SearchAgent => Wizard"
assert
action
==
"SearchAgent => Wizard"
def
get_database
(
test_datapath
,
train_datapath
):
def
get_database
(
test_datapath
,
train_datapath
,
data_type
):
"""Get the database
sorted
by topics"""
"""Get the database by topics"""
# get test data topic list
# get test data topic list
print
(
"> reading test data from %s"
%
test_datapath
)
print
(
"> reading test data from %s"
%
test_datapath
)
...
@@ -208,17 +213,27 @@ def get_database(test_datapath, train_datapath):
...
@@ -208,17 +213,27 @@ def get_database(test_datapath, train_datapath):
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
knowledge
=
splits
[
2
]
knowledge
=
splits
[
2
]
response
=
splits
[
3
]
response
=
splits
[
3
]
# filtering data samples
if
knowledge
==
"no_passages_used"
:
if
knowledge
==
"no_passages_used"
:
continue
continue
if
data_type
!=
"wow_seen"
and
(
"("
in
knowledge
or
")"
in
knowledge
):
continue
if
data_type
!=
"wow_seen"
and
topic
not
in
knowledge
:
continue
# get the instance
# get the instance
last_turn
=
turns
[
-
1
]
last_turn
=
turns
[
-
1
]
if
data_type
==
"woi"
:
instance
=
"( "
+
last_turn
+
" ) "
+
topic
+
" -> "
+
knowledge
else
:
instance
=
"( "
+
last_turn
+
" ) "
+
topic
+
" => "
+
knowledge
instance
=
"( "
+
last_turn
+
" ) "
+
topic
+
" => "
+
knowledge
# construct dialog example
# construct dialog example
dialog_example
=
""
dialog_example
=
""
dialog_example
+=
"( "
+
topic
+
" )"
if
data_type
!=
"wow_seen"
:
for
turn
in
turns
:
dialog_example
+=
"( "
+
topic
+
" ) "
for
i
,
turn
in
enumerate
(
turns
):
if
i
!=
0
:
dialog_example
+=
" "
dialog_example
+=
" "
dialog_example
+=
turn
dialog_example
+=
turn
...
@@ -234,6 +249,15 @@ def get_database(test_datapath, train_datapath):
...
@@ -234,6 +249,15 @@ def get_database(test_datapath, train_datapath):
else
:
else
:
dialog_data_by_topic
[
topic
].
append
(
dialog_example
)
dialog_data_by_topic
[
topic
].
append
(
dialog_example
)
else
:
# filtering data samples
if
len
(
knowledge
.
split
())
>
20
:
# knowledge is too long
continue
if
knowledge
.
startswith
(
"It"
)
or
knowledge
.
startswith
(
"it"
)
or
\
knowledge
.
startswith
(
"This"
)
or
knowledge
.
startswith
(
"this"
):
continue
# append all the data into dialogue examples list
# append all the data into dialogue examples list
dialog_examples
.
append
((
topic
,
dialog_example
,
instance
))
dialog_examples
.
append
((
topic
,
dialog_example
,
instance
))
...
@@ -283,13 +307,13 @@ def select_prompts_based_on_similarity(
...
@@ -283,13 +307,13 @@ def select_prompts_based_on_similarity(
def
prompt_selection_for_knowledge_generation
(
def
prompt_selection_for_knowledge_generation
(
test_datapath
,
train_datapath
,
model_path
,
output_prompt_path
):
test_datapath
,
train_datapath
,
model_path
,
output_prompt_path
,
data_type
):
"""Selecting prompts for the knowledge generation"""
"""Selecting prompts for the knowledge generation"""
print
(
"> Selecting prompts for the knowledge generation"
)
print
(
"> Selecting prompts for the knowledge generation"
)
train_data_by_topic
,
dialog_data_by_topic
,
dialog_examples
=
\
train_data_by_topic
,
dialog_data_by_topic
,
dialog_examples
=
\
get_database
(
test_datapath
,
train_datapath
)
get_database
(
test_datapath
,
train_datapath
,
data_type
)
from
transformers
import
DPRQuestionEncoderTokenizer
from
transformers
import
DPRQuestionEncoderTokenizer
print
(
"> loading tokenizer and encoder"
)
print
(
"> loading tokenizer and encoder"
)
...
@@ -311,7 +335,6 @@ def prompt_selection_for_knowledge_generation(
...
@@ -311,7 +335,6 @@ def prompt_selection_for_knowledge_generation(
dialog_embeddings
=
torch
.
cat
((
dialog_embeddings
,
dialog_emb
),
dim
=
0
)
dialog_embeddings
=
torch
.
cat
((
dialog_embeddings
,
dialog_emb
),
dim
=
0
)
print
(
"> reading test data from %s"
%
test_datapath
)
print
(
"> reading test data from %s"
%
test_datapath
)
count_out_of_list
=
0
prompt_list_for_each_sample
=
[]
prompt_list_for_each_sample
=
[]
with
open
(
test_datapath
,
"r"
)
as
f
:
with
open
(
test_datapath
,
"r"
)
as
f
:
for
i
,
line
in
tqdm
(
enumerate
(
f
)):
for
i
,
line
in
tqdm
(
enumerate
(
f
)):
...
@@ -321,16 +344,17 @@ def prompt_selection_for_knowledge_generation(
...
@@ -321,16 +344,17 @@ def prompt_selection_for_knowledge_generation(
topic
=
splits
[
0
]
topic
=
splits
[
0
]
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
if
topic
not
in
train_data_by_topic
:
# get the query sentence
count_out_of_list
+=
1
# calculate similarity
# get the query embedding
query_sent
=
""
query_sent
=
""
query_sent
+=
"( "
+
topic
+
" )"
if
data_type
!=
"seen"
:
for
turn
in
turns
:
query_sent
+=
"( "
+
topic
+
" ) "
for
i
,
turn
in
enumerate
(
turns
):
if
i
!=
0
:
query_sent
+=
" "
query_sent
+=
" "
query_sent
+=
turn
query_sent
+=
turn
if
topic
not
in
train_data_by_topic
:
# get the query embedding
query_ids
=
tokenizer
.
encode
(
query_sent
)
query_ids
=
tokenizer
.
encode
(
query_sent
)
query_ids
=
torch
.
LongTensor
([
query_ids
]).
cuda
()
query_ids
=
torch
.
LongTensor
([
query_ids
]).
cuda
()
query_emb
=
encoder
(
input_ids
=
query_ids
).
pooler_output
query_emb
=
encoder
(
input_ids
=
query_ids
).
pooler_output
...
@@ -361,21 +385,14 @@ def prompt_selection_for_knowledge_generation(
...
@@ -361,21 +385,14 @@ def prompt_selection_for_knowledge_generation(
else
:
else
:
num_data_sample
=
min
(
len
(
train_data_by_topic
[
topic
]),
10
)
num_data_sample
=
min
(
len
(
train_data_by_topic
[
topic
]),
10
)
total_example_list
=
train_data_by_topic
[
topic
]
total_example_list
=
train_data_by_topic
[
topic
]
# query_sent
query_sent
=
""
query_sent
+=
"( "
+
topic
+
" )"
for
turn
in
turns
:
query_sent
+=
" "
query_sent
+=
turn
dialog_list
=
dialog_data_by_topic
[
topic
]
dialog_list
=
dialog_data_by_topic
[
topic
]
assert
len
(
dialog_list
)
==
num_data_sample
assert
len
(
dialog_list
)
==
len
(
train_data_by_topic
[
topic
])
# calculate the similarity
# calculate the similarity
selected_
example
s
=
select_prompts_based_on_similarity
(
example
_list
=
select_prompts_based_on_similarity
(
query_sent
,
dialog_list
,
total_example_list
,
query_sent
,
dialog_list
,
total_example_list
,
topic
,
tokenizer
,
encoder
,
topk
=
num_data_sample
)
topic
,
tokenizer
,
encoder
,
topk
=
num_data_sample
)
example_list
=
selected_examples
key
=
topic
+
" "
+
turns
[
-
1
]
key
=
topic
+
" "
+
turns
[
-
1
]
prompt_list_for_each_sample
.
append
({
key
:
example_list
})
prompt_list_for_each_sample
.
append
({
key
:
example_list
})
...
@@ -414,32 +431,43 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
...
@@ -414,32 +431,43 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
from
nltk
import
word_tokenize
from
nltk
import
word_tokenize
knowledge_sent_token_list
=
word_tokenize
(
knowledge
)
knowledge_sent_token_list
=
word_tokenize
(
knowledge
)
knowledge_sent_token_dict
=
{
token
:
True
for
token
in
knowledge_sent_token_list
}
knowledge_sent_token_dict
=
{
token
:
True
for
token
in
knowledge_sent_token_list
}
response_token_list
=
response
.
split
()
knowledge_len
=
len
(
knowledge_sent_token_list
)
response_token_list
=
word_tokenize
(
response
)
response_len
=
len
(
response_token_list
)
response_len
=
len
(
response_token_list
)
num_overlap_token
=
0
num_overlap_token
=
0
accumulator
=
0
for
token
in
response_token_list
:
for
token
in
response_token_list
:
if
token
in
knowledge_sent_token_dict
:
if
token
in
knowledge_sent_token_dict
:
num_overlap_token
+=
1
accumulator
+=
1
else
:
if
accumulator
>=
10
:
num_overlap_token
+=
accumulator
accumulator
=
0
if
accumulator
>=
10
:
num_overlap_token
+=
accumulator
# filtering the data based on the ratio
# filtering the data based on the ratio
if
num_overlap_token
>
response_len
*
0.9
or
num_overlap_token
<
response_len
*
0.6
:
if
num_overlap_token
>
response_len
*
0.9
or
num_overlap_token
<
response_len
*
0.6
:
continue
continue
if
num_overlap_token
<
knowledge_len
*
0.8
:
continue
last_turn
=
" "
.
join
(
word_tokenize
(
turns
[
-
1
]))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
response
=
" "
.
join
(
word_tokenize
(
response
))
prompt_example
=
""
prompt_example
=
""
# add dialog context
# add dialog context
prompt_example
+=
"Topic: "
+
topic
+
". "
prompt_example
+=
"Topic: "
+
topic
+
". "
prompt_example
+=
"User says: "
+
turns
[
-
1
]
+
" "
prompt_example
+=
"User says: "
+
last_turn
+
" "
prompt_example
+=
"We know that: "
+
knowledge
+
" "
prompt_example
+=
"We know that: "
+
knowledge
+
" "
prompt_example
+=
"System replies: "
+
response
prompt_example
+=
"System replies: "
+
response
prompt_example_list
.
append
(
prompt_example
)
prompt_example_list
.
append
(
prompt_example
)
print
(
"> shuffle the prompt examples (total %d)"
%
len
(
prompt_example_list
))
# shuffle the prompt examples
print
(
"length: %d"
%
len
(
prompt_example_list
))
np
.
random
.
shuffle
(
prompt_example_list
)
np
.
random
.
shuffle
(
prompt_example_list
)
print
(
"> Prompt example:"
)
print
(
prompt_example_list
[
0
])
print
(
"> writing to %s"
%
output_path
)
print
(
"> writing to %s"
%
output_path
)
with
open
(
output_path
,
"w"
)
as
f
:
with
open
(
output_path
,
"w"
)
as
f
:
# f.write("Generate the System's response based on the knowledge sentence:\n")
# f.write("Generate the System's response based on the knowledge sentence:\n")
...
@@ -451,10 +479,12 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
...
@@ -451,10 +479,12 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
def
prepare_input_for_response_generation
(
test_file
,
knowledge_file
,
output_file
):
def
prepare_input_for_response_generation
(
test_file
,
knowledge_file
,
output_file
):
"""Preparing inputs for the response generation"""
"""Preparing inputs for the response generation"""
print
(
"> Reading knowledge file from %s"
%
knowledge_file
)
# get the knowledge list
# get the knowledge list
with
open
(
knowledge_file
,
"r"
)
as
f
:
with
open
(
knowledge_file
,
"r"
)
as
f
:
knowledge_list
=
f
.
readlines
()
knowledge_list
=
f
.
readlines
()
print
(
"> Processing ..."
)
with
open
(
test_file
,
"r"
)
as
fr
:
with
open
(
test_file
,
"r"
)
as
fr
:
with
open
(
output_file
,
"w"
)
as
fw
:
with
open
(
output_file
,
"w"
)
as
fw
:
for
line_num
,
line
in
enumerate
(
tqdm
(
fr
)):
for
line_num
,
line
in
enumerate
(
tqdm
(
fr
)):
...
@@ -476,19 +506,22 @@ def prepare_input_for_response_generation(test_file, knowledge_file, output_file
...
@@ -476,19 +506,22 @@ def prepare_input_for_response_generation(test_file, knowledge_file, output_file
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
p
ar
am
s
=
get_
p
ar
am
s
()
ar
g
s
=
get_ar
g
s
()
if
p
ar
am
s
.
func
==
"process_wow_dataset"
:
if
ar
g
s
.
func
==
"process_wow_dataset"
:
process_wow_dataset
(
p
ar
am
s
.
input_file
,
p
ar
am
s
.
output_file
)
process_wow_dataset
(
ar
g
s
.
input_file
,
ar
g
s
.
output_file
)
elif
p
ar
am
s
.
func
==
"process_woi_dataset"
:
elif
ar
g
s
.
func
==
"process_woi_dataset"
:
process_woi_dataset
(
p
ar
am
s
.
input_file
,
p
ar
am
s
.
output_file
)
process_woi_dataset
(
ar
g
s
.
input_file
,
ar
g
s
.
output_file
)
elif
p
ar
am
s
.
func
==
"get_prompts"
:
elif
ar
g
s
.
func
==
"get_
knwl_gen_
prompts"
:
prompt_selection_for_knowledge_generation
(
prompt_selection_for_knowledge_generation
(
params
.
test_file
,
params
.
train_file
,
params
.
model_file
,
params
.
output_file
)
args
.
test_file
,
args
.
train_file
,
args
.
model_file
,
args
.
output_file
,
args
.
data_type
)
elif
args
.
func
==
"get_resp_gen_prompts"
:
prompt_selection_for_response_generation
(
prompt_selection_for_response_generation
(
p
ar
am
s
.
train_file
,
p
ar
am
s
.
output_file
,
p
ar
am
s
.
seed
)
ar
g
s
.
train_file
,
ar
g
s
.
output_file
,
ar
g
s
.
seed
)
elif
p
ar
am
s
.
func
==
"prepare_input"
:
elif
ar
g
s
.
func
==
"prepare_input"
:
prepare_input_for_response_generation
(
prepare_input_for_response_generation
(
p
ar
am
s
.
test_file
,
p
ar
am
s
.
knowledge_file
,
p
ar
am
s
.
output_file
)
ar
g
s
.
test_file
,
ar
g
s
.
knowledge_file
,
ar
g
s
.
output_file
)
tasks/knwl_dialo/prompt.py
View file @
91a80bd1
...
@@ -120,8 +120,9 @@ def generate_samples_by_prompting_input_from_file(model):
...
@@ -120,8 +120,9 @@ def generate_samples_by_prompting_input_from_file(model):
# args.prompt_type == "response"
# args.prompt_type == "response"
turns
=
splits
[
1
].
split
(
" [SEP] "
)
turns
=
splits
[
1
].
split
(
" [SEP] "
)
knowledge
=
splits
[
2
]
knowledge
=
splits
[
2
]
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
last_turn
=
turns
[
-
1
]
last_turn
=
turns
[
-
1
]
last_turn
=
" "
.
join
(
word_tokenize
(
last_turn
))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
knowledge
=
knowledge
.
strip
()
knowledge
=
knowledge
.
strip
()
last_turn
=
last_turn
.
strip
()
last_turn
=
last_turn
.
strip
()
raw_text
+=
"Topic: "
+
topic
+
". "
raw_text
+=
"Topic: "
+
topic
+
". "
...
...
tasks/knwl_dialo/scripts/commands.sh
0 → 100644
View file @
91a80bd1
# process WoW train
python tasks/knwl_dialo/preprocessing.py
--func
process_wow_dataset
--input_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.json
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt
# process WoW test
python tasks/knwl_dialo/preprocessing.py
--func
process_wow_dataset
--input_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_random_split.json
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_seen.txt
python tasks/knwl_dialo/preprocessing.py
--func
process_wow_dataset
--input_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_topic_split.json
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_unseen.txt
# process WoI test
python tasks/knwl_dialo/preprocessing.py
--func
process_woi_dataset
--input_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_internet/data/test.jsonl
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_internet/data/test.txt
# get knowledge generation prompts
# WoW seen
python tasks/knwl_dialo/preprocessing.py
--func
get_knwl_gen_prompts
--test_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_seen.txt
--train_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt
--model_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/checkpoints/dpr_wow/best_question_encoder.pt
--data_type
wow_seen
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/knowledge_prompts_test_seen.json
# WoW unseen
python tasks/knwl_dialo/preprocessing.py
--func
get_knwl_gen_prompts
--test_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_unseen.txt
--train_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt
--model_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/checkpoints/dpr_wow_ctrl/best_question_encoder.pt
--data_type
wow_unseen
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/knowledge_prompts_test_unseen.json
# WoI
python tasks/knwl_dialo/preprocessing.py
--func
get_knwl_gen_prompts
--test_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_internet/data/test.txt
--train_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt
--model_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/checkpoints/dpr_wow_ctrl/best_question_encoder.pt
--data_type
woi
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_internet/data/knowledge_prompts_test.json
# get response generation prompts --seed 147
python tasks/knwl_dialo/preprocessing.py
--func
get_resp_gen_prompts
--train_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/response_generation_prompts_temp.txt
--seed
1234
# prepare response generation inputs
# WoW seen
python tasks/knwl_dialo/preprocessing.py
--func
prepare_input
--test_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_seen.txt
--knowledge_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/output_testseen_knowledge_357m.txt
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_seen_resp_gen_input.txt
# WoW unseen
python tasks/knwl_dialo/preprocessing.py
--func
prepare_input
--test_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_unseen.txt
--knowledge_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/output_testunseen_knowledge_357m.txt
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_unseen_resp_gen_input.txt
tasks/knwl_dialo/scripts/data_processing.sh
View file @
91a80bd1
...
@@ -14,9 +14,9 @@ python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_wow_dataset --inp
...
@@ -14,9 +14,9 @@ python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_wow_dataset --inp
# We provide the following script to process the raw data from Wizard of Internet
# We provide the following script to process the raw data from Wizard of Internet
python
${
DIR
}
/tasks/knwl_dialo/preprocessing.py
--func
process_woi_dataset
--input_file
<PATH_OF_THE_INPUT_DATA>
--output_file
<PATH_OF_THE_OUTPUT_DATA>
python
${
DIR
}
/tasks/knwl_dialo/preprocessing.py
--func
process_woi_dataset
--input_file
<PATH_OF_THE_INPUT_DATA>
--output_file
<PATH_OF_THE_OUTPUT_DATA>
# Obtain the knowledge generation prompts
and response generation prompts
# Obtain the knowledge generation prompts
python
${
DIR
}
/tasks/knwl_dialo/preprocessing.py
--func
get_prompts
--test_file
<PATH_OF_THE_PROCESSED_TEST_DATA>
--train_file
<PATH_OF_THE_PROCESSED_TRAIN_DATA>
--model_file
<PATH_OF_THE_DPR_MODEL>
--output_file
<PATH_OF_THE_OUTPUT_FILE>
python
${
DIR
}
/tasks/knwl_dialo/preprocessing.py
--func
get_
knwl_gen_
prompts
--test_file
<PATH_OF_THE_PROCESSED_TEST_DATA>
--train_file
<PATH_OF_THE_PROCESSED_TRAIN_DATA>
--model_file
<PATH_OF_THE_DPR_MODEL>
--output_file
<PATH_OF_THE_OUTPUT_FILE>
--data_type
<DATA_TYPE_OF_THE_INPUT_FILE>
#
Alternatively, we recommend you to directly download the already processed file through:
#
Obtain the response generation prompts
wget
--no-check-certificate
'https://docs.google.com/uc?export=download&id=1vP0eGxhkbWfeJ2dUUOEAflbOZq-Jlde_'
-O
data.gz
python
${
DIR
}
/tasks/knwl_dialo/preprocessing.py
--func
get_resp_gen_prompts
--train_file
<PATH_OF_THE_PROCESSED_TRAIN_DATA>
--output_file
<PATH_OF_THE_OUTPUT_FILE>
tasks/knwl_dialo/utils.py
View file @
91a80bd1
...
@@ -142,6 +142,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -142,6 +142,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
tokens
=
context_tokens
tokens
=
context_tokens
if
maxlen
is
None
:
if
maxlen
is
None
:
maxlen
=
args
.
seq_length
-
1
maxlen
=
args
.
seq_length
-
1
if
maxlen
>
(
org_context_length
+
args
.
out_seq_length
):
maxlen
=
org_context_length
+
args
.
out_seq_length
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
# start the generation process
# start the generation process
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment