Commit 5833553a authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'multistage-prompting/main-multistage' into 'main'

Multistage prompting/main multistage

See merge request ADLR/megatron-lm!371
parents 945ece94 b8707ee2
# Multi-Stage Prompting for Knowledgeable Dialogue Generation
This directory contains all the scripts of multi-stage prompting for knowledgeable dialogue generation that includes data preparation, and knowledge and response generations. More details are available on [`knowledgeable task directory`](../../tasks/msdp).
#!/bin/bash
# Data preparation for our framework: preprocessing the WoW and WoI datasets
# The datasets can be downloaded through the following links:
# WoW: https://parl.ai/projects/wizard_of_wikipedia/
# WoI: https://parl.ai/projects/sea/
DIR=`pwd`
# Before running the preprocessing, please download
# the wizard of wikipedia and wizard datasets
WOW_DATA_FOLDER=<PATH_OF_WIZARD_OF_WIKIPEDIA_DATA_FOLDER>
WOI_DATA_FOLDER=<PATH_OF_WIZARD_OF_INTERNET_DATA_FOLDER>
# We provide examples for processing the raw data from Wizard of Wikipedia
# Processing the train dataset (train.json)
python ${DIR}/tasks/msdp/preprocessing.py \
--func process_wow_dataset \
--raw_file ${WOW_DATA_FOLDER}/train.json \
--processed_file ${WOW_DATA_FOLDER}/train_processed.txt
# Processing test seen dataset (test_random_split.json)
python ${DIR}/tasks/msdp/preprocessing.py \
--func process_wow_dataset \
--raw_file ${WOW_DATA_FOLDER}/test_random_split.json \
--processed_file ${WOW_DATA_FOLDER}/testseen_processed.txt \
--knwl_ref_file ${WOW_DATA_FOLDER}/output_testseen_knowledge_reference.txt \
--resp_ref_file ${WOW_DATA_FOLDER}/output_testseen_response_reference.txt
# processing test unseen dataset (test_topic_split.json)
python ${DIR}/tasks/msdp/preprocessing.py \
--func process_wow_dataset \
--raw_file ${WOW_DATA_FOLDER}/test_topic_split.json \
--processed_file ${WOW_DATA_FOLDER}/testunseen_processed.txt \
--knwl_ref_file ${WOW_DATA_FOLDER}/output_testunseen_knowledge_reference.txt \
--resp_ref_file ${WOW_DATA_FOLDER}/output_testunseen_response_reference.txt
# We provide the following script to process the raw data from Wizard of Internet
# Processing the test dataset (test.jsonl)
python ${DIR}/tasks/msdp/preprocessing.py \
--func process_woi_dataset \
--raw_file ${WOI_DATA_FOLDER}/test.jsonl \
--processed_file ${WOI_DATA_FOLDER}/test_processed.txt \
--knwl_ref_file ${WOI_DATA_FOLDER}/output_test_knowledge_reference.txt \
--resp_ref_file ${WOI_DATA_FOLDER}/output_test_response_reference.txt
# Get the knowledge generation prompts for the each test dataset in WoW and WoI
MODEL_FILE=<PATH_OF_THE_FINETUNED_DPR_MODEL>
# WoW test seen
python ${DIR}/tasks/msdp/preprocessing.py \
--func get_knwl_gen_prompts \
--test_file ${WOW_DATA_FOLDER}/testseen_processed.txt \
--train_file ${WOW_DATA_FOLDER}/train_processed.txt \
--model_file ${MODEL_FILE} \
--processed_file ${WOW_DATA_FOLDER}/output_testseen_knowledge_prompts.json \
--data_type wow_seen
# WoW test unseen
python ${DIR}/tasks/msdp/preprocessing.py \
--func get_knwl_gen_prompts \
--test_file ${WOW_DATA_FOLDER}/testunseen_processed.txt \
--train_file ${WOW_DATA_FOLDER}/train_processed.txt \
--model_file ${MODEL_FILE} \
--processed_file ${WOW_DATA_FOLDER}/output_testunseen_knowledge_prompts.json \
--data_type wow_unseen
# WoI
python ${DIR}/tasks/msdp/preprocessing.py \
--func get_knwl_gen_prompts \
--test_file ${WOI_DATA_FOLDER}/test_processed.txt \
--train_file ${WOW_DATA_FOLDER}/train_processed.txt \
--model_file ${MODEL_FILE} \
--processed_file ${WOI_DATA_FOLDER}/output_test_knowledge_prompts.json \
--data_type woi
# Get the response generation prompts (can be applied for all the test datasets)
python ${DIR}/tasks/msdp/preprocessing.py \
--func get_resp_gen_prompts \
--train_file ${WOW_DATA_FOLDER}/train_processed.txt \
--processed_file ${WOW_DATA_FOLDER}/output_response_prompts.txt
#!/bin/bash
#########################
# Evaluate the F1 scores.
#########################
WORLD_SIZE=1
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
MODEL_GEN_PATH=<PATH_OF_THE_KNOWLEDGE_GENERATION> \
(e.g., /testseen_knowledge_generations.txt)
GROUND_TRUTH_PATH=<PATH_OF_THE_GROUND_TRUTH_KNOWLEDGE> \
(e.g., /testseen_knowledge_reference.txt)
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/msdp/main.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--task MSDP-EVAL-F1 \
--guess-file ${MODEL_GEN_PATH} \
--answer-file ${GROUND_TRUTH_PATH}
############################################
# Evaluate BLEU, METEOR, and ROUGE-L scores.
############################################
# We follow the nlg-eval (https://github.com/Maluuba/nlg-eval) to
# evaluate the BLEU, METEOR, and ROUGE-L scores.
# To evaluate on these metrics, please setup the environments based on
# the nlg-eval github, and run the corresponding evaluation commands.
nlg-eval \
--hypothesis=<PATH_OF_THE_KNOWLEDGE_GENERATION> \
--references=<PATH_OF_THE_GROUND_TRUTH_KNOWLEDGE>
#!/bin/bash
#########################
# Evaluate the F1 scores.
#########################
WORLD_SIZE=1
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
MODEL_GEN_PATH=<PATH_OF_THE_RESPONSE_GENERATION> \
(e.g., /testseen_response_generations.txt)
GROUND_TRUTH_PATH=<PATH_OF_THE_GROUND_TRUTH_RESPONSE> \
(e.g., /testseen_response_reference.txt)
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/msdp/main.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--task MSDP-EVAL-F1 \
--guess-file ${MODEL_GEN_PATH} \
--answer-file ${GROUND_TRUTH_PATH}
##########################
# Evaluate the KF1 scores.
##########################
MODEL_GEN_PATH=<PATH_OF_THE_RESPONSE_GENERATION> \
(e.g., /testseen_response_generations.txt)
GROUND_TRUTH_PATH=<PATH_OF_THE_GROUND_TRUTH_KNOWLEDGE> \
(e.g., /testseen_knowledge_reference.txt)
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/msdp/main.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 4 \
--task MSDP-EVAL-F1 \
--guess-file ${MODEL_GEN_PATH} \
--answer-file ${GROUND_TRUTH_PATH}
############################################
# Evaluate BLEU, METEOR, and ROUGE-L scores.
############################################
# We follow the nlg-eval (https://github.com/Maluuba/nlg-eval) to
# evaluate the BLEU, METEOR, and ROUGE-L scores.
# To evaluate on these metrics, please setup the environments based on
# the nlg-eval github, and run the corresponding evaluation commands.
nlg-eval \
--hypothesis=<PATH_OF_THE_RESPONSE_GENERATION> \
--references=<PATH_OF_THE_GROUND_TRUTH_RESPONSE>
#!/bin/bash
# Preparing the input file for the response generation (second-stage prompting)
DIR=`pwd`
TEST_FILE=<PATH_OF_PROCESSED_TEST_DATA> \
(e.g., /testseen_processed.txt)
KNOWLEDGE_FILE=<PATH_OF_GENERATED_KNOWLEDGE_DATA> \
(e.g., /testseen_knowledge_generations.txt)
PROCESSED_FILE=<PATH_OF_INPUT_FILE_FOR_RESPONSE_GENERATION> \
(e.g., /testseen_processed_with_generated_knowledge.txt)
python ${DIR}/tasks/msdp/preprocessing.py \
--func prepare_input \
--test_file ${TEST_FILE} \
--knowledge_gen_file ${KNOWLEDGE_FILE} \
--processed_file ${PROCESSED_FILE}
#!/bin/bash
# Stage-1: Prompt a pretrained language model to generate the context-relevant knowledge
# The input contains prompts and current dialogue context, the output is the relevant knowledge
# The size of the pretrained language model is 357M
WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
CHECKPOINT_PATH=<PATH_OF_LANGUAGE_MODEL> (e.g., /357m)
VOCAB_PATH=<PATH_OF_VOCAB_FILE> (e.g., /gpt2-vocab.json)
MERGE_PATH=<PATH_OF_MERGE_FILE> (e.g., /gpt2-merges.txt)
INPUT_PATH=<PATH_OF_PROCESSED_TEST_DATA_FILE> \
(e.g., /testseen_processed.txt)
PROMPT_PATH=<PATH_OF_KNOWLEDGE_GENERATION_PROMPTS> \
(e.g., /testseen_knowledge_prompts.json)
OUTPUT_PATH=<PATH_OF_OUTPUT_GENERATION_FILE> \
(e.g., /testseen_knowledge_generations.txt)
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/msdp/main.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 1 \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--load ${CHECKPOINT_PATH} \
--fp16 \
--DDP-impl torch \
--tokenizer-type GPT2BPETokenizer \
--sample-input-file ${INPUT_PATH} \
--sample-output-file ${OUTPUT_PATH} \
--prompt-file ${PROMPT_PATH} \
--prompt-type knowledge \
--num-prompt-examples 10 \
--task MSDP-PROMPT
# NOTE: If you use api for the model generation, please use
# the "--api-prompt" flag (setting this value as True).
#!/bin/bash
# Stage-2: Prompt a pretrained language model to generate the corresponding response
# The input contains prompts, current dialogue context, and generated knowledge in Stage-1
# The output is the corresponding response.
# The size of the pretrained language model is 357M
WORLD_SIZE=8
DISTRIBUTED_ARGS="--nproc_per_node $WORLD_SIZE \
--nnodes 1 \
--node_rank 0 \
--master_addr localhost \
--master_port 6000"
CHECKPOINT_PATH=<PATH_OF_LANGUAGE_MODEL> (e.g., /357m)
VOCAB_PATH=<PATH_OF_VOCAB_FILE> (e.g., /gpt2-vocab.json)
MERGE_PATH=<PATH_OF_MERGE_FILE> (e.g., /gpt2-merges.txt)
INPUT_PATH=<PATH_OF_INPUT_TEST_DATA_FILE> (e.g., /testseen_processed.txt)
PROMPT_PATH=<PATH_OF_RESPONSE_GENERATION_PROMPTS> \
(e.g., /response_prompts.txt)
OUTPUT_PATH=<PATH_OF_OUTPUT_GENERATION_FILE> \
(e.g., /output_testseen_response_generations.txt)
python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/msdp/main.py \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--max-position-embeddings 2048 \
--micro-batch-size 1 \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--load ${CHECKPOINT_PATH} \
--fp16 \
--DDP-impl torch \
--tokenizer-type GPT2BPETokenizer \
--sample-input-file ${INPUT_PATH} \
--sample-output-file ${OUTPUT_PATH} \
--prompt-file ${PROMPT_PATH} \
--prompt-type response \
--num-prompt-examples 20 \
--task MSDP-PROMPT
# NOTE: If you use api for the model generation, please use
# the "--api-prompt" flag (setting this value as True).
# Multi-Stage Prompting for Knowledgeable Dialogue Generation
Below we present the steps to run our multi-stage dialogue prompting (MSDP) framework.
## Multi-Stage Dialogue Prompting
### Data Preparation
1. Dataset Download: [Wizard of Wikipedia](https://parl.ai/projects/wizard_of_wikipedia/) and [Wizard of Internet](https://parl.ai/projects/sea/)
2. Data Processing: We provide the script to run the [`data processing`](../../examples/msdp/data_processing.sh) of the datatsets.
### Stage-1: Prompting for Knowledge Generation
1. We provide the script to perform the [`first-stage prompting`](../../examples/msdp/prompt_knwl_gen.sh) for the knowledge generation.
2. We provide the [`evaluation script`](../../examples/msdp/eval_knwl_generation.sh) for the automatic evaluation (i.e., F1, BLEU, METEOR, and ROUGE-L) of the knowledge generation.
### Stage-2: Prompting for Response Generation
1. We provide the script to [`prepare the input file`](../../examples/msdp/prep_resp_gen.sh) for the response generation (based on the previously generated knowledge file).
2. We provide the script to perform the [`second-stage prompting`](../../examples/msdp/prompt_resp_gen.sh) for the response generation.
3. We provide the [`evaluation script`](../../examples/msdp/eval_resp_generation.sh) for the automatic evaluation (i.e., F1, KF1, BLEU, METEOR, and ROUGE-L) of the response generation.
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Model evaluation"""
from megatron import get_args
from megatron import print_rank_0
from tasks.msdp.metrics import F1Metric
from tqdm import tqdm
def evaluate_f1(guess_file, answer_file):
"""Evaluating F1 Score"""
guess_list = []
print_rank_0('reading %s' % guess_file)
with open(guess_file, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if "<|endoftext|>" in line:
line = line.replace("<|endoftext|>", "")
guess_list.append(line)
answer_list = []
print_rank_0('reading %s' % answer_file)
with open(answer_file, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if line == "no_passages_used":
line = ""
answer_list.append(line)
assert len(guess_list) == len(answer_list), \
"lengths of guess and answer are different!"
precision, recall, f1 = F1Metric.compute_all_pairs(guess_list, answer_list)
print_rank_0('Precision: %.4f; recall: %.4f; f1: %.4f' % (precision, recall, f1))
print_rank_0('done :-)')
def main():
args = get_args()
evaluate_f1(args.guess_file, args.answer_file)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run multi-stage dialogue prompting (MSDP)."""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(
os.path.join(os.path.dirname(__file__), os.path.pardir), os.path.pardir)))
from megatron import get_args
from megatron.initialize import initialize_megatron
def get_tasks_args(parser):
"""Provide extra arguments required for tasks."""
group = parser.add_argument_group(title='tasks')
# parameters for the knowledgeable dialogue generation
group.add_argument('--task', type=str, required=True,
help='Task name.')
group.add_argument("--sample-input-file", type=str, default=None,
help='Get input from file instead of interactive mode, '
'each line is an input.')
group.add_argument("--sample-output-file", type=str, default=None,
help='Output file got from --sample-input-file')
group.add_argument('--prompt-file', type=str, default=None,
help='prompting file')
group.add_argument('--prompt-type', type=str, default=None,
choices=['knowledge', 'response'],
help='prompt type (knowledge or response)')
group.add_argument('--num-prompt-examples', type=int, default=10,
help='number of prompt examples')
group.add_argument('--guess-file', type=str, default=None,
help='datapath for generated sentences')
group.add_argument('--answer-file', type=str, default=None,
help='datapath for golden sentences')
group.add_argument('--out-seq-length', type=int, default=100,
help='output sequence length')
group.add_argument('--api-prompt', default=False, action="store_true",
help='setup model api for prompting')
group.add_argument('--megatron-api-url', type=str, default=None,
help='url of the megatron api')
return parser
if __name__ == '__main__':
initialize_megatron(extra_args_provider=get_tasks_args)
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for downstream tasks.")
exit()
if args.task == 'MSDP-PROMPT':
from tasks.msdp.prompt import main
elif args.task == 'MSDP-EVAL-F1':
from tasks.msdp.evaluate import main
else:
raise NotImplementedError('Task {} is not implemented.'.format(
args.task))
main()
# The following code is adapted from
# https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/metrics.py,
# which is licensed under the MIT license. More details on the license can be
# found at https://github.com/facebookresearch/ParlAI/blob/master/LICENSE.
"""Provides standard metric evaluations for dialog."""
from collections import Counter
from typing import List
import numpy as np
import re
re_art = re.compile(r'\b(a|an|the)\b')
re_punc = re.compile(r'[!"#$%&()*+,-./:;<=>?@\[\]\\^`{|}~_\']')
def normalize_answer(s):
"""
Lower text and remove punctuation, articles and extra whitespace.
"""
s = s.lower()
s = re_punc.sub(' ', s)
s = re_art.sub(' ', s)
s = ' '.join(s.split())
return s
class F1Metric:
"""
Helper class which computes token-level F1.
"""
@staticmethod
def _prec_recall_f1_score(pred_items, gold_items):
"""
Compute precision, recall and f1 given a set of gold and prediction items.
:param pred_items: iterable of predicted values
:param gold_items: iterable of gold values
:return: tuple (p, r, f1) for precision, recall, f1
"""
common = Counter(gold_items) & Counter(pred_items)
num_same = sum(common.values())
if num_same == 0:
return 0, 0, 0
precision = 1.0 * num_same / len(pred_items)
recall = 1.0 * num_same / len(gold_items)
f1 = (2 * precision * recall) / (precision + recall)
return precision, recall, f1
@staticmethod
def compute_each_pair(guess: str, answer: str):
if answer == "":
return None, None, None
if guess == "":
return 0, 0, 0
g_tokens = normalize_answer(guess).split()
a_tokens = normalize_answer(answer).split()
precision, recall, f1 = F1Metric._prec_recall_f1_score(g_tokens, a_tokens)
return precision, recall, f1
@staticmethod
def compute_all_pairs(guesses: List[str], answers: List[str]):
# additional augment:
assert len(guesses) == len(answers)
precision_list, recall_list, f1_list = [], [], []
for guess, answer in zip(guesses, answers):
precision, recall, f1 = F1Metric.compute_each_pair(guess, answer)
if precision is None or recall is None or f1 is None:
continue
precision_list.append(precision)
recall_list.append(recall)
f1_list.append(f1)
return np.mean(precision_list), np.mean(recall_list), np.mean(f1_list)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing for Wizard of Wikipedia and Wizard of Internet datasets"""
import torch
import argparse
from nltk import word_tokenize
from tqdm import tqdm
import numpy as np
import json
def get_args():
parser = argparse.ArgumentParser(description="Preprocessing")
parser.add_argument("--func", type=str, default=None,
help="choose to run which function")
parser.add_argument("--raw_file", type=str, default=None,
help="path of the input file")
parser.add_argument("--processed_file", type=str, default=None,
help="path of the output file")
parser.add_argument("--knwl_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--resp_ref_file", type=str, default=None,
help="path of the knowledge reference file")
parser.add_argument("--knwl_gen_file", type=str, default=None,
help="path of the generated knowledge file")
parser.add_argument("--test_file", type=str, default=None,
help="path of the test file")
parser.add_argument("--train_file", type=str, default=None,
help="path of the train file")
parser.add_argument("--model_file", type=str, default=None,
help="path of the model file")
parser.add_argument("--data_type", type=str, default=None,
help="data types, choose one out of three types: \
wow_seen, wow_unseen, and woi")
parser.add_argument("--seed", type=int, default=1234,
help="random seed")
args = parser.parse_args()
return args
def process_wow_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
"""
This is a function used for processing the wizard of wikipedia (wow) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
# loading the raw data
print("> Loading data from %s" % raw_file)
with open(raw_file, "r") as fr:
dialog_data = json.load(fr)
print("> Processing data ...")
fproc = open(processed_file, "w")
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
for i, sample in enumerate(tqdm(dialog_data)):
# get all the dialog data for a single dialog sample
dialog = sample["dialog"]
turn_list = [] # collect the dialog history
# processing for each single dialog sample
for j, turn in enumerate(dialog):
# text of each turn
text = turn["text"]
if not (text.endswith("?") or text.endswith(".") or text.endswith("!")):
text = text + "."
if j == 0:
# first turn
turn_list.append(text)
continue
speaker = turn["speaker"].lower()
if "wizard" in speaker:
checked_sentence = list(turn["checked_sentence"].values()) # knowledge
checked_passage = list(turn["checked_passage"].values()) # topic
assert len(checked_sentence) <= 1
# get the ground truth knowledge
if len(checked_sentence) > 0:
checked_sentence = checked_sentence[0]
else:
checked_sentence = "no_passages_used"
if len(checked_passage) == 1:
checked_passage = checked_passage[0]
else:
checked_passage = "no_passages_used"
# get the topic
if checked_passage != "no_passages_used":
topic = checked_passage
else:
topic = sample["chosen_topic"]
dialog_context = " [SEP] ".join(turn_list)
knowledge = checked_sentence
response = text
# add the response into the dialog history
turn_list.append(response)
# write to the output files
fproc.write(topic + "\t" + dialog_context + "\t" + \
knowledge + "\t" + response + "\n")
if fknwl:
fknwl.write(knowledge + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
else:
assert "apprentice" in speaker
turn_list.append(text)
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def process_woi_dataset(raw_file, processed_file, knwl_ref_file, resp_ref_file):
"""
This is a function used for processing the wizard of internet (woi) dataset
Expected processed format:
topic \t dialogue context \t golden knowledge \t golden response
"""
print("> Processing %s" % raw_file)
fproc = open(processed_file, "w")
fknwl = open(knwl_ref_file, "w") if knwl_ref_file else None
fresp = open(resp_ref_file, "w") if resp_ref_file else None
with open(raw_file, "r") as fr:
for i, line in tqdm(enumerate(fr)):
# read line by line, each line uses json format
line = line.strip()
item_dict = json.loads(line)
# item_dict is a dictionary
# its key is the data id, and its value contains all the data content
item_dict = item_dict.values()
item_dict = list(item_dict)[0] # len(item_dict) == 1
# get the whole dialog data for a single dialog sample
dialog_data = item_dict['dialog_history']
length = len(dialog_data)
turn_list = [] # collect the dialog history
search_text = ""
for i in range(length):
item = dialog_data[i]
action = item['action']
if action == "Wizard => SearchAgent":
search_text = item['text']
elif action == "Wizard => Apprentice":
if len(turn_list) == 0:
# first turn
turn = item['text']
turn_list.append(turn)
continue
# get the relevant content
contents = item["context"]["contents"]
selects = item["context"]["selected_contents"]
flag = selects[0][0]
selects = selects[1:]
assert len(selects) == len(contents)
# get the topic
if flag:
# no knowledge sentence is used for the response
topic = "no_topic"
knwl_sent = "no_passages_used"
else:
# we consider the search text as the topic
topic = search_text
# get the knowledge sentence
knwl_sent = ""
for content, select in zip(contents, selects):
content = content['content']
assert len(content) == len(select)
for c, s in zip(content, select):
if s:
knwl_sent = c
break
if knwl_sent == "":
# no knowledge is used for the response
topic = "no_topic"
knwl_sent = "no_passages_used"
# get dialogue context, knowledge, and response
dialog_context = " [SEP] ".join(turn_list)
response = item['text']
# processing
topic = topic.replace("\n", "").replace("\r", \
"").replace("\t", "")
dialog_context = dialog_context.replace("\n", "").replace("\r", \
"").replace("\t", "")
knwl_sent = knwl_sent.replace("\n", "").replace("\r", \
"").replace("\t", "")
response = response.replace("\n", "").replace("\r", \
"").replace("\t", "")
if topic != "no_topic":
# write to the ouput files
fproc.write(topic + "\t" + dialog_context + "\t" + \
knwl_sent + "\t" + response + "\n")
if fknwl:
fknwl.write(knwl_sent + "\n")
if fresp:
# tokenize for evaluation
response = " ".join(word_tokenize(response))
fresp.write(response + "\n")
turn_list.append(response)
elif action == "Apprentice => Wizard":
turn = item['text']
turn_list.append(turn)
else:
assert action == "SearchAgent => Wizard", \
"Please check whether you have used the correct data!"
fproc.close()
if fknwl:
fknwl.close()
if fresp:
fresp.close()
def get_database(test_datapath, train_datapath, data_type):
"""Get the database by topics"""
assert data_type in ["wow_seen", "wow_unseen", "woi"], \
"Please input a correct data type!!"
# get test data topic dictionary
print("> reading test data from %s" % test_datapath)
test_topics = {}
with open(test_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
test_topics[topic] = True
print("> reading data from %s" % train_datapath)
train_data_by_topic = {}
dialog_data_by_topic = {}
dialog_examples = []
with open(train_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:]
knowledge = splits[2]
response = splits[3]
# filtering data samples
if knowledge == "no_passages_used":
# when no knowledge is used
continue
if data_type != "wow_seen" and ("(" in knowledge or ")" in knowledge):
# when bracket exists in the knowledge
continue
if data_type != "wow_seen" and topic not in knowledge:
# when topic does not exist in the knowledge
continue
# get the instance
last_turn = turns[-1]
instance = "( " + last_turn + " ) " + topic + " => " + knowledge
# construct dialog example
dialog_example = ""
if data_type != "wow_seen":
dialog_example += "( " + topic + " ) "
for i, turn in enumerate(turns):
if i != 0:
dialog_example += " "
dialog_example += turn
# check overlaps
if topic in test_topics:
if topic not in train_data_by_topic:
train_data_by_topic[topic] = [instance]
else:
train_data_by_topic[topic].append(instance)
if topic not in dialog_data_by_topic:
dialog_data_by_topic[topic] = [dialog_example]
else:
dialog_data_by_topic[topic].append(dialog_example)
else:
# filtering data samples
if len(knowledge.split()) > 20:
# knowledge is too long
continue
if knowledge.startswith("It") or knowledge.startswith("it") or \
knowledge.startswith("This") or knowledge.startswith("this"):
continue
# append all the data into dialogue examples list
dialog_examples.append((topic, dialog_example, instance))
return train_data_by_topic, dialog_data_by_topic, dialog_examples
emb_dict = {}
def select_prompts_based_on_similarity(
query, dialog_list, prompt_list, topic, tokenizer, encoder, topk):
"""Select samples based on the similarity"""
with torch.no_grad():
# get the query embeddings
query_ids = tokenizer.encode(query)
query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output
query_emb = query_emb[0]
# calculate embeddings for the samples in the database
if topic in emb_dict:
example_embeddings = emb_dict[topic]
example_embeddings = example_embeddings.cuda()
else:
for idx, example in enumerate(dialog_list):
example_ids = tokenizer.encode(example)
example_ids = torch.LongTensor([example_ids]).cuda()
example_emb = encoder(input_ids=example_ids).pooler_output
if idx == 0:
example_embeddings = example_emb
else:
example_embeddings = torch.cat(
(example_embeddings, example_emb), dim=0)
emb_dict[topic] = example_embeddings.cpu()
# compare the similarity and select the topk samples
similarity_list = example_embeddings.matmul(query_emb)
_, indices = torch.topk(similarity_list, k=topk)
indices = indices.tolist()
indices = indices[::-1] # reverse the order
selected_prompts = []
for index in indices:
# index = index.item()
selected_prompts.append(prompt_list[index])
return selected_prompts
def prompt_selection_for_knowledge_generation(
test_datapath, train_datapath, model_path, output_prompt_path, data_type):
"""Selecting prompts for the knowledge generation"""
print("> Selecting prompts for the knowledge generation")
train_data_by_topic, dialog_data_by_topic, dialog_examples = \
get_database(test_datapath, train_datapath, data_type)
from transformers import DPRQuestionEncoderTokenizer
print("> loading tokenizer and encoder")
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
'facebook/dpr-question_encoder-single-nq-base')
encoder = torch.load(model_path).cuda()
print("> getting dialog embeddings")
with torch.no_grad():
for idx, example in tqdm(enumerate(dialog_examples)):
dialog = example[1]
dialog_ids = tokenizer.encode(dialog)
dialog_ids = torch.LongTensor([dialog_ids]).cuda()
dialog_emb = encoder(input_ids=dialog_ids).pooler_output
if idx == 0:
dialog_embeddings = dialog_emb
else:
dialog_embeddings = torch.cat((dialog_embeddings, dialog_emb), dim=0)
print("> reading test data from %s" % test_datapath)
prompt_list_for_each_sample = []
with open(test_datapath, "r") as f:
for i, line in tqdm(enumerate(f)):
line = line.strip()
splits = line.split("\t")
topic = splits[0]
turns = splits[1].split(" [SEP] ")[-3:]
# get the query sentence
query_sent = ""
if data_type != "seen":
query_sent += "( " + topic + " ) "
for i, turn in enumerate(turns):
if i != 0:
query_sent += " "
query_sent += turn
if topic not in train_data_by_topic:
# get the query embedding
query_ids = tokenizer.encode(query_sent)
query_ids = torch.LongTensor([query_ids]).cuda()
query_emb = encoder(input_ids=query_ids).pooler_output
query_emb = query_emb[0]
# calculate the similarity
similarity_list = dialog_embeddings.matmul(query_emb)
_, indices = torch.sort(similarity_list)
indices = indices.tolist()
selected_topics = {}
selected_prompts = []
num_prompt = 0
for index in indices:
example = dialog_examples[index]
topic_temp = example[0]
if topic_temp not in selected_topics:
selected_topics[topic_temp] = True
selected_prompts.append(example[2])
num_prompt += 1
if num_prompt == 10:
break
# get the selected samples
example_list = selected_prompts[::-1]
key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list})
else:
num_data_sample = min(len(train_data_by_topic[topic]), 10)
total_example_list = train_data_by_topic[topic]
dialog_list = dialog_data_by_topic[topic]
assert len(dialog_list) == len(train_data_by_topic[topic])
# calculate the similarity
example_list = select_prompts_based_on_similarity(
query_sent, dialog_list, total_example_list,
topic, tokenizer, encoder, topk=num_data_sample)
key = topic + " " + turns[-1]
prompt_list_for_each_sample.append({key: example_list})
print("writing to %s" % output_prompt_path)
with open(output_prompt_path, "w") as f:
for instance in tqdm(prompt_list_for_each_sample):
json.dump(instance, f)
f.write("\n")
def prompt_selection_for_response_generation(input_path, output_path, seed):
"""Selecting prompts for the response generation"""
print("> Selecting prompts for the response generation")
print("> set random seed")
np.random.seed(seed)
prompt_example_list = []
print("> reading data from %s" % input_path)
with open(input_path, "r") as f:
for i, line in tqdm(enumerate(f)):
line = line.strip()
splits = line.split("\t")
# get the topic, context, knowledge and response
topic = splits[0]
dialog_context = splits[1]
knowledge = splits[2]
response = splits[3]
turns = dialog_context.split(" [SEP] ")[-3:]
if knowledge == "no_passages_used":
continue
# calculate the overlap ratio
from nltk import word_tokenize
knowledge_sent_token_list = word_tokenize(knowledge)
knowledge_sent_token_dict = {token: True for token in knowledge_sent_token_list}
knowledge_len = len(knowledge_sent_token_list)
response_token_list = word_tokenize(response)
response_len = len(response_token_list)
num_overlap_token = 0
accumulator = 0
for token in response_token_list:
if token in knowledge_sent_token_dict:
accumulator += 1
else:
if accumulator >= 10:
num_overlap_token += accumulator
accumulator = 0
if accumulator >= 10:
num_overlap_token += accumulator
# filtering the data based on the ratio
if num_overlap_token > response_len * 0.9 or num_overlap_token < response_len * 0.6:
continue
if num_overlap_token < knowledge_len * 0.8:
continue
last_turn = " ".join(word_tokenize(turns[-1]))
knowledge = " ".join(word_tokenize(knowledge))
response = " ".join(word_tokenize(response))
prompt_example = ""
# add dialog context
prompt_example += "Topic: " + topic + ". "
prompt_example += "User says: " + last_turn + " "
prompt_example += "We know that: " + knowledge + " "
prompt_example += "System replies: " + response
prompt_example_list.append(prompt_example)
# shuffle the prompt examples
np.random.shuffle(prompt_example_list)
print("> writing to %s" % output_path)
with open(output_path, "w") as f:
# f.write("Generate the System's response based on the knowledge sentence:\n")
for i in tqdm(range(20)):
example = prompt_example_list[i]
f.write(example + "\n")
def prepare_input_for_response_generation(test_file, knwl_gen_file, processed_file):
"""Preparing inputs for the response generation"""
print("> Reading knowledge file from %s" % knwl_gen_file)
# get the knowledge list
with open(knwl_gen_file, "r") as f:
knowledge_list = f.readlines()
print("> Processing ...")
with open(test_file, "r") as fr:
with open(processed_file, "w") as fw:
for line_num, line in enumerate(tqdm(fr)):
line = line.strip()
splits = line.split("\t")
# prepare topic, context, knowledge and response
topic = splits[0]
dialog_context = splits[1]
response = splits[3]
knowledge = knowledge_list[line_num]
knowledge = knowledge.strip()
if "<|endoftext|>" in knowledge:
knowledge = knowledge.replace("<|endoftext|>", "")
# write to the output file
fw.write(topic + "\t" + dialog_context + "\t" \
+ knowledge + "\t" + response + "\n")
if __name__ == "__main__":
args = get_args()
if args.func == "process_wow_dataset":
process_wow_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "process_woi_dataset":
process_woi_dataset(args.raw_file, args.processed_file, args.knwl_ref_file, args.resp_ref_file)
elif args.func == "get_knwl_gen_prompts":
prompt_selection_for_knowledge_generation(
args.test_file, args.train_file, args.model_file,
args.processed_file, args.data_type)
elif args.func == "get_resp_gen_prompts":
prompt_selection_for_response_generation(
args.train_file, args.processed_file, args.seed)
elif args.func == "prepare_input":
prepare_input_for_response_generation(
args.test_file, args.knwl_gen_file, args.processed_file)
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prompting the pretrained language model to generate knowledge/response"""
import json
import torch
import requests
from nltk import word_tokenize
from megatron import mpu
from megatron import get_args
from megatron import print_rank_0
from megatron import get_tokenizer
from megatron.model import GPTModel
from megatron.training import get_model
from megatron.checkpointing import load_checkpoint
from megatron.initialize import initialize_megatron
from megatron.text_generation import generate_and_post_process
def call_model_api(inputs, tokens_to_generate):
"""Calling the model api to get the output generations"""
args = get_args()
# The following is an example of using the Megatron API
# You can also implement your own API function to place this part
headers = {'Content-Type': 'application/json; charset=UTF-8'}
data = {"prompts": [inputs], "tokens_to_generate": tokens_to_generate, "top_k": 1}
data_json = json.dumps(data)
outputs = requests.put(args.megatron_api_url, headers=headers, data=data_json).json()["text"][0]
input_len = len(inputs)
outputs = outputs[input_len:]
outputs = outputs.split("\n")[0].strip()
return outputs
def read_prompts(prompt_path, prompt_type, n_example):
"""Read prompt data"""
if prompt_type == "knowledge":
# prompts for the knowledge generation
prompt_examples_dict = {}
# read prompt_path
with open(prompt_path, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[key] = prompt
return prompt_examples_dict
else:
# prompts for the response generation
# read prompt_path
prompt = ""
with open(prompt_path, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:n_example]
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
return prompt
def generate_samples_by_calling_api():
""" Generate outputs by calling"""
args = get_args()
assert args.prompt_type in ["knowledge", "response"], \
"Please input a correct prompt type!"
if args.prompt_type == "knowledge":
# read knowledge generation prompts
knwl_gen_prompt_dict = read_prompts(
args.prompt_file, args.prompt_type, args.num_prompt_examples)
else:
resp_gen_prompt = read_prompts(
args.prompt_file, args.prompt_type, args.num_prompt_examples)
# read the test data
fname = open(args.sample_input_file, "r")
test_sample_list = fname.readlines()
# create output file
fname_out = open(args.sample_output_file, "w")
# call the api to get the output generations
for test_sample in test_sample_list:
test_sample = test_sample.strip()
splits = test_sample.split("\t")
topic = splits[0]
# prepare the inputs for the api
if args.prompt_type == "knowledge":
## inputs = prompt + current test
# get the prompt
turns = splits[1].split(" [SEP] ")
last_turn = turns[-1]
key = topic + " " + last_turn
inputs = knwl_gen_prompt_dict[key]
# add current test
inputs += "( " + last_turn + " ) " + topic + " =>"
else:
# inputs = prompt + current test
# get the prompt
inputs = resp_gen_prompt
# add current test
turns = splits[1].split(" [SEP] ")
knowledge = splits[2]
last_turn = turns[-1]
last_turn = " ".join(word_tokenize(last_turn))
knowledge = " ".join(word_tokenize(knowledge))
knowledge = knowledge.strip()
last_turn = last_turn.strip()
inputs += "Topic: " + topic + ". "
inputs += "User says: " + last_turn + " "
inputs += "We know that: " + knowledge + " "
inputs += "System replies:"
# get the output generations from the api,
# and write to the output file
generations = call_model_api(inputs, args.out_seq_length)
fname_out.write(generations)
fname_out.write("\n")
fname.close()
fname_out.close()
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def generate_samples_by_prompting_input_from_file(model):
"""Prompt a pretrained language model to generate knowledge/response"""
# get tokenizer
args = get_args()
tokenizer = get_tokenizer()
# Read the sample file and open the output file.
assert args.sample_input_file is not None, \
'sample input file is not provided.'
if mpu.is_pipeline_first_stage() and mpu.get_tensor_model_parallel_rank() == 0:
fname = open(args.sample_input_file, "r")
all_raw_text = fname.readlines()
input_count = len(all_raw_text)
if args.sample_output_file is None:
sample_output_file = args.sample_input_file + ".out"
print('`sample-output-file` not specified, setting '
'it to {}'.format(sample_output_file))
else:
sample_output_file = args.sample_output_file
fname_out = open(sample_output_file, "w")
# only two prompt types (i.e., knowledge and response) are allowed
assert args.prompt_type in ["knowledge", "response"], \
"Please input a correct prompt type!"
# Read the prompt file
if args.prompt_type == "knowledge":
# read the prompts for the knowledge generation
prompt_examples_dict = {}
with open(args.prompt_file, "r") as f:
for i, line in enumerate(f):
line = line.strip()
line_dict = json.loads(line)
key = list(line_dict.keys())[0]
# get the prompt examples based on the key
if key not in prompt_examples_dict:
prompt_examples = line_dict[key]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
prompt_examples_dict[key] = prompt
else:
# read the prompts for the response generation
# prompts are fixed for all test samples
with open(args.prompt_file, "r") as f:
prompt_examples = f.readlines()
prompt_examples = prompt_examples[:args.num_prompt_examples]
prompt = ""
for instance in prompt_examples:
instance = instance.strip()
prompt += instance + " \n"
input_pos = 0
model.eval()
# perform prompting
with torch.no_grad():
while True:
raw_text_len = 0
if mpu.is_pipeline_first_stage() \
and mpu.get_tensor_model_parallel_rank() == 0:
input_str = all_raw_text[input_pos]
input_str = input_str.strip()
splits = input_str.split("\t")
topic = splits[0]
if args.prompt_type == "knowledge":
# first add the prompt into the raw_text
turns = splits[1].split(" [SEP] ")
last_turn = turns[-1]
key = topic + " " + last_turn
raw_text = prompt_examples_dict[key]
# construct inputs for knowledge generation
# then add the constructed inputs into the raw_text
raw_text += "( " + last_turn + " ) " + topic + " =>"
else:
# first add the prompt into the raw_text
raw_text = prompt
# construct inputs for response generation
# then add the constructed inputs into the raw_text
turns = splits[1].split(" [SEP] ")
knowledge = splits[2]
last_turn = turns[-1]
last_turn = " ".join(word_tokenize(last_turn))
knowledge = " ".join(word_tokenize(knowledge))
knowledge = knowledge.strip()
last_turn = last_turn.strip()
raw_text += "Topic: " + topic + ". "
raw_text += "User says: " + last_turn + " "
raw_text += "We know that: " + knowledge + " "
raw_text += "System replies:"
input_pos += 1
raw_text_len = len(raw_text)
else:
raw_text = "EMPTY TEXT"
if input_pos % 100 == 0:
print_rank_0("input_pos: %d" % input_pos)
outputs = generate_and_post_process(
model=model,
prompts=[raw_text],
tokens_to_generate=args.out_seq_length,
top_k_sampling=1)
prompts_plus_generations = outputs[0]
prompts_plus_generations = prompts_plus_generations[0]
# write the generated output to the output file
if mpu.get_tensor_model_parallel_rank() == 0:
if mpu.is_pipeline_first_stage():
generations = prompts_plus_generations[raw_text_len:]
generations = generations.split("\n")[0]
generations = generations.strip()
fname_out.write(generations)
fname_out.write("\n")
raw_text = None
if input_pos == input_count:
return
def main():
args = get_args()
if args.api_prompt:
# obtain the generations by calling the api
generate_samples_by_calling_api()
return
if args.num_layers_per_virtual_pipeline_stage is not None:
print("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
# Set up model and load checkpoint.
model = get_model(model_provider, wrap_with_ddp=False)
if args.load is not None:
_ = load_checkpoint(model, None, None)
assert len(model) == 1, "Above condition should have caught this"
model = model[0]
# perform the prompting
generate_samples_by_prompting_input_from_file(model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment