Commit e57a8f74 authored by zihanl's avatar zihanl
Browse files

remove tracked dialogctrl folder and train_gpt_conv.py

parent f777e6b3
......@@ -8,4 +8,7 @@ tensorboard
commands
*.log
logs
*.so
\ No newline at end of file
*.so
*.out
train_gpt_conv.py
dialogctrl/
\ No newline at end of file
"""Build Dataset for Controllable Coversational Model"""
import os
import torch
import numpy as np
from megatron import get_tokenizer
from megatron import print_rank_0
def read_data(tokenizer, data_path, train_module):
"""read and tokenize dialog data"""
data_list = []
with open(data_path, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
length_split = len(splits)
assert length_split == 2 or length_split == 3 or length_split == 4
if train_module == "dialog":
# if length_split == 2:
# continue
dialog_context = splits[0]
if length_split > 2:
ctrl_sent = splits[-2]
response = splits[-1]
# only take the last three turns in the dialog context
turns = dialog_context.split(" [SEP] ")
turns = turns[-3:]
# input_ids
for idx, turn in enumerate(turns):
if not (turn.endswith("?") or turn.endswith(".") or turn.endswith("!")):
turn = turn + " ."
if idx == 0:
input_ids = tokenizer.tokenize(turn)
else:
# input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
input_ids.extend(tokenizer.tokenize(turn))
if length_split > 2:
# when there is control sentence, add it into the input_ids
# input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(ctrl_sent))
input_ids.extend(tokenizer.tokenize("( " + ctrl_sent + " ) ."))
# output_ids
output_ids = tokenizer.tokenize(response)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
elif train_module == "control":
if length_split == 2:
continue
dialog_context = splits[0]
ctrl_sent = splits[-2]
ctrl_code = splits[1] if length_split == 4 else None
turns = dialog_context.split(" [SEP] ")
# last_turn = turns[-1]
# turns = turns[-3:]
# for idx, turn in enumerate(turns):
# if idx == 0:
# input_ids = tokenizer.tokenize(turn)
# else:
# # input_ids.extend([tokenizer.sep_id] + tokenizer.tokenize(turn))
# input_ids.extend(tokenizer.tokenize(turn))
# # input_ids
# if ctrl_code:
# ctrl_code_list = ctrl_code.split(" [CTRL] ")
# for code in ctrl_code_list:
# # input_ids.extend([tokenizer.ctrl_id] + tokenizer.tokenize(code))
# input_ids.extend(tokenizer.tokenize(code + " ."))
# put control code at the begginning
input_ids = []
if ctrl_code:
ctrl_code_list = ctrl_code.split(" [CTRL] ")
for code in ctrl_code_list:
input_ids.extend(tokenizer.tokenize("( " + code + " )"))
turns = turns[-3:]
for turn in turns:
if not (turn.endswith("?") or turn.endswith(".") or turn.endswith("!")):
turn = turn + " ."
input_ids.extend(tokenizer.tokenize(turn))
# output_ids
outputs = ctrl_sent
output_ids = tokenizer.tokenize(outputs)
data_list.append({"input_ids": input_ids, "output_ids": output_ids})
else:
raise ValueError("Please input a correct train-module name! (either dialog or cnotrol))")
return data_list
def data_shuffle(data, seed):
# set random seed to make the shuffling reproducible
np.random.seed(seed)
np.random.shuffle(data)
return data
class ControlDialogDataset(torch.utils.data.Dataset):
def __init__(self, data, max_seq_len, sep_id, pad_id, eod_id):
# need to deal with padding, label masking
self.data = data
self.max_seq_len = max_seq_len
self.sep_id = sep_id
self.pad_id = pad_id
self.eod_id = eod_id
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data_dict = self.data[idx]
input_ids, output_ids = data_dict["input_ids"], data_dict["output_ids"]
assert len(input_ids) < self.max_seq_len, "Set a larger max-seq-len!"
# length_of_loss_mask == length_of_text - 1
# text = input_ids + [self.sep_id] + output_ids + [self.eod_id]
text = input_ids + output_ids + [self.eod_id]
loss_mask = [0]*(len(input_ids)-1) + [1]*(len(output_ids)+1)
text_len = len(text)
if text_len > self.max_seq_len+1:
text = text[:self.max_seq_len+1]
loss_mask = loss_mask[:self.max_seq_len]
else:
text += [self.pad_id] * (self.max_seq_len+1 - text_len)
loss_mask += [0] * (self.max_seq_len+1 - text_len)
return {"text": np.array(text, dtype=np.int64), "loss_mask": np.array(loss_mask, dtype=np.int64)}
def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max_seq_len, seed):
"""Build train, valid, and test datasets."""
dataname_dict = {"wizard_of_wikipedia": {"train": "train_entity_based_control.txt", "valid": "valid_random_split_entity_based_control.txt", "test": "test_random_split_entity_based_control.txt"}}
train_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["train"])
valid_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["valid"])
test_data_path = os.path.join(data_folder, dataset_name+"/processed/"+dataname_dict[dataset_name]["test"])
tokenizer = get_tokenizer()
train_data_list = read_data(tokenizer, train_data_path, train_module)
valid_data_list = read_data(tokenizer, valid_data_path, train_module)
test_data_list = read_data(tokenizer, test_data_path, train_module)
# shuffle the training data
train_data_list = data_shuffle(train_data_list, seed)
# build train, valid, and test datasets
train_dataset = ControlDialogDataset(train_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
valid_dataset = ControlDialogDataset(valid_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
test_dataset = ControlDialogDataset(test_data_list, max_seq_len, sep_id=tokenizer.sep_id, pad_id=tokenizer.pad_id, eod_id=tokenizer.eod_id)
return train_dataset, valid_dataset, test_dataset
from src.config import get_params
from transformers import AutoTokenizer
import torch
import numpy as np
from tqdm import tqdm
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import string
import os
wn_lemma = WordNetLemmatizer()
stop_words = stopwords.words('english')
stop_words.append("n't")
stop_words.append("'s")
punctuations = list(string.punctuation)
punctuations.append("``")
punctuations.append("''")
stopwords_table = {word: True for word in stop_words}
punctuations_table = {punc: True for punc in punctuations}
# stop_words_and_punctuations = stop_words + punctuations
# stop_words_and_punctuations_table = {word: True for word in stop_words_and_punctuations}
label_set = ["O", "B", "I"]
def read_data(input_datapath):
data = []
print("Reading data from %s" % input_datapath)
with open(input_datapath, "r") as f:
for i, line in enumerate(f):
line = line.strip()
splits = line.split("\t")
length = len(splits)
assert length == 2 or length == 4
# length is 2: dialog context + response
# length is 4: dialog context + topic + control sentence + response
if length == 2:
# dialog context + response
data.append(line)
else:
# only need dialog context + control sentence + response
data.append(splits[0] + "\t" + splits[2] + "\t" + splits[3])
return data
def write_data(output_datapath, output_data):
print("Writing data to %s" % output_datapath)
with open(output_datapath, "w") as fw:
for data_sample in output_data:
fw.write(data_sample + "\n")
def detect_entities(tokenizer, ner_model, sentence):
tokens = sentence.split()
token_ids, first_tok_masks = [tokenizer.cls_token_id], [0]
for token in tokens:
subs_ = tokenizer.tokenize(token)
assert len(subs_) > 0
token_ids.extend(tokenizer.convert_tokens_to_ids(subs_))
first_tok_masks.extend([1] + [0] * (len(subs_) - 1))
token_ids.append(tokenizer.sep_token_id)
first_tok_masks.append(0)
token_ids = torch.LongTensor([token_ids]).cuda()
predictions = ner_model(token_ids)
predictions = predictions[0].data.cpu().numpy() # (seq_len, 3)
pred_ids = list(np.argmax(predictions, axis=1))
assert len(pred_ids) == len(first_tok_masks)
preds_for_each_word = []
for pred_id, mask in zip(pred_ids, first_tok_masks):
if mask == 1:
preds_for_each_word.append(label_set[pred_id])
assert len(preds_for_each_word) == len(tokens)
# extract entities
entity_list = []
temp = []
for i, (token, pred) in enumerate(zip(tokens, preds_for_each_word)):
if pred == "O":
if len(temp) > 0:
entity_list.append(" ".join(temp))
temp = []
else:
# pred == "B" or pred == "I"
temp.append(token)
return entity_list
def generate_entity_control_data(tokenizer, ner_model, input_data):
# aim to generate:
# dialog context + entity control code (optional) + relevant control sentence (contain entity) + response
output_data = []
n_skip, n_skip_no_overlap, n_skip_one_contain_another = 0, 0, 0
n_control, n_entity_control, n_overlap_control, n_control_without_code = 0, 0, 0, 0
total_num_control_code = 0
for sample_idx, data_item in enumerate(tqdm(input_data)):
# # Debug only
# if sample_idx > 1000:
# break
# 1. detect entities for dialog context, control sentence and response
splits = data_item.split("\t")
if len(splits) == 2:
output_data.append(data_item)
continue
assert len(splits) == 3
last_turn = splits[0].split(" [SEP] ")[-1]
control_sent = splits[1]
response = splits[2]
if control_sent in response or response in control_sent:
# if the whole control_sent is a part of response or vise versa, skip this data sample
n_skip += 1
n_skip_one_contain_another += 1
continue
last_turn_entities = detect_entities(tokenizer, ner_model, last_turn)
control_sent_entities = detect_entities(tokenizer, ner_model, control_sent)
response_entities = detect_entities(tokenizer, ner_model, response)
# 2. generate control code:
# 2.1 If there is one or more than one common entity in last_turn, control sentence and response. No need to use entity as control.
# 2.2 If the entity only exists in control sentence and response, use this as the control code.
# 2.3 If there is no overlaped entity or words between control sentence and response, skip this data sample.
# 2.4 If there is no overlapped entity but there are overlapped words, add entity in the control sentence (if any) as the control code if it is not in the dialog context
# TODO
# In general, need to trim the control sentence when it is too long.
# calculate common entity between control sentence and response
common_entity_list = []
for ctrl_entity in control_sent_entities:
for resp_entity in response_entities:
if resp_entity in ctrl_entity:
common_entity_list.append(ctrl_entity)
break
elif ctrl_entity in resp_entity:
common_entity_list.append(resp_entity)
break
if len(common_entity_list) == 0:
# calculate overlap between control sentence and response
control_word_list = control_sent.split()
response_word_list = response.split()
# response_word_table = {wn_lemma.lemmatize(word): True for word in response_word_list}
response_word_table = {}
for word in response_word_list:
response_word_table[wn_lemma.lemmatize(word)] = True
if "/" in word and len(word) > 0:
tokens = word.split("/")
for tok in tokens:
if len(tok) > 0:
response_word_table[wn_lemma.lemmatize(tok)] = True
overlap_phrases = []
temp = []
for word in control_word_list:
if word in punctuations_table:
continue
if word.lower() in stopwords_table and len(temp) == 0:
continue
if wn_lemma.lemmatize(word) in response_word_table:
temp.append(word)
else:
if len(temp) > 0:
if len(temp) > 5:
temp = temp[:5]
overlap_phrases.append(" ".join(temp))
temp = []
if len(overlap_phrases) == 0:
# skip this data sample
n_skip += 1
n_skip_no_overlap += 1
continue
n_control += 1
control_code_list = []
if len(control_sent_entities) > 0:
n_entity_control += 1
# reorder control_sent_entities based on the length of the entities (in a reverse order)
control_sent_entities = sorted(control_sent_entities, key=len, reverse=True)[:3]
for entity in control_sent_entities:
if entity not in last_turn:
add_flag = True
for code in control_code_list:
if entity in code:
add_flag = False
break
if add_flag:
control_code_list.append(entity)
else:
n_overlap_control += 1
# reorder overlap_phrases based on the length of the phrases (in a reverse order)
overlap_phrases = sorted(overlap_phrases, key=len, reverse=True)[:3]
for phrase in overlap_phrases:
if phrase not in last_turn:
add_flag = True
for code in control_code_list:
if phrase in code:
# remove repeat word
add_flag = False
break
if add_flag:
control_code_list.append(phrase)
else:
n_entity_control += 1
n_control += 1
control_code_list = []
# reorder common_entity_list based on the length of the entities (in a reverse order)
common_entity_list = sorted(common_entity_list, key=len, reverse=True)
for entity in common_entity_list:
if entity not in last_turn:
add_flag = True
for code in control_code_list:
if entity in code:
add_flag = False
break
if add_flag:
control_code_list.append(entity)
total_num_control_code += len(control_code_list)
if len(control_code_list) > 0:
output_data.append(splits[0] + "\t" + " [CTRL] ".join(control_code_list) + "\t" + control_sent + "\t" + response)
else:
n_control_without_code += 1
output_data.append(splits[0] + "\t" + control_sent + "\t" + response)
avg_num_control_code = total_num_control_code * 1.0 / n_control
print("number of skip sentences: %d (one contain another: %d + no overlap: %d)" % (n_skip, n_skip_one_contain_another, n_skip_no_overlap))
print("Total data size: %d. Number of control case: %d (entity control: %d + overlap control: %d)" % (len(output_data), n_control, n_entity_control, n_overlap_control))
print("Number of control code: %d; number of control case: %d; number of control case without control code: %d (averaged control code per case: %.4f)" % (total_num_control_code, n_control, n_control_without_code, avg_num_control_code))
return output_data
def main(params):
# load model and tokenizer
model_saved_path = os.path.join(params.saved_folder, params.model_name+".pt")
ner_model = torch.load(model_saved_path)["model"]
ner_model.cuda()
ner_model.eval()
tokenizer = AutoTokenizer.from_pretrained(params.model_name)
# load data
datafolder = os.path.join(params.default_folder, params.infer_datafolder)
input_datapath = os.path.join(datafolder, params.infer_dataname)
output_datapath = os.path.join(datafolder, params.output_dataname)
# read input data
input_data = read_data(input_datapath)
# process data (generate entity control data)
output_data = generate_entity_control_data(tokenizer, ner_model, input_data)
# write output data
write_data(output_datapath, output_data)
if __name__ == "__main__":
params = get_params()
main(params)
\ No newline at end of file
import torch
import numpy as np
from transformers import AutoTokenizer
from tabulate import tabulate
tokenizer = AutoTokenizer.from_pretrained("roberta-large")
ner_model = torch.load("/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/checkpoints/ner_model/roberta-large.pt")["model"]
ner_model.cuda()
ner_model.eval()
label_set = ["O", "B", "I"]
for step in range(100):
print("===========================================================================")
input_sent = input(">> Input:")
tokens = input_sent.split()
token_ids, first_tok_masks = [tokenizer.cls_token_id], [0]
for token in tokens:
subs_ = tokenizer.tokenize(token)
assert len(subs_) > 0
token_ids.extend(tokenizer.convert_tokens_to_ids(subs_))
first_tok_masks.extend([1] + [0] * (len(subs_) - 1))
token_ids.append(tokenizer.sep_token_id)
first_tok_masks.append(0)
token_ids = torch.LongTensor([token_ids]).cuda()
predictions = ner_model(token_ids) # (1, seq_len, 3)
predictions = predictions[0].data.cpu().numpy() # (seq_len, 3)
pred_ids = list(np.argmax(predictions, axis=1))
assert len(pred_ids) == len(first_tok_masks)
preds_for_each_word = []
for pred, mask in zip(pred_ids, first_tok_masks):
if mask == 1:
preds_for_each_word.append(label_set[pred])
assert len(preds_for_each_word) == len(tokens)
table = [tokens, preds_for_each_word]
print(tabulate(table))
# train_ner.py command
CUDA_VISIBLE_DEVICES=0 python train_ner.py --exp_name conll2003 --exp_id 1 --model_name roberta-large --lr 3e-5 --seed 111
# gen_entityctrl_data.py command (by default is to process training data)
CUDA_VISIBLE_DEVICES=0 python gen_entityctrl_data.py
CUDA_VISIBLE_DEVICES=0 python gen_entityctrl_data.py --infer_dataname valid_random_split.txt --output_dataname valid_random_split_entity_based_control.txt
CUDA_VISIBLE_DEVICES=0 python gen_entityctrl_data.py --infer_dataname valid_topic_split.txt --output_dataname valid_topic_split_entity_based_control.txt
CUDA_VISIBLE_DEVICES=0 python gen_entityctrl_data.py --infer_dataname test_random_split_seen.txt --output_dataname test_random_split_entity_based_control.txt
CUDA_VISIBLE_DEVICES=0 python gen_entityctrl_data.py --infer_dataname test_topic_split_unseen.txt --output_dataname test_topic_split_entity_based_control.txt
import argparse
def get_params():
parser = argparse.ArgumentParser(description="NER Task")
parser.add_argument("--exp_name", type=str, default="conll2003", help="Experiment name")
parser.add_argument("--logger_filename", type=str, default="train.log")
parser.add_argument("--dump_path", type=str, default="logs", help="Experiment saved root path")
parser.add_argument("--exp_id", type=str, default="1", help="Experiment id")
parser.add_argument("--model_name", type=str, default="roberta-large", help="model name")
parser.add_argument("--seed", type=int, default=111, help="random seed")
# train parameters
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--epoch", type=int, default=300, help="Number of epoch")
parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
parser.add_argument("--early_stop", type=int, default=3, help="No improvement after several epoch, we stop training")
parser.add_argument("--num_tag", type=int, default=3, help="Number of entity in the dataset")
parser.add_argument("--dropout", type=float, default=0.1, help="dropout rate")
parser.add_argument("--hidden_dim", type=int, default=1024, help="Hidden layer dimension")
parser.add_argument("--data_folder", type=str, default="/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/conll2003", help="NER data folder")
parser.add_argument("--saved_folder", type=str, default="/gpfs/fs1/projects/gpu_adlr/datasets/zihanl/checkpoints/ner_model", help="NER data folder")
parser.add_argument("--default_folder", type=str, default="/gpfs/fs1/projects/gpu_adlr/datasets/zihanl")
parser.add_argument("--infer_datafolder", type=str, default="dialog_datasets/wizard_of_wikipedia/processed")
parser.add_argument("--infer_dataname", type=str, default="train.txt")
parser.add_argument("--output_dataname", type=str, default="train_entity_based_control.txt")
params = parser.parse_args()
return params
import torch
import torch.nn as nn
import torch.utils.data as data
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
import os
from tqdm import tqdm
import logging
logger = logging.getLogger()
pad_token_label_id = nn.CrossEntropyLoss().ignore_index
label_set = ["O", "B-ENTITY", "I-ENTITY"]
def read_ner(tokenizer, datapath):
inputs, labels = [], []
with open(datapath, "r") as fr:
token_list, label_list = [], []
for i, line in enumerate(fr):
line = line.strip()
if line == "":
if len(token_list) > 0:
assert len(token_list) == len(label_list)
inputs.append([tokenizer.cls_token_id] + token_list + [tokenizer.sep_token_id])
labels.append([pad_token_label_id] + label_list + [pad_token_label_id])
token_list, label_list = [], []
continue
splits = line.split("\t")
token = splits[0]
label = splits[1]
if label.startswith("B-"):
label = "B-ENTITY"
elif label.startswith("I-"):
label = "I-ENTITY"
subs_ = tokenizer.tokenize(token)
if len(subs_) > 0:
label_list.extend([label_set.index(label)] + [pad_token_label_id] * (len(subs_) - 1))
token_list.extend(tokenizer.convert_tokens_to_ids(subs_))
else:
print("length of subwords for %s is zero; its label is %s" % (token, label))
return inputs, labels
class Dataset(data.Dataset):
def __init__(self, tokenizer, inputs, labels):
self.X = inputs
self.y = labels
self.tokenizer = tokenizer
def __getitem__(self, index):
return self.X[index], self.y[index]
def __len__(self):
return len(self.X)
def collate_fn(self, data):
X, y = zip(*data)
lengths = [len(bs_x) for bs_x in X]
max_lengths = max(lengths)
padded_seqs = torch.LongTensor(len(X), max_lengths).fill_(self.tokenizer.pad_token_id)
padded_y = torch.LongTensor(len(X), max_lengths).fill_(pad_token_label_id)
for i, (seq, y_) in enumerate(zip(X, y)):
length = lengths[i]
padded_seqs[i, :length] = torch.LongTensor(seq)
padded_y[i, :length] = torch.LongTensor(y_)
return padded_seqs, padded_y
def get_dataloader(model_name, batch_size, data_folder):
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs_train, labels_train = read_ner(tokenizer, os.path.join(data_folder, "train.txt"))
inputs_dev, labels_dev = read_ner(tokenizer, os.path.join(data_folder, "dev.txt"))
inputs_test, labels_test = read_ner(tokenizer, os.path.join(data_folder, "test.txt"))
logger.info("conll2003 dataset: train size: %d; dev size %d; test size: %d" % (len(inputs_train), len(inputs_dev), len(inputs_test)))
dataset_train = Dataset(tokenizer, inputs_train, labels_train)
dataset_dev = Dataset(tokenizer, inputs_dev, labels_dev)
dataset_test = Dataset(tokenizer, inputs_test, labels_test)
dataloader_train = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, collate_fn=dataset_train.collate_fn)
dataloader_dev = DataLoader(dataset=dataset_dev, batch_size=batch_size, shuffle=False, collate_fn=dataset_dev.collate_fn)
dataloader_test = DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False, collate_fn=dataset_test.collate_fn)
return dataloader_train, dataloader_dev, dataloader_test
#!/usr/bin/env python
# Python version of the evaluation script from CoNLL'00-
# Intentional differences:
# - accept any space as delimiter by default
# - optional file argument (default STDIN)
# - option to set boundary (-b argument)
# - LaTeX output (-l argument) not supported
# - raw tags (-r argument) not supported
import sys
import re
from collections import defaultdict, namedtuple
ANY_SPACE = '<SPACE>'
class FormatError(Exception):
pass
Metrics = namedtuple('Metrics', 'tp fp fn prec rec fscore')
class EvalCounts(object):
def __init__(self):
self.correct_chunk = 0 # number of correctly identified chunks
self.correct_tags = 0 # number of correct chunk tags
self.found_correct = 0 # number of chunks in corpus
self.found_guessed = 0 # number of identified chunks
self.token_counter = 0 # token counter (ignores sentence breaks)
# counts by type
self.t_correct_chunk = defaultdict(int)
self.t_found_correct = defaultdict(int)
self.t_found_guessed = defaultdict(int)
def parse_args(argv):
import argparse
parser = argparse.ArgumentParser(
description='evaluate tagging results using CoNLL criteria',
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
arg = parser.add_argument
arg('-b', '--boundary', metavar='STR', default='-X-',
help='sentence boundary')
arg('-d', '--delimiter', metavar='CHAR', default=ANY_SPACE,
help='character delimiting items in input')
arg('-o', '--otag', metavar='CHAR', default='O',
help='alternative outside tag')
arg('file', nargs='?', default=None)
return parser.parse_args(argv)
def parse_tag(t):
m = re.match(r'^([^-]*)-(.*)$', t)
return m.groups() if m else (t, '')
def evaluate(lines, options=None):
if options is None:
options = parse_args([]) # use defaults
counts = EvalCounts()
num_features = None # number of features per line
in_correct = False # currently processed chunks is correct until now
last_correct = 'O' # previous chunk tag in corpus
last_correct_type = '' # type of previously identified chunk tag
last_guessed = 'O' # previously identified chunk tag
last_guessed_type = '' # type of previous chunk tag in corpus
for line in lines:
line = line.rstrip('\r\n')
if options.delimiter == ANY_SPACE:
features = line.split()
else:
features = line.split(options.delimiter)
if num_features is None:
num_features = len(features)
elif num_features != len(features) and len(features) != 0:
raise FormatError('unexpected number of features: %d (%d)' %
(len(features), num_features))
if len(features) == 0 or features[0] == options.boundary:
features = [options.boundary, 'O', 'O']
if len(features) < 3:
raise FormatError('unexpected number of features in line %s' % line)
guessed, guessed_type = parse_tag(features.pop())
correct, correct_type = parse_tag(features.pop())
first_item = features.pop(0)
if first_item == options.boundary:
guessed = 'O'
end_correct = end_of_chunk(last_correct, correct,
last_correct_type, correct_type)
end_guessed = end_of_chunk(last_guessed, guessed,
last_guessed_type, guessed_type)
start_correct = start_of_chunk(last_correct, correct,
last_correct_type, correct_type)
start_guessed = start_of_chunk(last_guessed, guessed,
last_guessed_type, guessed_type)
if in_correct:
if (end_correct and end_guessed and
last_guessed_type == last_correct_type):
in_correct = False
counts.correct_chunk += 1
counts.t_correct_chunk[last_correct_type] += 1
elif (end_correct != end_guessed or guessed_type != correct_type):
in_correct = False
if start_correct and start_guessed and guessed_type == correct_type:
in_correct = True
if start_correct:
counts.found_correct += 1
counts.t_found_correct[correct_type] += 1
if start_guessed:
counts.found_guessed += 1
counts.t_found_guessed[guessed_type] += 1
if first_item != options.boundary:
if correct == guessed and guessed_type == correct_type:
counts.correct_tags += 1
counts.token_counter += 1
last_guessed = guessed
last_correct = correct
last_guessed_type = guessed_type
last_correct_type = correct_type
if in_correct:
counts.correct_chunk += 1
counts.t_correct_chunk[last_correct_type] += 1
return counts
def uniq(iterable):
seen = set()
return [i for i in iterable if not (i in seen or seen.add(i))]
def calculate_metrics(correct, guessed, total):
tp, fp, fn = correct, guessed-correct, total-correct
p = 0 if tp + fp == 0 else 1.*tp / (tp + fp)
r = 0 if tp + fn == 0 else 1.*tp / (tp + fn)
f = 0 if p + r == 0 else 2 * p * r / (p + r)
return Metrics(tp, fp, fn, p, r, f)
def metrics(counts):
c = counts
overall = calculate_metrics(
c.correct_chunk, c.found_guessed, c.found_correct
)
by_type = {}
for t in uniq(list(c.t_found_correct) + list(c.t_found_guessed)):
by_type[t] = calculate_metrics(
c.t_correct_chunk[t], c.t_found_guessed[t], c.t_found_correct[t]
)
return overall, by_type
def report(counts, out=None):
if out is None:
out = sys.stdout
overall, by_type = metrics(counts)
c = counts
# out.write('processed %d tokens with %d phrases; ' %
# (c.token_counter, c.found_correct))
# out.write('found: %d phrases; correct: %d.\n' %
# (c.found_guessed, c.correct_chunk))
results = {}
if c.token_counter > 0:
results["fb1"] = 100.*overall.fscore
# comment it to not print details
# for i, m in sorted(by_type.items()):
# print('%17s: ' % i)
# print('precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f %d\n' % (100.*m.prec, 100.*m.rec, 100.*m.fscore, c.t_found_guessed[i]))
return results
def end_of_chunk(prev_tag, tag, prev_type, type_):
# check if a chunk ended between the previous and current word
# arguments: previous and current chunk tags, previous and current types
chunk_end = False
if prev_tag == 'E': chunk_end = True
if prev_tag == 'S': chunk_end = True
if prev_tag == 'B' and tag == 'B': chunk_end = True
if prev_tag == 'B' and tag == 'S': chunk_end = True
if prev_tag == 'B' and tag == 'O': chunk_end = True
if prev_tag == 'I' and tag == 'B': chunk_end = True
if prev_tag == 'I' and tag == 'S': chunk_end = True
if prev_tag == 'I' and tag == 'O': chunk_end = True
if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
chunk_end = True
# these chunks are assumed to have length 1
if prev_tag == ']': chunk_end = True
if prev_tag == '[': chunk_end = True
return chunk_end
def start_of_chunk(prev_tag, tag, prev_type, type_):
# check if a chunk started between the previous and current word
# arguments: previous and current chunk tags, previous and current types
chunk_start = False
if tag == 'B': chunk_start = True
if tag == 'S': chunk_start = True
if prev_tag == 'E' and tag == 'E': chunk_start = True
if prev_tag == 'E' and tag == 'I': chunk_start = True
if prev_tag == 'S' and tag == 'E': chunk_start = True
if prev_tag == 'S' and tag == 'I': chunk_start = True
if prev_tag == 'O' and tag == 'E': chunk_start = True
if prev_tag == 'O' and tag == 'I': chunk_start = True
if tag != 'O' and tag != '.' and prev_type != type_:
chunk_start = True
# these chunks are assumed to have length 1
if tag == '[': chunk_start = True
if tag == ']': chunk_start = True
return chunk_start
def main(argv):
args = parse_args(argv[1:])
if args.file is None:
counts = evaluate(sys.stdin, args)
else:
with open(args.file) as f:
counts = evaluate(f, args)
report(counts)
def conll2002_measure(lines, verbose=False):
counts = evaluate(lines, None)
return report(counts)
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import AutoModel
class EntityTagger(nn.Module):
def __init__(self, params):
super(EntityTagger, self).__init__()
self.num_tag = params.num_tag
self.hidden_dim = params.hidden_dim
self.model = AutoModel.from_pretrained(params.model_name)
self.dropout = nn.Dropout(params.dropout)
self.linear = nn.Linear(self.hidden_dim, self.num_tag)
def forward(self, X):
outputs = self.model(X) # a tuple ((bsz,seq_len,hidden_dim), (bsz, hidden_dim))
outputs = outputs[0] # (bsz, seq_len, hidden_dim)
outputs = self.dropout(outputs)
prediction = self.linear(outputs)
return prediction
import torch
import torch.nn as nn
from src.metrics import *
from src.dataloader import label_set, pad_token_label_id
import os
import numpy as np
from tqdm import tqdm
import logging
logger = logging.getLogger()
class NERTrainer(object):
def __init__(self, params, model):
self.params = params
self.model = model
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=params.lr)
self.loss_fn = nn.CrossEntropyLoss()
self.early_stop = params.early_stop
self.no_improvement_num = 0
self.best_dev_f1 = 0
def train_step(self, X, y):
self.model.train()
preds = self.model(X)
y = y.view(y.size(0)*y.size(1))
preds = preds.view(preds.size(0)*preds.size(1), preds.size(2))
self.optimizer.zero_grad()
loss = self.loss_fn(preds, y)
loss.backward()
self.optimizer.step()
return loss.item()
def train(self, dataloader_train, dataloader_dev, dataloader_test):
logger.info("Start NER training ...")
for e in range(self.params.epoch):
logger.info("============== epoch %d ==============" % e)
loss_list = []
pbar = tqdm(enumerate(dataloader_train), total=len(dataloader_train))
for i, (X, y) in pbar:
X, y = X.cuda(), y.cuda()
loss = self.train_step(X, y)
loss_list.append(loss)
pbar.set_description("(Epoch {}) LOSS:{:.4f}".format(e, np.mean(loss_list)))
logger.info("Finish training epoch %d. loss: %.4f" % (e, np.mean(loss_list)))
logger.info("============== Evaluate epoch %d on Dev Set ==============" % e)
f1_dev = self.evaluate(dataloader_dev)
logger.info("Evaluate on Dev Set. F1: %.4f." % f1_dev)
if f1_dev > self.best_dev_f1:
logger.info("Found better model!!")
self.best_dev_f1 = f1_dev
self.no_improvement_num = 0
self.save_model()
else:
self.no_improvement_num += 1
logger.info("No better model found (%d/%d)" % (self.no_improvement_num, self.early_stop))
if self.no_improvement_num >= self.early_stop:
break
logger.info("============== Evaluate on Test Set ==============")
f1_test = self.evaluate(dataloader_test)
logger.info("Evaluate on Test Set. F1: %.4f." % f1_test)
def evaluate(self, dataloader):
self.model.eval()
pred_list = []
y_list = []
pbar = tqdm(enumerate(dataloader), total=len(dataloader))
for i, (X, y) in pbar:
y_list.extend(y.data.numpy()) # y is a list
X = X.cuda()
preds = self.model(X)
pred_list.extend(preds.data.cpu().numpy())
# concatenation
pred_list = np.concatenate(pred_list, axis=0) # (length, num_tag)
pred_list = np.argmax(pred_list, axis=1)
y_list = np.concatenate(y_list, axis=0)
# calcuate f1 score
pred_list = list(pred_list)
y_list = list(y_list)
lines = []
for pred_index, gold_index in zip(pred_list, y_list):
gold_index = int(gold_index)
if gold_index != pad_token_label_id:
pred_token = label_set[pred_index]
gold_token = label_set[gold_index]
lines.append("w" + " " + pred_token + " " + gold_token)
results = conll2002_measure(lines)
f1 = results["fb1"]
return f1
def save_model(self):
"""
save the best model
"""
saved_path = os.path.join(self.params.saved_folder, self.params.model_name+".pt")
torch.save({
"model": self.model,
}, saved_path)
logger.info("Best model has been saved to %s" % saved_path)
import os
import subprocess
import pickle
import logging
import time
import random
from datetime import timedelta
import numpy as np
def init_experiment(params, logger_filename):
"""
Initialize the experiment:
- save parameters
- create a logger
"""
# save parameters
get_saved_path(params)
pickle.dump(params, open(os.path.join(params.dump_path, "params.pkl"), "wb"))
# create a logger
logger = create_logger(os.path.join(params.dump_path, logger_filename))
logger.info('============ Initialized logger ============')
logger.info('\n'.join('%s: %s' % (k, str(v))
for k, v in sorted(dict(vars(params)).items())))
logger.info('The experiment will be stored in %s\n' % params.dump_path)
return logger
class LogFormatter():
def __init__(self):
self.start_time = time.time()
def format(self, record):
elapsed_seconds = round(record.created - self.start_time)
prefix = "%s - %s - %s" % (
record.levelname,
time.strftime('%x %X'),
timedelta(seconds=elapsed_seconds)
)
message = record.getMessage()
message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3))
return "%s - %s" % (prefix, message) if message else ''
def create_logger(filepath):
# create log formatter
log_formatter = LogFormatter()
# create file handler and set level to debug
if filepath is not None:
file_handler = logging.FileHandler(filepath, "a")
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(log_formatter)
# create console handler and set level to info
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(log_formatter)
# create logger and set level to debug
logger = logging.getLogger()
logger.handlers = []
logger.setLevel(logging.DEBUG)
logger.propagate = False
if filepath is not None:
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# reset logger elapsed time
def reset_time():
log_formatter.start_time = time.time()
logger.reset_time = reset_time
return logger
def get_saved_path(params):
"""
create a directory to store the experiment
"""
dump_path = "./" if params.dump_path == "" else params.dump_path
if not os.path.isdir(dump_path):
subprocess.Popen("mkdir -p %s" % dump_path, shell=True).wait()
assert os.path.isdir(dump_path)
# create experiment path if it does not exist
exp_path = os.path.join(dump_path, params.exp_name)
if not os.path.exists(exp_path):
subprocess.Popen("mkdir -p %s" % exp_path, shell=True).wait()
# generate id for this experiment
if params.exp_id == "":
chars = "0123456789"
while True:
exp_id = "".join(random.choice(chars) for _ in range(0, 3))
if not os.path.isdir(os.path.join(exp_path, exp_id)):
break
else:
exp_id = params.exp_id
# update dump_path
params.dump_path = os.path.join(exp_path, exp_id)
if not os.path.isdir(params.dump_path):
subprocess.Popen("mkdir -p %s" % params.dump_path, shell=True).wait()
assert os.path.isdir(params.dump_path)
from src.config import get_params
from src.utils import init_experiment
from src.dataloader import get_dataloader
from src.model import EntityTagger
from src.trainer import NERTrainer
import torch
import numpy as np
import random
def random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def train_ner(params):
# initialize experiment
logger = init_experiment(params, logger_filename=params.logger_filename)
# dataloader
dataloader_train, dataloader_dev, dataloader_test = get_dataloader(params.model_name, params.batch_size, params.data_folder)
# BERT-based NER Tagger
model = EntityTagger(params)
model.cuda()
# trainer
trainer = NERTrainer(params, model)
trainer.train(dataloader_train, dataloader_dev, dataloader_test)
if __name__ == "__main__":
params = get_params()
random_seed(params.seed)
train_ner(params)
import torch
from megatron import print_rank_0
def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
"""Build attention masks and position id for left to right model."""
micro_batch_size, seq_length = data.size()
# Attention mask
attention_mask = torch.tril(torch.ones((micro_batch_size, seq_length, seq_length), device=data.device)).view(micro_batch_size, 1, seq_length, seq_length)
# mask padded tokens
for b in range(micro_batch_size):
for idx in range(seq_length-1):
if data[b, idx] == eod_token_id:
# pad tokens that come after the eod token
attention_mask[b, 0, idx+1:, :] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# # reset attentino mask and position ids
# # Loop through the batches:
# for b in range(micro_batch_size):
# # Find indecies where EOD token is.
# eod_index = position_ids[b, data[b] == eod_token_id]
# eod_index = eod_index.clone()
# # Loop through EOD indecies:
# prev_index = 0
# for j in range(eod_index.size()[0]):
# i = eod_index[j]
# # Mask attention loss.
# attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# # Reset positions.
# position_ids[b, (i + 1):] -= (i + 1 - prev_index)
# prev_index = i + 1
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask, position_ids
\ No newline at end of file
"""Train dialogue model based on GPT"""
import torch
from functools import partial
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron import mpu
# from megatron.data.gpt_dataset import build_train_valid_test_datasets
from dialogctrl.dialog_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel
from megatron.training import pretrain
# from megatron.utils import get_ltor_masks_and_position_ids
from dialogctrl.utils import get_ltor_attention_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
def model_provider(pre_process=True, post_process=True):
"""Build the model."""
print_rank_0('building GPT model ...')
model = GPTModel(
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def get_batch(data_iterator):
"""Generate a batch"""
args = get_args()
tokenizer = get_tokenizer()
# Items and their type.
keys = ['text', 'loss_mask']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = mpu.broadcast_data(keys, data, datatype)
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
loss_mask = data_b['loss_mask'].float()
# Get the attention_mask and postition ids.
attention_mask, position_ids = get_ltor_attention_masks_and_position_ids(tokens, tokenizer.eod_id)
return tokens, labels, loss_mask, attention_mask, position_ids
def loss_func(loss_mask, output_tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model):
"""Forward step."""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def train_valid_test_datasets_provider():
"""Build train, valid, and test datasets for dialog/control module"""
args = get_args()
print_rank_0('> building train, validation, and test datasets for %s module ...' % args.train_module)
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_folder=args.data_folder,
dataset_name=args.dataset_name,
train_module=args.train_module,
max_seq_len=args.max_seq_len,
seed=args.seed)
print_rank_0("> finished creating datasets for %s module ..." % args.train_module)
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
pretrain(train_valid_test_datasets_provider, model_provider, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment