Commit 91a80bd1 authored by root's avatar root
Browse files

update commands

parent 49907915
...@@ -10,14 +10,10 @@ from collections import Counter ...@@ -10,14 +10,10 @@ from collections import Counter
from typing import List from typing import List
import numpy as np import numpy as np
import re import re
from nltk.corpus import stopwords
re_art = re.compile(r'\b(a|an|the)\b') re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']') re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
stopword_list = stopwords.words('english')
stopword_list = stopword_list + ["n's", "'s"]
stopword_dict = {token: True for token in stopword_list}
def normalize_answer(s): def normalize_answer(s):
""" """
...@@ -29,13 +25,6 @@ def normalize_answer(s): ...@@ -29,13 +25,6 @@ def normalize_answer(s):
s = ' '.join(s.split()) s = ' '.join(s.split())
return s return s
def remove_stopwords(token_list):
new_list = []
for token in token_list:
if token in stopword_dict:
continue
new_list.append(token)
return new_list
class F1Metric: class F1Metric:
""" """
......
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets""" """Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
import torch
import argparse import argparse
from nltk import word_tokenize from nltk import word_tokenize
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
import json import json
def get_params(): def get_args():
parser = argparse.ArgumentParser(description="Preprocessing") parser = argparse.ArgumentParser(description="Preprocessing")
parser.add_argument("--func", type=str, default="", parser.add_argument("--func", type=str, default=None,
help="choose to run which function") help="choose to run which function")
parser.add_argument("--input_file", type=str, default="", parser.add_argument("--input_file", type=str, default=None,
help="path of the input file") help="path of the input file")
parser.add_argument("--knowledge_file", type=str, default="", parser.add_argument("--knowledge_file", type=str, default=None,
help="path of the knowledge file") help="path of the knowledge file")
parser.add_argument("--test_file", type=str, default="", parser.add_argument("--test_file", type=str, default=None,
help="path of the test file") help="path of the test file")
parser.add_argument("--train_file", type=str, default="", parser.add_argument("--train_file", type=str, default=None,
help="path of the train file") help="path of the train file")
parser.add_argument("--output_file", type=str, default="", parser.add_argument("--output_file", type=str, default=None,
help="path of the output file") help="path of the output file")
parser.add_argument("--model_file", type=str, default="", parser.add_argument("--model_file", type=str, default=None,
help="path of the model file") help="path of the model file")
parser.add_argument("--seed", type=int, default=123456, parser.add_argument("--data_type", type=str, default=None,
help="data types (wow_seen, wow_unseen, or woi)")
parser.add_argument("--seed", type=int, default=1234,
help="random seed") help="random seed")
params = parser.parse_args() args = parser.parse_args()
return params return args
def process_wow_dataset(input_file, output_file): def process_wow_dataset(input_file, output_file):
...@@ -38,9 +41,11 @@ def process_wow_dataset(input_file, output_file): ...@@ -38,9 +41,11 @@ def process_wow_dataset(input_file, output_file):
topic \t dialogue context \t golden knowledge \t golden response topic \t dialogue context \t golden knowledge \t golden response
""" """
print("> Loading data from %s" % input_file)
with open(input_file, "r") as fr: with open(input_file, "r") as fr:
dialog_data = json.load(fr) dialog_data = json.load(fr)
print("> Processing data ...")
with open(output_file, "w") as fw: with open(output_file, "w") as fw:
for i, sample in enumerate(tqdm(dialog_data)): for i, sample in enumerate(tqdm(dialog_data)):
# get all the dialog data for a single sample # get all the dialog data for a single sample
...@@ -50,8 +55,7 @@ def process_wow_dataset(input_file, output_file): ...@@ -50,8 +55,7 @@ def process_wow_dataset(input_file, output_file):
for j, turn in enumerate(dialog): for j, turn in enumerate(dialog):
text = turn["text"] text = turn["text"]
if not (text.endswith("?") or text.endswith(".") or text.endswith("!")): if not (text.endswith("?") or text.endswith(".") or text.endswith("!")):
text = text + " ." text = text + "."
text = " ".join(word_tokenize(text))
if j == 0: if j == 0:
# first turn # first turn
...@@ -99,8 +103,9 @@ def process_woi_dataset(input_file, output_file): ...@@ -99,8 +103,9 @@ def process_woi_dataset(input_file, output_file):
topic \t dialogue context \t golden knowledge \t golden response topic \t dialogue context \t golden knowledge \t golden response
""" """
with open(output_path, "w") as fw: print("> Processing %s" % input_file)
with open(input_path, "r") as fr: with open(output_file, "w") as fw:
with open(input_file, "r") as fr:
for i, line in tqdm(enumerate(fr)): for i, line in tqdm(enumerate(fr)):
line = line.strip() line = line.strip()
item_dict = json.loads(line) item_dict = json.loads(line)
...@@ -183,8 +188,8 @@ def process_woi_dataset(input_file, output_file): ...@@ -183,8 +188,8 @@ def process_woi_dataset(input_file, output_file):
assert action == "SearchAgent => Wizard" assert action == "SearchAgent => Wizard"
def get_database(test_datapath, train_datapath): def get_database(test_datapath, train_datapath, data_type):
"""Get the database sorted by topics""" """Get the database by topics"""
# get test data topic list # get test data topic list
print("> reading test data from %s" % test_datapath) print("> reading test data from %s" % test_datapath)
...@@ -208,20 +213,30 @@ def get_database(test_datapath, train_datapath): ...@@ -208,20 +213,30 @@ def get_database(test_datapath, train_datapath):
turns = splits[1].split(" [SEP] ")[-3:] turns = splits[1].split(" [SEP] ")[-3:]
knowledge = splits[2] knowledge = splits[2]
response = splits[3] response = splits[3]
# filtering data samples
if knowledge == "no_passages_used": if knowledge == "no_passages_used":
continue continue
if data_type != "wow_seen" and ("(" in knowledge or ")" in knowledge):
continue
if data_type != "wow_seen" and topic not in knowledge:
continue
# get the instance # get the instance
last_turn = turns[-1] last_turn = turns[-1]
instance = "( " + last_turn + " ) " + topic + " => " + knowledge if data_type == "woi":
instance = "( " + last_turn + " ) " + topic + " -> " + knowledge
else:
instance = "( " + last_turn + " ) " + topic + " => " + knowledge
# construct dialog example # construct dialog example
dialog_example = "" dialog_example = ""
dialog_example += "( " + topic + " )" if data_type != "wow_seen":
for turn in turns: dialog_example += "( " + topic + " ) "
dialog_example += " " for i, turn in enumerate(turns):
if i != 0:
dialog_example += " "
dialog_example += turn dialog_example += turn
# check overlaps # check overlaps
if topic in test_topics: if topic in test_topics:
if topic not in train_data_by_topic: if topic not in train_data_by_topic:
...@@ -233,7 +248,16 @@ def get_database(test_datapath, train_datapath): ...@@ -233,7 +248,16 @@ def get_database(test_datapath, train_datapath):
dialog_data_by_topic[topic] = [dialog_example] dialog_data_by_topic[topic] = [dialog_example]
else: else:
dialog_data_by_topic[topic].append(dialog_example) dialog_data_by_topic[topic].append(dialog_example)
else:
# filtering data samples
if len(knowledge.split()) > 20:
# knowledge is too long
continue
if knowledge.startswith("It") or knowledge.startswith("it") or \
knowledge.startswith("This") or knowledge.startswith("this"):
continue
# append all the data into dialogue examples list # append all the data into dialogue examples list
dialog_examples.append((topic, dialog_example, instance)) dialog_examples.append((topic, dialog_example, instance))
...@@ -283,13 +307,13 @@ def select_prompts_based_on_similarity( ...@@ -283,13 +307,13 @@ def select_prompts_based_on_similarity(
def prompt_selection_for_knowledge_generation( def prompt_selection_for_knowledge_generation(
test_datapath, train_datapath, model_path, output_prompt_path): test_datapath, train_datapath, model_path, output_prompt_path, data_type):
"""Selecting prompts for the knowledge generation""" """Selecting prompts for the knowledge generation"""
print("> Selecting prompts for the knowledge generation") print("> Selecting prompts for the knowledge generation")
train_data_by_topic, dialog_data_by_topic, dialog_examples = \ train_data_by_topic, dialog_data_by_topic, dialog_examples = \
get_database(test_datapath, train_datapath) get_database(test_datapath, train_datapath, data_type)
from transformers import DPRQuestionEncoderTokenizer from transformers import DPRQuestionEncoderTokenizer
print("> loading tokenizer and encoder") print("> loading tokenizer and encoder")
...@@ -311,7 +335,6 @@ def prompt_selection_for_knowledge_generation( ...@@ -311,7 +335,6 @@ def prompt_selection_for_knowledge_generation(
dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0) dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0)
print("> reading test data from %s" % test_datapath) print("> reading test data from %s" % test_datapath)
count_out_of_list = 0
prompt_list_for_each_sample = [] prompt_list_for_each_sample = []
with open(test_datapath, "r") as f: with open(test_datapath, "r") as f:
for i, line in tqdm(enumerate(f)): for i, line in tqdm(enumerate(f)):
...@@ -321,16 +344,17 @@ def prompt_selection_for_knowledge_generation( ...@@ -321,16 +344,17 @@ def prompt_selection_for_knowledge_generation(
topic = splits[0] topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:] turns = splits[1].split(" [SEP] ")[-3:]
if topic not in train_data_by_topic: # get the query sentence
count_out_of_list += 1 query_sent = ""
if data_type != "seen":
query_sent += "( " + topic + " ) "
for i, turn in enumerate(turns):
if i != 0:
query_sent += " "
query_sent += turn
# calculate similarity if topic not in train_data_by_topic:
# get the query embedding # get the query embedding
query_sent = ""
query_sent += "( " + topic + " )"
for turn in turns:
query_sent += " "
query_sent += turn
query_ids = tokenizer.encode(query_sent) query_ids = tokenizer.encode(query_sent)
query_ids = torch.LongTensor([query_ids]).cuda() query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output query_emb = encoder(input_ids=query_ids).pooler_output
...@@ -361,21 +385,14 @@ def prompt_selection_for_knowledge_generation( ...@@ -361,21 +385,14 @@ def prompt_selection_for_knowledge_generation(
else: else:
num_data_sample = min(len(train_data_by_topic[topic]), 10) num_data_sample = min(len(train_data_by_topic[topic]), 10)
total_example_list = train_data_by_topic[topic] total_example_list = train_data_by_topic[topic]
# query_sent
query_sent = ""
query_sent += "( " + topic + " )"
for turn in turns:
query_sent += " "
query_sent += turn
dialog_list = dialog_data_by_topic[topic] dialog_list = dialog_data_by_topic[topic]
assert len(dialog_list) == num_data_sample assert len(dialog_list) == len(train_data_by_topic[topic])
# calculate the similarity # calculate the similarity
selected_examples = select_prompts_based_on_similarity( example_list = select_prompts_based_on_similarity(
query_sent, dialog_list, total_example_list, query_sent, dialog_list, total_example_list,
topic, tokenizer, encoder, topk=num_data_sample) topic, tokenizer, encoder, topk=num_data_sample)
example_list = selected_examples
key = topic + " " + turns[-1] key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list}) prompt_list_for_each_sample.append({key: example_list})
...@@ -414,31 +431,42 @@ def prompt_selection_for_response_generation(input_path, output_path, seed): ...@@ -414,31 +431,42 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
from nltk import word_tokenize from nltk import word_tokenize
knowledge_sent_token_list = word_tokenize(knowledge) knowledge_sent_token_list = word_tokenize(knowledge)
knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list} knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list}
response_token_list = response.split() knowledge_len = len(knowledge_sent_token_list)
response_token_list = word_tokenize(response)
response_len = len(response_token_list) response_len = len(response_token_list)
num_overlap_token = 0 num_overlap_token = 0
accumulator = 0
for token in response_token_list: for token in response_token_list:
if token in knowledge_sent_token_dict: if token in knowledge_sent_token_dict:
num_overlap_token += 1 accumulator += 1
else:
if accumulator >= 10:
num_overlap_token += accumulator
accumulator = 0
if accumulator >= 10:
num_overlap_token += accumulator
# filtering the data based on the ratio # filtering the data based on the ratio
if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6: if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6:
continue continue
if num_overlap_token < knowledge_len * 0.8:
continue
last_turn = " ".join(word_tokenize(turns[-1]))
knowledge = " ".join(word_tokenize(knowledge))
response = " ".join(word_tokenize(response))
prompt_example = "" prompt_example = ""
# add dialog context # add dialog context
prompt_example += "Topic: " + topic + ". " prompt_example += "Topic: " + topic + ". "
prompt_example += "User says: " + turns[-1] + " " prompt_example += "User says: " + last_turn + " "
prompt_example += "We know that: " + knowledge + " " prompt_example += "We know that: " + knowledge + " "
prompt_example += "System replies: " + response prompt_example += "System replies: " + response
prompt_example_list.append(prompt_example) prompt_example_list.append(prompt_example)
print("> shuffle the prompt examples (total %d)" % len(prompt_example_list)) # shuffle the prompt examples
print("length: %d" % len(prompt_example_list))
np.random.shuffle(prompt_example_list) np.random.shuffle(prompt_example_list)
print("> Prompt example:")
print(prompt_example_list[0])
print("> writing to %s" % output_path) print("> writing to %s" % output_path)
with open(output_path, "w") as f: with open(output_path, "w") as f:
...@@ -451,10 +479,12 @@ def prompt_selection_for_response_generation(input_path, output_path, seed): ...@@ -451,10 +479,12 @@ def prompt_selection_for_response_generation(input_path, output_path, seed):
def prepare_input_for_response_generation(test_file, knowledge_file, output_file): def prepare_input_for_response_generation(test_file, knowledge_file, output_file):
"""Preparing inputs for the response generation""" """Preparing inputs for the response generation"""
print("> Reading knowledge file from %s" % knowledge_file)
# get the knowledge list # get the knowledge list
with open(knowledge_file, "r") as f: with open(knowledge_file, "r") as f:
knowledge_list = f.readlines() knowledge_list = f.readlines()
print("> Processing ...")
with open(test_file, "r") as fr: with open(test_file, "r") as fr:
with open(output_file, "w") as fw: with open(output_file, "w") as fw:
for line_num, line in enumerate(tqdm(fr)): for line_num, line in enumerate(tqdm(fr)):
...@@ -476,19 +506,22 @@ def prepare_input_for_response_generation(test_file, knowledge_file, output_file ...@@ -476,19 +506,22 @@ def prepare_input_for_response_generation(test_file, knowledge_file, output_file
if __name__ == "__main__": if __name__ == "__main__":
params = get_params() args = get_args()
if params.func == "process_wow_dataset": if args.func == "process_wow_dataset":
process_wow_dataset(params.input_file, params.output_file) process_wow_dataset(args.input_file, args.output_file)
elif params.func == "process_woi_dataset": elif args.func == "process_woi_dataset":
process_woi_dataset(params.input_file, params.output_file) process_woi_dataset(args.input_file, args.output_file)
elif params.func == "get_prompts": elif args.func == "get_knwl_gen_prompts":
prompt_selection_for_knowledge_generation( prompt_selection_for_knowledge_generation(
params.test_file, params.train_file, params.model_file, params.output_file) args.test_file, args.train_file, args.model_file,
args.output_file, args.data_type)
elif args.func == "get_resp_gen_prompts":
prompt_selection_for_response_generation( prompt_selection_for_response_generation(
params.train_file, params.output_file, params.seed) args.train_file, args.output_file, args.seed)
elif params.func == "prepare_input": elif args.func == "prepare_input":
prepare_input_for_response_generation( prepare_input_for_response_generation(
params.test_file, params.knowledge_file, params.output_file) args.test_file, args.knowledge_file, args.output_file)
...@@ -120,8 +120,9 @@ def generate_samples_by_prompting_input_from_file(model): ...@@ -120,8 +120,9 @@ def generate_samples_by_prompting_input_from_file(model):
# args.prompt_type == "response" # args.prompt_type == "response"
turns = splits[1].split(" [SEP] ") turns = splits[1].split(" [SEP] ")
knowledge = splits[2] knowledge = splits[2]
knowledge = " ".join(word_tokenize(knowledge))
last_turn = turns[-1] last_turn = turns[-1]
last_turn = " ".join(word_tokenize(last_turn))
knowledge = " ".join(word_tokenize(knowledge))
knowledge = knowledge.strip() knowledge = knowledge.strip()
last_turn = last_turn.strip() last_turn = last_turn.strip()
raw_text += "Topic: " + topic + ". " raw_text += "Topic: " + topic + ". "
......
# process WoW train
python tasks/knwl_dialo/preprocessing.py --func process_wow_dataset --input_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.json --output_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt
# process WoW test
python tasks/knwl_dialo/preprocessing.py --func process_wow_dataset --input_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_random_split.json --output_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_seen.txt
python tasks/knwl_dialo/preprocessing.py --func process_wow_dataset --input_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_topic_split.json --output_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_unseen.txt
# process WoI test
python tasks/knwl_dialo/preprocessing.py --func process_woi_dataset --input_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_internet/data/test.jsonl --output_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_internet/data/test.txt
# get knowledge generation prompts
# WoW seen
python tasks/knwl_dialo/preprocessing.py --func get_knwl_gen_prompts --test_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_seen.txt --train_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt --model_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/checkpoints/dpr_wow/best_question_encoder.pt --data_type wow_seen --output_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/knowledge_prompts_test_seen.json
# WoW unseen
python tasks/knwl_dialo/preprocessing.py --func get_knwl_gen_prompts --test_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_unseen.txt --train_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt --model_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/checkpoints/dpr_wow_ctrl/best_question_encoder.pt --data_type wow_unseen --output_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/knowledge_prompts_test_unseen.json
# WoI
python tasks/knwl_dialo/preprocessing.py --func get_knwl_gen_prompts --test_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_internet/data/test.txt --train_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt --model_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/checkpoints/dpr_wow_ctrl/best_question_encoder.pt --data_type woi --output_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_internet/data/knowledge_prompts_test.json
# get response generation prompts --seed 147
python tasks/knwl_dialo/preprocessing.py --func get_resp_gen_prompts --train_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/train.txt --output_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/response_generation_prompts_temp.txt --seed 1234
# prepare response generation inputs
# WoW seen
python tasks/knwl_dialo/preprocessing.py --func prepare_input --test_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_seen.txt --knowledge_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/output_testseen_knowledge_357m.txt --output_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_seen_resp_gen_input.txt
# WoW unseen
python tasks/knwl_dialo/preprocessing.py --func prepare_input --test_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_unseen.txt --knowledge_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/output_testunseen_knowledge_357m.txt --output_file /gpfs/fs1/projects/gpu_adlr/datasets/zihanl/dialog_datasets/wizard_of_wikipedia/data/test_unseen_resp_gen_input.txt
...@@ -14,9 +14,9 @@ python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_wow_dataset --inp ...@@ -14,9 +14,9 @@ python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_wow_dataset --inp
# We provide the following script to process the raw data from Wizard of Internet # We provide the following script to process the raw data from Wizard of Internet
python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_woi_dataset --input_file <PATH_OF_THE_INPUT_DATA> --output_file <PATH_OF_THE_OUTPUT_DATA> python ${DIR}/tasks/knwl_dialo/preprocessing.py --func process_woi_dataset --input_file <PATH_OF_THE_INPUT_DATA> --output_file <PATH_OF_THE_OUTPUT_DATA>
# Obtain the knowledge generation prompts and response generation prompts # Obtain the knowledge generation prompts
python ${DIR}/tasks/knwl_dialo/preprocessing.py --func get_prompts --test_file <PATH_OF_THE_PROCESSED_TEST_DATA> --train_file <PATH_OF_THE_PROCESSED_TRAIN_DATA> --model_file <PATH_OF_THE_DPR_MODEL> --output_file <PATH_OF_THE_OUTPUT_FILE> python ${DIR}/tasks/knwl_dialo/preprocessing.py --func get_knwl_gen_prompts --test_file <PATH_OF_THE_PROCESSED_TEST_DATA> --train_file <PATH_OF_THE_PROCESSED_TRAIN_DATA> --model_file <PATH_OF_THE_DPR_MODEL> --output_file <PATH_OF_THE_OUTPUT_FILE> --data_type <DATA_TYPE_OF_THE_INPUT_FILE>
# Alternatively, we recommend you to directly download the already processed file through: # Obtain the response generation prompts
wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1vP0eGxhkbWfeJ2dUUOEAflbOZq-Jlde_' -O data.gz python ${DIR}/tasks/knwl_dialo/preprocessing.py --func get_resp_gen_prompts --train_file <PATH_OF_THE_PROCESSED_TRAIN_DATA> --output_file <PATH_OF_THE_OUTPUT_FILE>
...@@ -142,6 +142,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, ...@@ -142,6 +142,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
tokens = context_tokens tokens = context_tokens
if maxlen is None: if maxlen is None:
maxlen = args.seq_length - 1 maxlen = args.seq_length - 1
if maxlen > (org_context_length + args.out_seq_length):
maxlen = org_context_length + args.out_seq_length
lengths = torch.ones([batch_size]).long().cuda() * maxlen lengths = torch.ones([batch_size]).long().cuda() * maxlen
# start the generation process # start the generation process
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment