Commit 611961ad authored by piero's avatar piero Committed by Julien Chaumond
Browse files

Added tqdm to preprocessing

parent afc7dcd9
...@@ -18,13 +18,14 @@ import torch.utils.data as data ...@@ -18,13 +18,14 @@ import torch.utils.data as data
from nltk.tokenize.treebank import TreebankWordDetokenizer from nltk.tokenize.treebank import TreebankWordDetokenizer
from torchtext import data as torchtext_data from torchtext import data as torchtext_data
from torchtext import datasets from torchtext import datasets
from tqdm import tqdm, trange
from transformers import GPT2Tokenizer, GPT2LMHeadModel from transformers import GPT2Tokenizer, GPT2LMHeadModel
torch.manual_seed(0) torch.manual_seed(0)
np.random.seed(0) np.random.seed(0)
EPSILON = 1e-10 EPSILON = 1e-10
device = 'cpu' device = "cpu"
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had." example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
max_length_seq = 100 max_length_seq = 100
...@@ -109,8 +110,8 @@ class Dataset(data.Dataset): ...@@ -109,8 +110,8 @@ class Dataset(data.Dataset):
def __getitem__(self, index): def __getitem__(self, index):
"""Returns one data pair (source and target).""" """Returns one data pair (source and target)."""
data = {} data = {}
data['X'] = self.X[index] data["X"] = self.X[index]
data['y'] = self.y[index] data["y"] = self.y[index]
return data return data
...@@ -133,8 +134,8 @@ def collate_fn(data): ...@@ -133,8 +134,8 @@ def collate_fn(data):
for key in data[0].keys(): for key in data[0].keys():
item_info[key] = [d[key] for d in data] item_info[key] = [d[key] for d in data]
x_batch, _ = pad_sequences(item_info['X']) x_batch, _ = pad_sequences(item_info["X"])
y_batch = torch.tensor(item_info['y'], dtype=torch.long) y_batch = torch.tensor(item_info["y"], dtype=torch.long)
return x_batch, y_batch return x_batch, y_batch
...@@ -144,8 +145,8 @@ def cached_collate_fn(data): ...@@ -144,8 +145,8 @@ def cached_collate_fn(data):
for key in data[0].keys(): for key in data[0].keys():
item_info[key] = [d[key] for d in data] item_info[key] = [d[key] for d in data]
x_batch = torch.cat(item_info['X'], 0) x_batch = torch.cat(item_info["X"], 0)
y_batch = torch.tensor(item_info['y'], dtype=torch.long) y_batch = torch.tensor(item_info["y"], dtype=torch.long)
return x_batch, y_batch return x_batch, y_batch
...@@ -168,7 +169,7 @@ def train_epoch(data_loader, discriminator, optimizer, ...@@ -168,7 +169,7 @@ def train_epoch(data_loader, discriminator, optimizer,
if batch_idx % log_interval == 0: if batch_idx % log_interval == 0:
print( print(
'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch + 1, epoch + 1,
samples_so_far, len(data_loader.dataset), samples_so_far, len(data_loader.dataset),
100 * samples_so_far / len(data_loader.dataset), loss.item() 100 * samples_so_far / len(data_loader.dataset), loss.item()
...@@ -185,7 +186,7 @@ def evaluate_performance(data_loader, discriminator): ...@@ -185,7 +186,7 @@ def evaluate_performance(data_loader, discriminator):
input_t, target_t = input_t.to(device), target_t.to(device) input_t, target_t = input_t.to(device), target_t.to(device)
output_t = discriminator(input_t) output_t = discriminator(input_t)
# sum up batch loss # sum up batch loss
test_loss += F.nll_loss(output_t, target_t, reduction='sum').item() test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
# get the index of the max log-probability # get the index of the max log-probability
pred_t = output_t.argmax(dim=1, keepdim=True) pred_t = output_t.argmax(dim=1, keepdim=True)
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item() correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
...@@ -193,8 +194,8 @@ def evaluate_performance(data_loader, discriminator): ...@@ -193,8 +194,8 @@ def evaluate_performance(data_loader, discriminator):
test_loss /= len(data_loader.dataset) test_loss /= len(data_loader.dataset)
print( print(
'Performance on test set: ' "Performance on test set: "
'Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( "Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
test_loss, correct, len(data_loader.dataset), test_loss, correct, len(data_loader.dataset),
100. * correct / len(data_loader.dataset) 100. * correct / len(data_loader.dataset)
) )
...@@ -208,8 +209,8 @@ def predict(input_sentence, model, classes, cached=False): ...@@ -208,8 +209,8 @@ def predict(input_sentence, model, classes, cached=False):
input_t = model.avg_representation(input_t) input_t = model.avg_representation(input_t)
log_probs = model(input_t).data.cpu().numpy().flatten().tolist() log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
print('Input sentence:', input_sentence) print("Input sentence:", input_sentence)
print('Predictions:', ", ".join( print("Predictions:", ", ".join(
"{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in "{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in
zip(classes, log_probs) zip(classes, log_probs)
)) ))
...@@ -222,7 +223,7 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False): ...@@ -222,7 +223,7 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
xs = [] xs = []
ys = [] ys = []
for batch_idx, (x, y) in enumerate(data_loader): for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
with torch.no_grad(): with torch.no_grad():
x = x.to(device) x = x.to(device)
avg_rep = discriminator.avg_representation(x).cpu().detach() avg_rep = discriminator.avg_representation(x).cpu().detach()
...@@ -240,16 +241,16 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False): ...@@ -240,16 +241,16 @@ def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False):
def train_discriminator( def train_discriminator(
dataset, dataset_fp=None, pretrained_model='gpt2-medium', dataset, dataset_fp=None, pretrained_model="gpt2-medium",
epochs=10, batch_size=64, log_interval=10, epochs=10, batch_size=64, log_interval=10,
save_model=False, cached=False, no_cuda=False): save_model=False, cached=False, no_cuda=False):
global device global device
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu" device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
print('Preprocessing {} dataset...'.format(dataset)) print("Preprocessing {} dataset...".format(dataset))
start = time.time() start = time.time()
if dataset == 'SST': if dataset == "SST":
idx2class = ["positive", "negative", "very positive", "very negative", idx2class = ["positive", "negative", "very positive", "very negative",
"neutral"] "neutral"]
class2idx = {c: i for i, c in enumerate(idx2class)} class2idx = {c: i for i, c in enumerate(idx2class)}
...@@ -271,7 +272,7 @@ def train_discriminator( ...@@ -271,7 +272,7 @@ def train_discriminator(
x = [] x = []
y = [] y = []
for i in range(len(train_data)): for i in trange(len(train_data), ascii=True):
seq = TreebankWordDetokenizer().detokenize( seq = TreebankWordDetokenizer().detokenize(
vars(train_data[i])["text"] vars(train_data[i])["text"]
) )
...@@ -283,7 +284,7 @@ def train_discriminator( ...@@ -283,7 +284,7 @@ def train_discriminator(
test_x = [] test_x = []
test_y = [] test_y = []
for i in range(len(test_data)): for i in trange(len(test_data), ascii=True):
seq = TreebankWordDetokenizer().detokenize( seq = TreebankWordDetokenizer().detokenize(
vars(test_data[i])["text"] vars(test_data[i])["text"]
) )
...@@ -301,7 +302,7 @@ def train_discriminator( ...@@ -301,7 +302,7 @@ def train_discriminator(
"default_class": 2, "default_class": 2,
} }
elif dataset == 'clickbait': elif dataset == "clickbait":
idx2class = ["non_clickbait", "clickbait"] idx2class = ["non_clickbait", "clickbait"]
class2idx = {c: i for i, c in enumerate(idx2class)} class2idx = {c: i for i, c in enumerate(idx2class)}
...@@ -317,15 +318,16 @@ def train_discriminator( ...@@ -317,15 +318,16 @@ def train_discriminator(
try: try:
data.append(eval(line)) data.append(eval(line))
except: except:
print('Error evaluating line {}: {}'.format( print("Error evaluating line {}: {}".format(
i, line i, line
)) ))
continue continue
x = [] x = []
y = [] y = []
y = [] with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
for i, d in enumerate(data): for i, line in enumerate(tqdm(f, ascii=True)):
try: try:
d = eval(line)
seq = discriminator.tokenizer.encode(d["text"]) seq = discriminator.tokenizer.encode(d["text"])
if len(seq) < max_length_seq: if len(seq) < max_length_seq:
...@@ -338,9 +340,10 @@ def train_discriminator( ...@@ -338,9 +340,10 @@ def train_discriminator(
)) ))
continue continue
x.append(seq) x.append(seq)
y.append(d['label']) y.append(d["label"])
except: except:
print("Error tokenizing line {}, skipping it".format(i)) print("Error evaluating / tokenizing"
" line {}, skipping it".format(i))
pass pass
full_dataset = Dataset(x, y) full_dataset = Dataset(x, y)
...@@ -358,7 +361,7 @@ def train_discriminator( ...@@ -358,7 +361,7 @@ def train_discriminator(
"default_class": 1, "default_class": 1,
} }
elif dataset == 'toxic': elif dataset == "toxic":
idx2class = ["non_toxic", "toxic"] idx2class = ["non_toxic", "toxic"]
class2idx = {c: i for i, c in enumerate(idx2class)} class2idx = {c: i for i, c in enumerate(idx2class)}
...@@ -368,21 +371,12 @@ def train_discriminator( ...@@ -368,21 +371,12 @@ def train_discriminator(
cached_mode=cached cached_mode=cached
).to(device) ).to(device)
with open("datasets/toxic/toxic_train.txt") as f:
data = []
for i, line in enumerate(f):
try:
data.append(eval(line))
except:
print('Error evaluating line {}: {}'.format(
i, line
))
continue
x = [] x = []
y = [] y = []
for i, d in enumerate(data): with open("datasets/toxic/toxic_train.txt") as f:
for i, line in enumerate(tqdm(f, ascii=True)):
try: try:
d = eval(line)
seq = discriminator.tokenizer.encode(d["text"]) seq = discriminator.tokenizer.encode(d["text"])
if len(seq) < max_length_seq: if len(seq) < max_length_seq:
...@@ -395,9 +389,10 @@ def train_discriminator( ...@@ -395,9 +389,10 @@ def train_discriminator(
)) ))
continue continue
x.append(seq) x.append(seq)
y.append(int(np.sum(d['label']) > 0)) y.append(int(np.sum(d["label"]) > 0))
except: except:
print("Error tokenizing line {}, skipping it".format(i)) print("Error evaluating / tokenizing"
" line {}, skipping it".format(i))
pass pass
full_dataset = Dataset(x, y) full_dataset = Dataset(x, y)
...@@ -415,18 +410,18 @@ def train_discriminator( ...@@ -415,18 +410,18 @@ def train_discriminator(
"default_class": 0, "default_class": 0,
} }
else: # if dataset == 'generic': else: # if dataset == "generic":
# This assumes the input dataset is a TSV with the following structure: # This assumes the input dataset is a TSV with the following structure:
# class \t text # class \t text
if dataset_fp is None: if dataset_fp is None:
raise ValueError('When generic dataset is selected, ' raise ValueError("When generic dataset is selected, "
'dataset_fp needs to be specified aswell.') "dataset_fp needs to be specified aswell.")
classes = set() classes = set()
with open(dataset_fp) as f: with open(dataset_fp) as f:
csv_reader = csv.reader(f, delimiter='\t') csv_reader = csv.reader(f, delimiter="\t")
for row in csv_reader: for row in tqdm(csv_reader, ascii=True):
if row: if row:
classes.add(row[0]) classes.add(row[0])
...@@ -442,8 +437,8 @@ def train_discriminator( ...@@ -442,8 +437,8 @@ def train_discriminator(
x = [] x = []
y = [] y = []
with open(dataset_fp) as f: with open(dataset_fp) as f:
csv_reader = csv.reader(f, delimiter='\t') csv_reader = csv.reader(f, delimiter="\t")
for i, row in enumerate(csv_reader): for i, row in enumerate(tqdm(csv_reader, ascii=True)):
if row: if row:
label = row[0] label = row[0]
text = row[1] text = row[1]
...@@ -458,7 +453,8 @@ def train_discriminator( ...@@ -458,7 +453,8 @@ def train_discriminator(
) )
else: else:
print("Line {} is longer than maximum length {}".format( print(
"Line {} is longer than maximum length {}".format(
i, max_length_seq i, max_length_seq
)) ))
continue continue
...@@ -487,12 +483,14 @@ def train_discriminator( ...@@ -487,12 +483,14 @@ def train_discriminator(
} }
end = time.time() end = time.time()
print('Preprocessed {} data points'.format( print("Preprocessed {} data points".format(
len(train_dataset) + len(test_dataset)) len(train_dataset) + len(test_dataset))
) )
print("Data preprocessing took: {:.3f}s".format(end - start)) print("Data preprocessing took: {:.3f}s".format(end - start))
if cached: if cached:
print("Building representation cache...")
start = time.time() start = time.time()
train_loader = get_cached_data_loader( train_loader = get_cached_data_loader(
...@@ -524,7 +522,7 @@ def train_discriminator( ...@@ -524,7 +522,7 @@ def train_discriminator(
for epoch in range(epochs): for epoch in range(epochs):
start = time.time() start = time.time()
print('\nEpoch', epoch + 1) print("\nEpoch", epoch + 1)
train_epoch( train_epoch(
discriminator=discriminator, discriminator=discriminator,
...@@ -553,31 +551,31 @@ def train_discriminator( ...@@ -553,31 +551,31 @@ def train_discriminator(
"{}_classifier_head_epoch_{}.pt".format(dataset, epoch)) "{}_classifier_head_epoch_{}.pt".format(dataset, epoch))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Train a discriminator on top of GPT-2 representations') description="Train a discriminator on top of GPT-2 representations")
parser.add_argument('--dataset', type=str, default='SST', parser.add_argument("--dataset", type=str, default="SST",
choices=('SST', 'clickbait', 'toxic', 'generic'), choices=("SST", "clickbait", "toxic", "generic"),
help='dataset to train the discriminator on.' help="dataset to train the discriminator on."
'In case of generic, the dataset is expected' "In case of generic, the dataset is expected"
'to be a TSBV file with structure: class \\t text') "to be a TSBV file with structure: class \\t text")
parser.add_argument('--dataset_fp', type=str, default='', parser.add_argument("--dataset_fp", type=str, default="",
help='File path of the dataset to use. ' help="File path of the dataset to use. "
'Needed only in case of generic datadset') "Needed only in case of generic datadset")
parser.add_argument('--pretrained_model', type=str, default='gpt2-medium', parser.add_argument("--pretrained_model", type=str, default="gpt2-medium",
help='Pretrained model to use as encoder') help="Pretrained model to use as encoder")
parser.add_argument('--epochs', type=int, default=10, metavar='N', parser.add_argument("--epochs", type=int, default=10, metavar="N",
help='Number of training epochs') help="Number of training epochs")
parser.add_argument('--batch_size', type=int, default=64, metavar='N', parser.add_argument("--batch_size", type=int, default=64, metavar="N",
help='input batch size for training (default: 64)') help="input batch size for training (default: 64)")
parser.add_argument('--log_interval', type=int, default=10, metavar='N', parser.add_argument("--log_interval", type=int, default=10, metavar="N",
help='how many batches to wait before logging training status') help="how many batches to wait before logging training status")
parser.add_argument('--save_model', action='store_true', parser.add_argument("--save_model", action="store_true",
help='whether to save the model') help="whether to save the model")
parser.add_argument('--cached', action='store_true', parser.add_argument("--cached", action="store_true",
help='whether to cache the input representations') help="whether to cache the input representations")
parser.add_argument('--no_cuda', action='store_true', parser.add_argument("--no_cuda", action="store_true",
help='use to turn off cuda') help="use to turn off cuda")
args = parser.parse_args() args = parser.parse_args()
train_discriminator(**(vars(args))) train_discriminator(**(vars(args)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment