Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
91a80bd1
"git@developer.sourcefind.cn:modelzoo/bert_migraphx.git" did not exist on "35ad452d08368a21ada2199836e1110c42ae018c"
Commit
91a80bd1
authored
Dec 04, 2021
by
root
Browse files
update commands
parent
49907915
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
136 additions
and
79 deletions
+136
-79
tasks/knwl_dialo/metrics.py
tasks/knwl_dialo/metrics.py
+0
-11
tasks/knwl_dialo/preprocessing.py
tasks/knwl_dialo/preprocessing.py
+96
-63
tasks/knwl_dialo/prompt.py
tasks/knwl_dialo/prompt.py
+2
-1
tasks/knwl_dialo/scripts/commands.sh
tasks/knwl_dialo/scripts/commands.sh
+31
-0
tasks/knwl_dialo/scripts/data_processing.sh
tasks/knwl_dialo/scripts/data_processing.sh
+4
-4
tasks/knwl_dialo/utils.py
tasks/knwl_dialo/utils.py
+3
-0
No files found.
tasks/knwl_dialo/metrics.py
View file @
91a80bd1
...
...
@@ -10,14 +10,10 @@ from collections import Counter
from
typing
import
List
import
numpy
as
np
import
re
from
nltk.corpus
import
stopwords
re_art
=
re
.
compile
(
r
'\b(a|an|the)\b'
)
re_punc
=
re
.
compile
(
r
'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']'
)
stopword_list
=
stopwords
.
words
(
'english'
)
stopword_list
=
stopword_list
+
[
"n's"
,
"'s"
]
stopword_dict
=
{
token
:
True
for
token
in
stopword_list
}
def
normalize_answer
(
s
):
"""
...
...
@@ -29,13 +25,6 @@ def normalize_answer(s):
s
=
' '
.
join
(
s
.
split
())
return
s
def
remove_stopwords
(
token_list
):
new_list
=
[]
for
token
in
token_list
:
if
token
in
stopword_dict
:
continue
new_list
.
append
(
token
)
return
new_list
class
F1Metric
:
"""
...
...
tasks/knwl_dialo/preprocessing.py
View file @
91a80bd1
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
import
torch
import
argparse
from
nltk
import
word_tokenize
from
tqdm
import
tqdm
import
numpy
as
np
import
json
def
get_
p
ar
am
s
():
def
get_ar
g
s
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Preprocessing"
)
parser
.
add_argument
(
"--func"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--func"
,
type
=
str
,
default
=
None
,
help
=
"choose to run which function"
)
parser
.
add_argument
(
"--input_file"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--input_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the input file"
)
parser
.
add_argument
(
"--knowledge_file"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--knowledge_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the knowledge file"
)
parser
.
add_argument
(
"--test_file"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--test_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the test file"
)
parser
.
add_argument
(
"--train_file"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--train_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the train file"
)
parser
.
add_argument
(
"--output_file"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--output_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the output file"
)
parser
.
add_argument
(
"--model_file"
,
type
=
str
,
default
=
""
,
parser
.
add_argument
(
"--model_file"
,
type
=
str
,
default
=
None
,
help
=
"path of the model file"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
123456
,
parser
.
add_argument
(
"--data_type"
,
type
=
str
,
default
=
None
,
help
=
"data types (wow_seen, wow_unseen, or woi)"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
1234
,
help
=
"random seed"
)
p
ar
am
s
=
parser
.
parse_args
()
return
p
ar
am
s
ar
g
s
=
parser
.
parse_args
()
return
ar
g
s
def
process_wow_dataset
(
input_file
,
output_file
):
...
...
@@ -38,9 +41,11 @@ def process_wow_dataset(input_file, output_file):
topic
\t
dialogue context
\t
golden knowledge
\t
golden response
"""
print
(
"> Loading data from %s"
%
input_file
)
with
open
(
input_file
,
"r"
)
as
fr
:
dialog_data
=
json
.
load
(
fr
)
print
(
"> Processing data ..."
)
with
open
(
output_file
,
"w"
)
as
fw
:
for
i
,
sample
in
enumerate
(
tqdm
(
dialog_data
)):
# get all the dialog data for a single sample
...
...
@@ -50,8 +55,7 @@ def process_wow_dataset(input_file, output_file):
for
j
,
turn
in
enumerate
(
dialog
):
text
=
turn
[
"text"
]
if
not
(
text
.
endswith
(
"?"
)
or
text
.
endswith
(
"."
)
or
text
.
endswith
(
"!"
)):
text
=
text
+
" ."
text
=
" "
.
join
(
word_tokenize
(
text
))
text
=
text
+
"."
if
j
==
0
:
# first turn
...
...
@@ -99,8 +103,9 @@ def process_woi_dataset(input_file, output_file):
topic
\t
dialogue context
\t
golden knowledge
\t
golden response
"""
with
open
(
output_path
,
"w"
)
as
fw
:
with
open
(
input_path
,
"r"
)
as
fr
:
print
(
"> Processing %s"
%
input_file
)
with
open
(
output_file
,
"w"
)
as
fw
:
with
open
(
input_file
,
"r"
)
as
fr
:
for
i
,
line
in
tqdm
(
enumerate
(
fr
)):
line
=
line
.
strip
()
item_dict
=
json
.
loads
(
line
)
...
...
@@ -183,8 +188,8 @@ def process_woi_dataset(input_file, output_file):
assert
action
==
"SearchAgent => Wizard"
def
get_database
(
test_datapath
,
train_datapath
):
"""Get the database
sorted
by topics"""
def
get_database
(
test_datapath
,
train_datapath
,
data_type
):
"""Get the database by topics"""
# get test data topic list
print
(
"> reading test data from %s"
%
test_datapath
)
...
...
@@ -208,20 +213,30 @@ def get_database(test_datapath, train_datapath):
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
knowledge
=
splits
[
2
]
response
=
splits
[
3
]
# filtering data samples
if
knowledge
==
"no_passages_used"
:
continue
if
data_type
!=
"wow_seen"
and
(
"("
in
knowledge
or
")"
in
knowledge
):
continue
if
data_type
!=
"wow_seen"
and
topic
not
in
knowledge
:
continue
# get the instance
last_turn
=
turns
[
-
1
]
instance
=
"( "
+
last_turn
+
" ) "
+
topic
+
" => "
+
knowledge
if
data_type
==
"woi"
:
instance
=
"( "
+
last_turn
+
" ) "
+
topic
+
" -> "
+
knowledge
else
:
instance
=
"( "
+
last_turn
+
" ) "
+
topic
+
" => "
+
knowledge
# construct dialog example
dialog_example
=
""
dialog_example
+=
"( "
+
topic
+
" )"
for
turn
in
turns
:
dialog_example
+=
" "
if
data_type
!=
"wow_seen"
:
dialog_example
+=
"( "
+
topic
+
" ) "
for
i
,
turn
in
enumerate
(
turns
):
if
i
!=
0
:
dialog_example
+=
" "
dialog_example
+=
turn
# check overlaps
if
topic
in
test_topics
:
if
topic
not
in
train_data_by_topic
:
...
...
@@ -233,7 +248,16 @@ def get_database(test_datapath, train_datapath):
dialog_data_by_topic
[
topic
]
=
[
dialog_example
]
else
:
dialog_data_by_topic
[
topic
].
append
(
dialog_example
)
else
:
# filtering data samples
if
len
(
knowledge
.
split
())
>
20
:
# knowledge is too long
continue
if
knowledge
.
startswith
(
"It"
)
or
knowledge
.
startswith
(
"it"
)
or
\
knowledge
.
startswith
(
"This"
)
or
knowledge
.
startswith
(
"this"
):
continue
# append all the data into dialogue examples list
dialog_examples
.
append
((
topic
,
dialog_example
,
instance
))
...
...
@@ -283,13 +307,13 @@ def select_prompts_based_on_similarity(
def
prompt_selection_for_knowledge_generation
(
test_datapath
,
train_datapath
,
model_path
,
output_prompt_path
):
test_datapath
,
train_datapath
,
model_path
,
output_prompt_path
,
data_type
):
"""Selecting prompts for the knowledge generation"""
print
(
"> Selecting prompts for the knowledge generation"
)
train_data_by_topic
,
dialog_data_by_topic
,
dialog_examples
=
\
get_database
(
test_datapath
,
train_datapath
)
get_database
(
test_datapath
,
train_datapath
,
data_type
)
from
transformers
import
DPRQuestionEncoderTokenizer
print
(
"> loading tokenizer and encoder"
)
...
...
@@ -311,7 +335,6 @@ def prompt_selection_for_knowledge_generation(
dialog_embeddings
=
torch
.
cat
((
dialog_embeddings
,
dialog_emb
),
dim
=
0
)
print
(
"> reading test data from %s"
%
test_datapath
)
count_out_of_list
=
0
prompt_list_for_each_sample
=
[]
with
open
(
test_datapath
,
"r"
)
as
f
:
for
i
,
line
in
tqdm
(
enumerate
(
f
)):
...
...
@@ -321,16 +344,17 @@ def prompt_selection_for_knowledge_generation(
topic
=
splits
[
0
]
turns
=
splits
[
1
].
split
(
" [SEP] "
)[
-
3
:]
if
topic
not
in
train_data_by_topic
:
count_out_of_list
+=
1
# get the query sentence
query_sent
=
""
if
data_type
!=
"seen"
:
query_sent
+=
"( "
+
topic
+
" ) "
for
i
,
turn
in
enumerate
(
turns
):
if
i
!=
0
:
query_sent
+=
" "
query_sent
+=
turn
# calculate similarity
if
topic
not
in
train_data_by_topic
:
# get the query embedding
query_sent
=
""
query_sent
+=
"( "
+
topic
+
" )"
for
turn
in
turns
:
query_sent
+=
" "
query_sent
+=
turn
query_ids
=
tokenizer
.
encode
(
query_sent
)
query_ids
=
torch
.
LongTensor
([
query_ids
]).
cuda
()
query_emb
=
encoder
(
input_ids
=
query_ids
).
pooler_output
...
...
@@ -361,21 +385,14 @@ def prompt_selection_for_knowledge_generation(
else
:
num_data_sample
=
min
(
len
(
train_data_by_topic
[
topic
]),
10
)
total_example_list
=
train_data_by_topic
[
topic
]
# query_sent
query_sent
=
""
query_sent
+=
"( "
+
topic
+
" )"
for
turn
in
turns
:
query_sent
+=
" "
query_sent
+=
turn
dialog_list
=
dialog_data_by_topic
[
topic
]
assert
len
(
dialog_list
)
==
num_data_sample
assert
len
(
dialog_list
)
==
len
(
train_data_by_topic
[
topic
])
# calculate the similarity
selected_
example
s
=
select_prompts_based_on_similarity
(
example
_list
=
select_prompts_based_on_similarity
(
query_sent
,
dialog_list
,
total_example_list
,
topic
,
tokenizer
,
encoder
,
topk
=
num_data_sample
)
example_list
=
selected_examples
key
=
topic
+
" "
+
turns
[
-
1
]
prompt_list_for_each_sample
.
append
({
key
:
example_list
})
...
...
@@ -414,31 +431,42 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
from
nltk
import
word_tokenize
knowledge_sent_token_list
=
word_tokenize
(
knowledge
)
knowledge_sent_token_dict
=
{
token
:
True
for
token
in
knowledge_sent_token_list
}
response_token_list
=
response
.
split
()
knowledge_len
=
len
(
knowledge_sent_token_list
)
response_token_list
=
word_tokenize
(
response
)
response_len
=
len
(
response_token_list
)
num_overlap_token
=
0
accumulator
=
0
for
token
in
response_token_list
:
if
token
in
knowledge_sent_token_dict
:
num_overlap_token
+=
1
accumulator
+=
1
else
:
if
accumulator
>=
10
:
num_overlap_token
+=
accumulator
accumulator
=
0
if
accumulator
>=
10
:
num_overlap_token
+=
accumulator
# filtering the data based on the ratio
if
num_overlap_token
>
response_len
*
0.9
or
num_overlap_token
<
response_len
*
0.6
:
continue
if
num_overlap_token
<
knowledge_len
*
0.8
:
continue
last_turn
=
" "
.
join
(
word_tokenize
(
turns
[
-
1
]))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
response
=
" "
.
join
(
word_tokenize
(
response
))
prompt_example
=
""
# add dialog context
prompt_example
+=
"Topic: "
+
topic
+
". "
prompt_example
+=
"User says: "
+
turns
[
-
1
]
+
" "
prompt_example
+=
"User says: "
+
last_turn
+
" "
prompt_example
+=
"We know that: "
+
knowledge
+
" "
prompt_example
+=
"System replies: "
+
response
prompt_example_list
.
append
(
prompt_example
)
print
(
"> shuffle the prompt examples (total %d)"
%
len
(
prompt_example_list
))
# shuffle the prompt examples
print
(
"length: %d"
%
len
(
prompt_example_list
))
np
.
random
.
shuffle
(
prompt_example_list
)
print
(
"> Prompt example:"
)
print
(
prompt_example_list
[
0
])
print
(
"> writing to %s"
%
output_path
)
with
open
(
output_path
,
"w"
)
as
f
:
...
...
@@ -451,10 +479,12 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
def
prepare_input_for_response_generation
(
test_file
,
knowledge_file
,
output_file
):
"""Preparing inputs for the response generation"""
print
(
"> Reading knowledge file from %s"
%
knowledge_file
)
# get the knowledge list
with
open
(
knowledge_file
,
"r"
)
as
f
:
knowledge_list
=
f
.
readlines
()
print
(
"> Processing ..."
)
with
open
(
test_file
,
"r"
)
as
fr
:
with
open
(
output_file
,
"w"
)
as
fw
:
for
line_num
,
line
in
enumerate
(
tqdm
(
fr
)):
...
...
@@ -476,19 +506,22 @@ def prepare_input_for_response_generation(test_file, knowledge_file, output_file
if
__name__
==
"__main__"
:
p
ar
am
s
=
get_
p
ar
am
s
()
if
p
ar
am
s
.
func
==
"process_wow_dataset"
:
process_wow_dataset
(
p
ar
am
s
.
input_file
,
p
ar
am
s
.
output_file
)
ar
g
s
=
get_ar
g
s
()
if
ar
g
s
.
func
==
"process_wow_dataset"
:
process_wow_dataset
(
ar
g
s
.
input_file
,
ar
g
s
.
output_file
)
elif
p
ar
am
s
.
func
==
"process_woi_dataset"
:
process_woi_dataset
(
p
ar
am
s
.
input_file
,
p
ar
am
s
.
output_file
)
elif
ar
g
s
.
func
==
"process_woi_dataset"
:
process_woi_dataset
(
ar
g
s
.
input_file
,
ar
g
s
.
output_file
)
elif
p
ar
am
s
.
func
==
"get_prompts"
:
elif
ar
g
s
.
func
==
"get_
knwl_gen_
prompts"
:
prompt_selection_for_knowledge_generation
(
params
.
test_file
,
params
.
train_file
,
params
.
model_file
,
params
.
output_file
)
args
.
test_file
,
args
.
train_file
,
args
.
model_file
,
args
.
output_file
,
args
.
data_type
)
elif
args
.
func
==
"get_resp_gen_prompts"
:
prompt_selection_for_response_generation
(
p
ar
am
s
.
train_file
,
p
ar
am
s
.
output_file
,
p
ar
am
s
.
seed
)
ar
g
s
.
train_file
,
ar
g
s
.
output_file
,
ar
g
s
.
seed
)
elif
p
ar
am
s
.
func
==
"prepare_input"
:
elif
ar
g
s
.
func
==
"prepare_input"
:
prepare_input_for_response_generation
(
p
ar
am
s
.
test_file
,
p
ar
am
s
.
knowledge_file
,
p
ar
am
s
.
output_file
)
ar
g
s
.
test_file
,
ar
g
s
.
knowledge_file
,
ar
g
s
.
output_file
)
tasks/knwl_dialo/prompt.py
View file @
91a80bd1
...
...
@@ -120,8 +120,9 @@ def generate_samples_by_prompting_input_from_file(model):
# args.prompt_type == "response"
turns
=
splits
[
1
].
split
(
" [SEP] "
)
knowledge
=
splits
[
2
]
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
last_turn
=
turns
[
-
1
]
last_turn
=
" "
.
join
(
word_tokenize
(
last_turn
))
knowledge
=
" "
.
join
(
word_tokenize
(
knowledge
))
knowledge
=
knowledge
.
strip
()
last_turn
=
last_turn
.
strip
()
raw_text
+=
"Topic: "
+
topic
+
". "
...
...
tasks/knwl_dialo/scripts/commands.sh
0 → 100644
View file @
91a80bd1
# process WoW train
python tasks/knwl_dialo/preprocessing.py
--func
process_wow_dataset
--input_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.json
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt
# process WoW test
python tasks/knwl_dialo/preprocessing.py
--func
process_wow_dataset
--input_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_random_split.json
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_seen.txt
python tasks/knwl_dialo/preprocessing.py
--func
process_wow_dataset
--input_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_topic_split.json
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_unseen.txt
# process WoI test
python tasks/knwl_dialo/preprocessing.py
--func
process_woi_dataset
--input_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_internet/data/test.jsonl
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_internet/data/test.txt
# get knowledge generation prompts
# WoW seen
python tasks/knwl_dialo/preprocessing.py
--func
get_knwl_gen_prompts
--test_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_seen.txt
--train_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt
--model_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/checkpoints/dpr_wow/best_question_encoder.pt
--data_type
wow_seen
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/knowledge_prompts_test_seen.json
# WoW unseen
python tasks/knwl_dialo/preprocessing.py
--func
get_knwl_gen_prompts
--test_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_unseen.txt
--train_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt
--model_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/checkpoints/dpr_wow_ctrl/best_question_encoder.pt
--data_type
wow_unseen
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/knowledge_prompts_test_unseen.json
# WoI
python tasks/knwl_dialo/preprocessing.py
--func
get_knwl_gen_prompts
--test_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_internet/data/test.txt
--train_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt
--model_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/checkpoints/dpr_wow_ctrl/best_question_encoder.pt
--data_type
woi
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_internet/data/knowledge_prompts_test.json
# get response generation prompts --seed 147
python tasks/knwl_dialo/preprocessing.py
--func
get_resp_gen_prompts
--train_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/response_generation_prompts_temp.txt
--seed
1234
# prepare response generation inputs
# WoW seen
python tasks/knwl_dialo/preprocessing.py
--func
prepare_input
--test_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_seen.txt
--knowledge_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/output_testseen_knowledge_357m.txt
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_seen_resp_gen_input.txt
# WoW unseen
python tasks/knwl_dialo/preprocessing.py
--func
prepare_input
--test_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_unseen.txt
--knowledge_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/output_testunseen_knowledge_357m.txt
--output_file
/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_unseen_resp_gen_input.txt
tasks/knwl_dialo/scripts/data_processing.sh
View file @
91a80bd1
...
...
@@ -14,9 +14,9 @@ python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_wow_dataset --inp
# We provide the following script to process the raw data from Wizard of Internet
python
${
DIR
}
/tasks/knwl_dialo/preprocessing.py
--func
process_woi_dataset
--input_file
<PATH_OF_THE_INPUT_DATA>
--output_file
<PATH_OF_THE_OUTPUT_DATA>
# Obtain the knowledge generation prompts
and response generation prompts
python
${
DIR
}
/tasks/knwl_dialo/preprocessing.py
--func
get_prompts
--test_file
<PATH_OF_THE_PROCESSED_TEST_DATA>
--train_file
<PATH_OF_THE_PROCESSED_TRAIN_DATA>
--model_file
<PATH_OF_THE_DPR_MODEL>
--output_file
<PATH_OF_THE_OUTPUT_FILE>
# Obtain the knowledge generation prompts
python
${
DIR
}
/tasks/knwl_dialo/preprocessing.py
--func
get_
knwl_gen_
prompts
--test_file
<PATH_OF_THE_PROCESSED_TEST_DATA>
--train_file
<PATH_OF_THE_PROCESSED_TRAIN_DATA>
--model_file
<PATH_OF_THE_DPR_MODEL>
--output_file
<PATH_OF_THE_OUTPUT_FILE>
--data_type
<DATA_TYPE_OF_THE_INPUT_FILE>
#
Alternatively, we recommend you to directly download the already processed file through:
wget
--no-check-certificate
'https://docs.google.com/uc?export=download&id=1vP0eGxhkbWfeJ2dUUOEAflbOZq-Jlde_'
-O
data.gz
#
Obtain the response generation prompts
python
${
DIR
}
/tasks/knwl_dialo/preprocessing.py
--func
get_resp_gen_prompts
--train_file
<PATH_OF_THE_PROCESSED_TRAIN_DATA>
--output_file
<PATH_OF_THE_OUTPUT_FILE>
tasks/knwl_dialo/utils.py
View file @
91a80bd1
...
...
@@ -142,6 +142,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
tokens
=
context_tokens
if
maxlen
is
None
:
maxlen
=
args
.
seq_length
-
1
if
maxlen
>
(
org_context_length
+
args
.
out_seq_length
):
maxlen
=
org_context_length
+
args
.
out_seq_length
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
# start the generation process
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment