Unverified Commit 562f8640 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge branch 'master' into fix-xlnet-squad2.0

parents ca99a2d5 8618bf15
......@@ -21,7 +21,6 @@ import psutil
import time
from tqdm import trange, tqdm
import numpy as np
import psutil
import torch
import torch.nn as nn
......
......@@ -3,4 +3,4 @@ tensorboard>=1.14.0
tensorboardX==1.8
psutil==5.6.3
scipy==1.3.1
transformers==2.0.0
transformers
# Plug and Play Language Models: a Simple Approach to Controlled Text Generation
Authors: [Sumanth Dathathri](https://dathath.github.io/), [Andrea Madotto](https://andreamad8.github.io/), Janice Lan, Jane Hung, Eric Frank, [Piero Molino](https://w4nderlu.st/), [Jason Yosinski](http://yosinski.com/), and [Rosanne Liu](http://www.rosanneliu.com/)
This folder contains the original code used to run the Plug and Play Language Model (PPLM).
Paper link: https://arxiv.org/abs/1912.02164
Blog link: https://eng.uber.com/pplm
Please check out the repo under uber-research for more information: https://github.com/uber-research/PPLM
## Setup
```bash
git clone https://github.com/huggingface/transformers && cd transformers
pip install [--editable] .
pip install nltk torchtext # additional requirements.
cd examples/pplm
```
## PPLM-BoW
### Example command for bag-of-words control
```bash
python run_pplm.py -B military --cond_text "The potato" --length 50 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.03 --window_length 5 --kl_scale 0.01 --gm_scale 0.99 --colorama --sample
```
### Tuning hyperparameters for bag-of-words control
1. Increase `--stepsize` to intensify topic control, and decrease its value to soften the control. `--stepsize 0` recovers the original uncontrolled GPT-2 model.
2. If the language being generated is repetitive (For e.g. "science science experiment experiment"), there are several options to consider: </br>
a) Reduce the `--stepsize` </br>
b) Increase `--kl_scale` (the KL-loss coefficient) or decrease `--gm_scale` (the gm-scaling term) </br>
c) Add `--grad-length xx` where xx is an (integer <= length, e.g. `--grad-length 30`).</br>
## PPLM-Discrim
### Example command for discriminator based sentiment control
```bash
python run_pplm.py -D sentiment --class_label 2 --cond_text "My dog died" --length 50 --gamma 1.0 --num_iterations 10 --num_samples 10 --stepsize 0.04 --kl_scale 0.01 --gm_scale 0.95 --sample
```
### Tuning hyperparameters for discriminator control
1. Increase `--stepsize` to intensify topic control, and decrease its value to soften the control. `--stepsize 0` recovers the original uncontrolled GPT-2 model.
2. Use `--class_label 3` for negative, and `--class_label 2` for positive
import torch
class ClassificationHead(torch.nn.Module):
"""Classification Head for transformer encoders"""
def __init__(self, class_size, embed_size):
super(ClassificationHead, self).__init__()
self.class_size = class_size
self.embed_size = embed_size
# self.mlp1 = torch.nn.Linear(embed_size, embed_size)
# self.mlp2 = (torch.nn.Linear(embed_size, class_size))
self.mlp = torch.nn.Linear(embed_size, class_size)
def forward(self, hidden_state):
# hidden_state = F.relu(self.mlp1(hidden_state))
# hidden_state = self.mlp2(hidden_state)
logits = self.mlp(hidden_state)
return logits
#! /usr/bin/env python3
# coding=utf-8
#Copyright (c) 2019 Uber Technologies, Inc.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
"""
Example command with bag of words:
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
Example command with discriminator:
python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
"""
import argparse
import json
from operator import add
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from tqdm import trange
from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModel
from pplm_classification_head import ClassificationHead
PPLM_BOW = 1
PPLM_DISCRIM = 2
PPLM_BOW_DISCRIM = 3
SMALL_CONST = 1e-15
BIG_CONST = 1e10
BAG_OF_WORDS_ARCHIVE_MAP = {
'legal': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/legal.txt",
'military': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/military.txt",
'politics': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/politics.txt",
'religion': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/religion.txt",
'science': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/science.txt",
'space': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/space.txt",
'technology': "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/bow/technology.txt",
}
DISCRIMINATOR_MODELS_PARAMS = {
"clickbait": {
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/clickbait_classifier_head.pt",
"class_size": 2,
"embed_size": 1024,
"class_vocab": {"non_clickbait": 0, "clickbait": 1},
"default_class": 1,
"pretrained_model": "gpt2-medium",
},
"sentiment": {
"url": "https://s3.amazonaws.com/models.huggingface.co/bert/pplm/discriminators/SST_classifier_head.pt",
"class_size": 5,
"embed_size": 1024,
"class_vocab": {"very_positive": 2, "very_negative": 3},
"default_class": 3,
"pretrained_model": "gpt2-medium",
},
}
def to_var(x, requires_grad=False, volatile=False, device='cuda'):
if torch.cuda.is_available() and device == 'cuda':
x = x.cuda()
elif device != 'cuda':
x = x.to(device)
return Variable(x, requires_grad=requires_grad, volatile=volatile)
def top_k_filter(logits, k, probs=False):
"""
Masks everything but the k top entries as -infinity (1e10).
Used to mask logits such that e^-infinity -> 0 won't contribute to the
sum of the denominator.
"""
if k == 0:
return logits
else:
values = torch.topk(logits, k)[0]
batch_mins = values[:, -1].view(-1, 1).expand_as(logits)
if probs:
return torch.where(logits < batch_mins,
torch.ones_like(logits) * 0.0, logits)
return torch.where(logits < batch_mins,
torch.ones_like(logits) * -BIG_CONST,
logits)
def perturb_past(
past,
model,
last,
unpert_past=None,
unpert_logits=None,
accumulated_hidden=None,
grad_norms=None,
stepsize=0.01,
one_hot_bows_vectors=None,
classifier=None,
class_label=None,
loss_type=0,
num_iterations=3,
horizon_length=1,
window_length=0,
decay=False,
gamma=1.5,
kl_scale=0.01,
device='cuda',
):
# Generate inital perturbed past
grad_accumulator = [
(np.zeros(p.shape).astype("float32"))
for p in past
]
if accumulated_hidden is None:
accumulated_hidden = 0
if decay:
decay_mask = torch.arange(
0.,
1.0 + SMALL_CONST,
1.0 / (window_length)
)[1:]
else:
decay_mask = 1.0
# TODO fix this comment (SUMANTH)
# Generate a mask is gradient perturbated is based on a past window
_, _, _, curr_length, _ = past[0].shape
if curr_length > window_length and window_length > 0:
ones_key_val_shape = (
tuple(past[0].shape[:-2])
+ tuple([window_length])
+ tuple(past[0].shape[-1:])
)
zeros_key_val_shape = (
tuple(past[0].shape[:-2])
+ tuple([curr_length - window_length])
+ tuple(past[0].shape[-1:])
)
ones_mask = torch.ones(ones_key_val_shape)
ones_mask = decay_mask * ones_mask.permute(0, 1, 2, 4, 3)
ones_mask = ones_mask.permute(0, 1, 2, 4, 3)
window_mask = torch.cat(
(ones_mask, torch.zeros(zeros_key_val_shape)),
dim=-2
).to(device)
else:
window_mask = torch.ones_like(past[0]).to(device)
# accumulate perturbations for num_iterations
loss_per_iter = []
new_accumulated_hidden = None
for i in range(num_iterations):
print("Iteration ", i + 1)
curr_perturbation = [
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
for p_ in grad_accumulator
]
# Compute hidden using perturbed past
perturbed_past = list(map(add, past, curr_perturbation))
_, _, _, curr_length, _ = curr_perturbation[0].shape
all_logits, _, all_hidden = model(last, past=perturbed_past)
hidden = all_hidden[-1]
new_accumulated_hidden = accumulated_hidden + torch.sum(
hidden,
dim=1
).detach()
# TODO: Check the layer-norm consistency of this with trained discriminator (Sumanth)
logits = all_logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
loss = 0.0
loss_list = []
if loss_type == PPLM_BOW or loss_type == PPLM_BOW_DISCRIM:
for one_hot_bow in one_hot_bows_vectors:
bow_logits = torch.mm(probs, torch.t(one_hot_bow))
bow_loss = -torch.log(torch.sum(bow_logits))
loss += bow_loss
loss_list.append(bow_loss)
print(" pplm_bow_loss:", loss.data.cpu().numpy())
if loss_type == 2 or loss_type == 3:
ce_loss = torch.nn.CrossEntropyLoss()
# TODO why we need to do this assignment and not just using unpert_past? (Sumanth)
curr_unpert_past = unpert_past
curr_probs = torch.unsqueeze(probs, dim=1)
wte = model.resize_token_embeddings()
for _ in range(horizon_length):
inputs_embeds = torch.matmul(curr_probs, wte.weight.data)
_, curr_unpert_past, curr_all_hidden = model(
past=curr_unpert_past,
inputs_embeds=inputs_embeds
)
curr_hidden = curr_all_hidden[-1]
new_accumulated_hidden = new_accumulated_hidden + torch.sum(
curr_hidden, dim=1)
prediction = classifier(new_accumulated_hidden /
(curr_length + 1 + horizon_length))
label = torch.tensor(prediction.shape[0] * [class_label],
device=device,
dtype=torch.long)
discrim_loss = ce_loss(prediction, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
loss += discrim_loss
loss_list.append(discrim_loss)
kl_loss = 0.0
if kl_scale > 0.0:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = (
unpert_probs + SMALL_CONST *
(unpert_probs <= SMALL_CONST).float().to(device).detach()
)
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(
device).detach()
corrected_probs = probs + correction.detach()
kl_loss = kl_scale * (
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
)
print(' kl_loss', kl_loss.data.cpu().numpy())
loss += kl_loss
loss_per_iter.append(loss.data.cpu().numpy())
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
# compute gradients
loss.backward()
# calculate gradient norms
if grad_norms is not None and loss_type == PPLM_BOW:
grad_norms = [
torch.max(grad_norms[index], torch.norm(p_.grad * window_mask))
for index, p_ in enumerate(curr_perturbation)
]
else:
grad_norms = [
(torch.norm(p_.grad * window_mask) + SMALL_CONST)
for index, p_ in enumerate(curr_perturbation)
]
# normalize gradients
grad = [
-stepsize *
(p_.grad * window_mask / grad_norms[
index] ** gamma).data.cpu().numpy()
for index, p_ in enumerate(curr_perturbation)
]
# accumulate gradient
grad_accumulator = list(map(add, grad, grad_accumulator))
# reset gradients, just to make sure
for p_ in curr_perturbation:
p_.grad.data.zero_()
# removing past from the graph
new_past = []
for p_ in past:
new_past.append(p_.detach())
past = new_past
# apply the accumulated perturbations to the past
grad_accumulator = [
to_var(torch.from_numpy(p_), requires_grad=True, device=device)
for p_ in grad_accumulator
]
pert_past = list(map(add, past, grad_accumulator))
return pert_past, new_accumulated_hidden, grad_norms, loss_per_iter
def get_classifier(
name: Optional[str], class_label: Union[str, int],
device: str
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
if name is None:
return None, None
params = DISCRIMINATOR_MODELS_PARAMS[name]
classifier = ClassificationHead(
class_size=params['class_size'],
embed_size=params['embed_size']
).to(device)
if "url" in params:
resolved_archive_file = cached_path(params["url"])
elif "path" in params:
resolved_archive_file = params["path"]
else:
raise ValueError("Either url or path have to be specified "
"in the discriminator model parameters")
classifier.load_state_dict(
torch.load(resolved_archive_file, map_location=device))
classifier.eval()
if isinstance(class_label, str):
if class_label in params["class_vocab"]:
label_id = params["class_vocab"][class_label]
else:
label_id = params["default_class"]
print("class_label {} not in class_vocab".format(class_label))
print("available values are: {}".format(params["class_vocab"]))
print("using default class {}".format(label_id))
elif isinstance(class_label, int):
if class_label in set(params["class_vocab"].values()):
label_id = class_label
else:
label_id = params["default_class"]
print("class_label {} not in class_vocab".format(class_label))
print("available values are: {}".format(params["class_vocab"]))
print("using default class {}".format(label_id))
else:
label_id = params["default_class"]
return classifier, label_id
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
List[List[List[int]]]:
bow_indices = []
for id_or_path in bag_of_words_ids_or_paths:
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
filepath = cached_path(BAG_OF_WORDS_ARCHIVE_MAP[id_or_path])
else:
filepath = id_or_path
with open(filepath, "r") as f:
words = f.read().strip().split("\n")
bow_indices.append(
[tokenizer.encode(word.strip(), add_prefix_space=True) for word in
words])
return bow_indices
def build_bows_one_hot_vectors(bow_indices, tokenizer, device='cuda'):
if bow_indices is None:
return None
one_hot_bows_vectors = []
for single_bow in bow_indices:
single_bow = list(filter(lambda x: len(x) <= 1, single_bow))
single_bow = torch.tensor(single_bow).to(device)
num_words = single_bow.shape[0]
one_hot_bow = torch.zeros(num_words, tokenizer.vocab_size).to(device)
one_hot_bow.scatter_(1, single_bow, 1)
one_hot_bows_vectors.append(one_hot_bow)
return one_hot_bows_vectors
def full_text_generation(
model,
tokenizer,
context=None,
num_samples=1,
device="cuda",
bag_of_words=None,
discrim=None,
class_label=None,
length=100,
stepsize=0.02,
temperature=1.0,
top_k=10,
sample=False,
num_iterations=3,
grad_length=10000,
horizon_length=1,
window_length=0,
decay=False,
gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
**kwargs
):
classifier, class_id = get_classifier(
discrim,
class_label,
device
)
bow_indices = []
if bag_of_words:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
tokenizer)
if bag_of_words and classifier:
print("Both PPLM-BoW and PPLM-Discrim are on. This is not optimized.")
loss_type = PPLM_BOW_DISCRIM
elif bag_of_words:
loss_type = PPLM_BOW
print("Using PPLM-BoW")
elif classifier is not None:
loss_type = PPLM_DISCRIM
print("Using PPLM-Discrim")
else:
raise Exception("Specify either a bag of words or a discriminator")
unpert_gen_tok_text, _, _ = generate_text_pplm(
model=model,
tokenizer=tokenizer,
context=context,
device=device,
length=length,
sample=sample,
perturb=False
)
if device == 'cuda':
torch.cuda.empty_cache()
pert_gen_tok_texts = []
discrim_losses = []
losses_in_time = []
for i in range(num_samples):
pert_gen_tok_text, discrim_loss, loss_in_time = generate_text_pplm(
model=model,
tokenizer=tokenizer,
context=context,
device=device,
perturb=True,
bow_indices=bow_indices,
classifier=classifier,
class_label=class_id,
loss_type=loss_type,
length=length,
stepsize=stepsize,
temperature=temperature,
top_k=top_k,
sample=sample,
num_iterations=num_iterations,
grad_length=grad_length,
horizon_length=horizon_length,
window_length=window_length,
decay=decay,
gamma=gamma,
gm_scale=gm_scale,
kl_scale=kl_scale,
)
pert_gen_tok_texts.append(pert_gen_tok_text)
if classifier is not None:
discrim_losses.append(discrim_loss.data.cpu().numpy())
losses_in_time.append(loss_in_time)
if device == 'cuda':
torch.cuda.empty_cache()
return unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
def generate_text_pplm(
model,
tokenizer,
context=None,
past=None,
device="cuda",
perturb=True,
bow_indices=None,
classifier=None,
class_label=None,
loss_type=0,
length=100,
stepsize=0.02,
temperature=1.0,
top_k=10,
sample=False,
num_iterations=3,
grad_length=10000,
horizon_length=1,
window_length=0,
decay=False,
gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
):
output_so_far = None
if context:
context_t = torch.tensor(context, device=device, dtype=torch.long)
while len(context_t.shape) < 2:
context_t = context_t.unsqueeze(0)
output_so_far = context_t
# collect one hot vectors for bags of words
one_hot_bows_vectors = build_bows_one_hot_vectors(bow_indices, tokenizer,
device)
grad_norms = None
last = None
unpert_discrim_loss = 0
loss_in_time = []
for i in trange(length, ascii=True):
# Get past/probs for current output, except for last word
# Note that GPT takes 2 inputs: past + current_token
# run model forward to obtain unperturbed
if past is None and output_so_far is not None:
last = output_so_far[:, -1:]
if output_so_far.shape[1] > 1:
_, past, _ = model(output_so_far[:, :-1])
unpert_logits, unpert_past, unpert_all_hidden = model(output_so_far)
unpert_last_hidden = unpert_all_hidden[-1]
# check if we are abowe grad max length
if i >= grad_length:
current_stepsize = stepsize * 0
else:
current_stepsize = stepsize
# modify the past if necessary
if not perturb or num_iterations == 0:
pert_past = past
else:
accumulated_hidden = unpert_last_hidden[:, :-1, :]
accumulated_hidden = torch.sum(accumulated_hidden, dim=1)
if past is not None:
pert_past, _, grad_norms, loss_this_iter = perturb_past(
past,
model,
last,
unpert_past=unpert_past,
unpert_logits=unpert_logits,
accumulated_hidden=accumulated_hidden,
grad_norms=grad_norms,
stepsize=current_stepsize,
one_hot_bows_vectors=one_hot_bows_vectors,
classifier=classifier,
class_label=class_label,
loss_type=loss_type,
num_iterations=num_iterations,
horizon_length=horizon_length,
window_length=window_length,
decay=decay,
gamma=gamma,
kl_scale=kl_scale,
device=device,
)
loss_in_time.append(loss_this_iter)
else:
pert_past = past
pert_logits, past, pert_all_hidden = model(last, past=pert_past)
pert_logits = pert_logits[:, -1, :] / temperature # + SMALL_CONST
pert_probs = F.softmax(pert_logits, dim=-1)
if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss()
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
label = torch.tensor([class_label], device=device,
dtype=torch.long)
unpert_discrim_loss = ce_loss(prediction, label)
print(
"unperturbed discrim loss",
unpert_discrim_loss.data.cpu().numpy()
)
else:
unpert_discrim_loss = 0
# Fuse the modified model and original model
if perturb:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
pert_probs = ((pert_probs ** gm_scale) * (
unpert_probs ** (1 - gm_scale))) # + SMALL_CONST
pert_probs = top_k_filter(pert_probs, k=top_k,
probs=True) # + SMALL_CONST
# rescale
if torch.sum(pert_probs) <= 1:
pert_probs = pert_probs / torch.sum(pert_probs)
else:
pert_logits = top_k_filter(pert_logits, k=top_k) # + SMALL_CONST
pert_probs = F.softmax(pert_logits, dim=-1)
# sample or greedy
if sample:
last = torch.multinomial(pert_probs, num_samples=1)
else:
_, last = torch.topk(pert_probs, k=1, dim=-1)
# update context/output_so_far appending the new token
output_so_far = (
last if output_so_far is None
else torch.cat((output_so_far, last), dim=1)
)
print(tokenizer.decode(output_so_far.tolist()[0]))
return output_so_far, unpert_discrim_loss, loss_in_time
def set_generic_model_params(discrim_weights, discrim_meta):
if discrim_weights is None:
raise ValueError('When using a generic discriminator, '
'discrim_weights need to be specified')
if discrim_meta is None:
raise ValueError('When using a generic discriminator, '
'discrim_meta need to be specified')
with open(discrim_meta, 'r') as discrim_meta_file:
meta = json.load(discrim_meta_file)
meta['path'] = discrim_weights
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
def run_pplm_example(
pretrained_model="gpt2-medium",
cond_text="",
uncond=False,
num_samples=1,
bag_of_words=None,
discrim=None,
discrim_weights=None,
discrim_meta=None,
class_label=-1,
length=100,
stepsize=0.02,
temperature=1.0,
top_k=10,
sample=False,
num_iterations=3,
grad_length=10000,
horizon_length=1,
window_length=0,
decay=False,
gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
seed=0,
no_cuda=False,
colorama=False
):
# set Random seed
torch.manual_seed(seed)
np.random.seed(seed)
# set the device
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
if discrim == 'generic':
set_generic_model_params(discrim_weights, discrim_meta)
if discrim is not None:
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][
"pretrained_model"
]
print("discrim = {}, pretrained_model set "
"to discriminator's = {}".format(discrim, pretrained_model))
# load pretrained model
model = GPT2LMHeadModel.from_pretrained(
pretrained_model,
output_hidden_states=True
)
model.to(device)
model.eval()
# load tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
# Freeze GPT-2 weights
for param in model.parameters():
param.requires_grad = False
# figure out conditioning text
if uncond:
tokenized_cond_text = tokenizer.encode(
[tokenizer.bos_token]
)
else:
raw_text = cond_text
while not raw_text:
print("Did you forget to add `--cond_text`? ")
raw_text = input("Model prompt >>> ")
tokenized_cond_text = tokenizer.encode(tokenizer.bos_token + raw_text)
print("= Prefix of sentence =")
print(tokenizer.decode(tokenized_cond_text))
print()
# generate unperturbed and perturbed texts
# full_text_generation returns:
# unpert_gen_tok_text, pert_gen_tok_texts, discrim_losses, losses_in_time
unpert_gen_tok_text, pert_gen_tok_texts, _, _ = full_text_generation(
model=model,
tokenizer=tokenizer,
context=tokenized_cond_text,
device=device,
num_samples=num_samples,
bag_of_words=bag_of_words,
discrim=discrim,
class_label=class_label,
length=length,
stepsize=stepsize,
temperature=temperature,
top_k=top_k,
sample=sample,
num_iterations=num_iterations,
grad_length=grad_length,
horizon_length=horizon_length,
window_length=window_length,
decay=decay,
gamma=gamma,
gm_scale=gm_scale,
kl_scale=kl_scale,
)
# untokenize unperturbed text
unpert_gen_text = tokenizer.decode(unpert_gen_tok_text.tolist()[0])
print("=" * 80)
print("= Unperturbed generated text =")
print(unpert_gen_text)
print()
generated_texts = []
bow_word_ids = set()
if bag_of_words and colorama:
bow_indices = get_bag_of_words_indices(bag_of_words.split(";"),
tokenizer)
for single_bow_list in bow_indices:
# filtering all words in the list composed of more than 1 token
filtered = list(filter(lambda x: len(x) <= 1, single_bow_list))
# w[0] because we are sure w has only 1 item because previous fitler
bow_word_ids.update(w[0] for w in filtered)
# iterate through the perturbed texts
for i, pert_gen_tok_text in enumerate(pert_gen_tok_texts):
try:
# untokenize unperturbed text
if colorama:
import colorama
pert_gen_text = ''
for word_id in pert_gen_tok_text.tolist()[0]:
if word_id in bow_word_ids:
pert_gen_text += '{}{}{}'.format(
colorama.Fore.RED,
tokenizer.decode([word_id]),
colorama.Style.RESET_ALL
)
else:
pert_gen_text += tokenizer.decode([word_id])
else:
pert_gen_text = tokenizer.decode(pert_gen_tok_text.tolist()[0])
print("= Perturbed generated text {} =".format(i + 1))
print(pert_gen_text)
print()
except:
pass
# keep the prefix, perturbed seq, original seq for each index
generated_texts.append(
(tokenized_cond_text, pert_gen_tok_text, unpert_gen_tok_text)
)
return
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model",
"-M",
type=str,
default="gpt2-medium",
help="pretrained model name or path to local checkpoint",
)
parser.add_argument(
"--cond_text", type=str, default="The lake",
help="Prefix texts to condition on"
)
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples to generate from the modified latents",
)
parser.add_argument(
"--bag_of_words",
"-B",
type=str,
default=None,
help="Bags of words used for PPLM-BoW. "
"Either a BOW id (see list in code) or a filepath. "
"Multiple BoWs separated by ;",
)
parser.add_argument(
"--discrim",
"-D",
type=str,
default=None,
choices=("clickbait", "sentiment", "toxicity", "generic"),
help="Discriminator to use",
)
parser.add_argument('--discrim_weights', type=str, default=None,
help='Weights for the generic discriminator')
parser.add_argument('--discrim_meta', type=str, default=None,
help='Meta information for the generic discriminator')
parser.add_argument(
"--class_label",
type=int,
default=-1,
help="Class label used for the discriminator",
)
parser.add_argument("--length", type=int, default=100)
parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=10)
parser.add_argument(
"--sample", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument("--num_iterations", type=int, default=3)
parser.add_argument("--grad_length", type=int, default=10000)
parser.add_argument(
"--window_length",
type=int,
default=0,
help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
)
parser.add_argument(
"--horizon_length",
type=int,
default=1,
help="Length of future to optimize over",
)
parser.add_argument("--decay", action="store_true",
help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument("--colorama", action="store_true",
help="colors keywords")
args = parser.parse_args()
run_pplm_example(**vars(args))
#! /usr/bin/env python3
# coding=utf-8
#Copyright (c) 2019 Uber Technologies, Inc.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import argparse
import csv
import json
import math
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim
import torch.optim as optim
import torch.utils.data as data
from nltk.tokenize.treebank import TreebankWordDetokenizer
from torchtext import data as torchtext_data
from torchtext import datasets
from tqdm import tqdm, trange
from transformers import GPT2Tokenizer, GPT2LMHeadModel
from pplm_classification_head import ClassificationHead
torch.manual_seed(0)
np.random.seed(0)
EPSILON = 1e-10
example_sentence = "This is incredible! I love it, this is the best chicken I have ever had."
max_length_seq = 100
class Discriminator(torch.nn.Module):
"""Transformer encoder followed by a Classification Head"""
def __init__(
self,
class_size,
pretrained_model="gpt2-medium",
cached_mode=False,
device='cpu'
):
super(Discriminator, self).__init__()
self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model)
self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model)
self.embed_size = self.encoder.transformer.config.hidden_size
self.classifier_head = ClassificationHead(
class_size=class_size,
embed_size=self.embed_size
)
self.cached_mode = cached_mode
self.device = device
def get_classifier(self):
return self.classifier_head
def train_custom(self):
for param in self.encoder.parameters():
param.requires_grad = False
self.classifier_head.train()
def avg_representation(self, x):
mask = x.ne(0).unsqueeze(2).repeat(
1, 1, self.embed_size
).float().to(self.device).detach()
hidden, _ = self.encoder.transformer(x)
masked_hidden = hidden * mask
avg_hidden = torch.sum(masked_hidden, dim=1) / (
torch.sum(mask, dim=1).detach() + EPSILON
)
return avg_hidden
def forward(self, x):
if self.cached_mode:
avg_hidden = x.to(self.device)
else:
avg_hidden = self.avg_representation(x.to(self.device))
logits = self.classifier_head(avg_hidden)
probs = F.log_softmax(logits, dim=-1)
return probs
class Dataset(data.Dataset):
def __init__(self, X, y):
"""Reads source and target sequences from txt files."""
self.X = X
self.y = y
def __len__(self):
return len(self.X)
def __getitem__(self, index):
"""Returns one data pair (source and target)."""
data = {}
data["X"] = self.X[index]
data["y"] = self.y[index]
return data
def collate_fn(data):
def pad_sequences(sequences):
lengths = [len(seq) for seq in sequences]
padded_sequences = torch.zeros(
len(sequences),
max(lengths)
).long() # padding value = 0
for i, seq in enumerate(sequences):
end = lengths[i]
padded_sequences[i, :end] = seq[:end]
return padded_sequences, lengths
item_info = {}
for key in data[0].keys():
item_info[key] = [d[key] for d in data]
x_batch, _ = pad_sequences(item_info["X"])
y_batch = torch.tensor(item_info["y"], dtype=torch.long)
return x_batch, y_batch
def cached_collate_fn(data):
item_info = {}
for key in data[0].keys():
item_info[key] = [d[key] for d in data]
x_batch = torch.cat(item_info["X"], 0)
y_batch = torch.tensor(item_info["y"], dtype=torch.long)
return x_batch, y_batch
def train_epoch(data_loader, discriminator, optimizer,
epoch=0, log_interval=10, device='cpu'):
samples_so_far = 0
discriminator.train_custom()
for batch_idx, (input_t, target_t) in enumerate(data_loader):
input_t, target_t = input_t.to(device), target_t.to(device)
optimizer.zero_grad()
output_t = discriminator(input_t)
loss = F.nll_loss(output_t, target_t)
loss.backward(retain_graph=True)
optimizer.step()
samples_so_far += len(input_t)
if batch_idx % log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch + 1,
samples_so_far, len(data_loader.dataset),
100 * samples_so_far / len(data_loader.dataset), loss.item()
)
)
def evaluate_performance(data_loader, discriminator, device='cpu'):
discriminator.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for input_t, target_t in data_loader:
input_t, target_t = input_t.to(device), target_t.to(device)
output_t = discriminator(input_t)
# sum up batch loss
test_loss += F.nll_loss(output_t, target_t, reduction="sum").item()
# get the index of the max log-probability
pred_t = output_t.argmax(dim=1, keepdim=True)
correct += pred_t.eq(target_t.view_as(pred_t)).sum().item()
test_loss /= len(data_loader.dataset)
print(
"Performance on test set: "
"Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)".format(
test_loss, correct, len(data_loader.dataset),
100. * correct / len(data_loader.dataset)
)
)
def predict(input_sentence, model, classes, cached=False, device='cpu'):
input_t = model.tokenizer.encode(input_sentence)
input_t = torch.tensor([input_t], dtype=torch.long, device=device)
if cached:
input_t = model.avg_representation(input_t)
log_probs = model(input_t).data.cpu().numpy().flatten().tolist()
print("Input sentence:", input_sentence)
print("Predictions:", ", ".join(
"{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in
zip(classes, log_probs)
))
def get_cached_data_loader(dataset, batch_size, discriminator,
shuffle=False, device='cpu'):
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
collate_fn=collate_fn)
xs = []
ys = []
for batch_idx, (x, y) in enumerate(tqdm(data_loader, ascii=True)):
with torch.no_grad():
x = x.to(device)
avg_rep = discriminator.avg_representation(x).cpu().detach()
avg_rep_list = torch.unbind(avg_rep.unsqueeze(1))
xs += avg_rep_list
ys += y.cpu().numpy().tolist()
data_loader = torch.utils.data.DataLoader(
dataset=Dataset(xs, ys),
batch_size=batch_size,
shuffle=shuffle,
collate_fn=cached_collate_fn)
return data_loader
def train_discriminator(
dataset, dataset_fp=None, pretrained_model="gpt2-medium",
epochs=10, batch_size=64, log_interval=10,
save_model=False, cached=False, no_cuda=False):
device = "cuda" if torch.cuda.is_available() and not no_cuda else "cpu"
print("Preprocessing {} dataset...".format(dataset))
start = time.time()
if dataset == "SST":
idx2class = ["positive", "negative", "very positive", "very negative",
"neutral"]
class2idx = {c: i for i, c in enumerate(idx2class)}
discriminator = Discriminator(
class_size=len(idx2class),
pretrained_model=pretrained_model,
cached_mode=cached,
device=device
).to(device)
text = torchtext_data.Field()
label = torchtext_data.Field(sequential=False)
train_data, val_data, test_data = datasets.SST.splits(
text,
label,
fine_grained=True,
train_subtrees=True,
)
x = []
y = []
for i in trange(len(train_data), ascii=True):
seq = TreebankWordDetokenizer().detokenize(
vars(train_data[i])["text"]
)
seq = discriminator.tokenizer.encode(seq)
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
x.append(seq)
y.append(class2idx[vars(train_data[i])["label"]])
train_dataset = Dataset(x, y)
test_x = []
test_y = []
for i in trange(len(test_data), ascii=True):
seq = TreebankWordDetokenizer().detokenize(
vars(test_data[i])["text"]
)
seq = discriminator.tokenizer.encode(seq)
seq = torch.tensor([50256] + seq, device=device, dtype=torch.long)
test_x.append(seq)
test_y.append(class2idx[vars(test_data[i])["label"]])
test_dataset = Dataset(test_x, test_y)
discriminator_meta = {
"class_size": len(idx2class),
"embed_size": discriminator.embed_size,
"pretrained_model": pretrained_model,
"class_vocab": class2idx,
"default_class": 2,
}
elif dataset == "clickbait":
idx2class = ["non_clickbait", "clickbait"]
class2idx = {c: i for i, c in enumerate(idx2class)}
discriminator = Discriminator(
class_size=len(idx2class),
pretrained_model=pretrained_model,
cached_mode=cached,
device=device
).to(device)
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
data = []
for i, line in enumerate(f):
try:
data.append(eval(line))
except:
print("Error evaluating line {}: {}".format(
i, line
))
continue
x = []
y = []
with open("datasets/clickbait/clickbait_train_prefix.txt") as f:
for i, line in enumerate(tqdm(f, ascii=True)):
try:
d = eval(line)
seq = discriminator.tokenizer.encode(d["text"])
if len(seq) < max_length_seq:
seq = torch.tensor(
[50256] + seq, device=device, dtype=torch.long
)
else:
print("Line {} is longer than maximum length {}".format(
i, max_length_seq
))
continue
x.append(seq)
y.append(d["label"])
except:
print("Error evaluating / tokenizing"
" line {}, skipping it".format(i))
pass
full_dataset = Dataset(x, y)
train_size = int(0.9 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(
full_dataset, [train_size, test_size]
)
discriminator_meta = {
"class_size": len(idx2class),
"embed_size": discriminator.embed_size,
"pretrained_model": pretrained_model,
"class_vocab": class2idx,
"default_class": 1,
}
elif dataset == "toxic":
idx2class = ["non_toxic", "toxic"]
class2idx = {c: i for i, c in enumerate(idx2class)}
discriminator = Discriminator(
class_size=len(idx2class),
pretrained_model=pretrained_model,
cached_mode=cached,
device=device
).to(device)
x = []
y = []
with open("datasets/toxic/toxic_train.txt") as f:
for i, line in enumerate(tqdm(f, ascii=True)):
try:
d = eval(line)
seq = discriminator.tokenizer.encode(d["text"])
if len(seq) < max_length_seq:
seq = torch.tensor(
[50256] + seq, device=device, dtype=torch.long
)
else:
print("Line {} is longer than maximum length {}".format(
i, max_length_seq
))
continue
x.append(seq)
y.append(int(np.sum(d["label"]) > 0))
except:
print("Error evaluating / tokenizing"
" line {}, skipping it".format(i))
pass
full_dataset = Dataset(x, y)
train_size = int(0.9 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(
full_dataset, [train_size, test_size]
)
discriminator_meta = {
"class_size": len(idx2class),
"embed_size": discriminator.embed_size,
"pretrained_model": pretrained_model,
"class_vocab": class2idx,
"default_class": 0,
}
else: # if dataset == "generic":
# This assumes the input dataset is a TSV with the following structure:
# class \t text
if dataset_fp is None:
raise ValueError("When generic dataset is selected, "
"dataset_fp needs to be specified aswell.")
classes = set()
with open(dataset_fp) as f:
csv_reader = csv.reader(f, delimiter="\t")
for row in tqdm(csv_reader, ascii=True):
if row:
classes.add(row[0])
idx2class = sorted(classes)
class2idx = {c: i for i, c in enumerate(idx2class)}
discriminator = Discriminator(
class_size=len(idx2class),
pretrained_model=pretrained_model,
cached_mode=cached,
device=device
).to(device)
x = []
y = []
with open(dataset_fp) as f:
csv_reader = csv.reader(f, delimiter="\t")
for i, row in enumerate(tqdm(csv_reader, ascii=True)):
if row:
label = row[0]
text = row[1]
try:
seq = discriminator.tokenizer.encode(text)
if (len(seq) < max_length_seq):
seq = torch.tensor(
[50256] + seq,
device=device,
dtype=torch.long
)
else:
print(
"Line {} is longer than maximum length {}".format(
i, max_length_seq
))
continue
x.append(seq)
y.append(class2idx[label])
except:
print("Error tokenizing line {}, skipping it".format(i))
pass
full_dataset = Dataset(x, y)
train_size = int(0.9 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(
full_dataset,
[train_size, test_size]
)
discriminator_meta = {
"class_size": len(idx2class),
"embed_size": discriminator.embed_size,
"pretrained_model": pretrained_model,
"class_vocab": class2idx,
"default_class": 0,
}
end = time.time()
print("Preprocessed {} data points".format(
len(train_dataset) + len(test_dataset))
)
print("Data preprocessing took: {:.3f}s".format(end - start))
if cached:
print("Building representation cache...")
start = time.time()
train_loader = get_cached_data_loader(
train_dataset, batch_size, discriminator,
shuffle=True, device=device
)
test_loader = get_cached_data_loader(
test_dataset, batch_size, discriminator, device=device
)
end = time.time()
print("Building representation cache took: {:.3f}s".format(end - start))
else:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
collate_fn=collate_fn)
if save_model:
with open("{}_classifier_head_meta.json".format(dataset),
"w") as meta_file:
json.dump(discriminator_meta, meta_file)
optimizer = optim.Adam(discriminator.parameters(), lr=0.0001)
for epoch in range(epochs):
start = time.time()
print("\nEpoch", epoch + 1)
train_epoch(
discriminator=discriminator,
data_loader=train_loader,
optimizer=optimizer,
epoch=epoch,
log_interval=log_interval,
device=device
)
evaluate_performance(
data_loader=test_loader,
discriminator=discriminator,
device=device
)
end = time.time()
print("Epoch took: {:.3f}s".format(end - start))
print("\nExample prediction")
predict(example_sentence, discriminator, idx2class,
cached=cached, device=device)
if save_model:
# torch.save(discriminator.state_dict(),
# "{}_discriminator_{}.pt".format(
# args.dataset, epoch + 1
# ))
torch.save(discriminator.get_classifier().state_dict(),
"{}_classifier_head_epoch_{}.pt".format(dataset,
epoch + 1))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train a discriminator on top of GPT-2 representations")
parser.add_argument("--dataset", type=str, default="SST",
choices=("SST", "clickbait", "toxic", "generic"),
help="dataset to train the discriminator on."
"In case of generic, the dataset is expected"
"to be a TSBV file with structure: class \\t text")
parser.add_argument("--dataset_fp", type=str, default="",
help="File path of the dataset to use. "
"Needed only in case of generic datadset")
parser.add_argument("--pretrained_model", type=str, default="gpt2-medium",
help="Pretrained model to use as encoder")
parser.add_argument("--epochs", type=int, default=10, metavar="N",
help="Number of training epochs")
parser.add_argument("--batch_size", type=int, default=64, metavar="N",
help="input batch size for training (default: 64)")
parser.add_argument("--log_interval", type=int, default=10, metavar="N",
help="how many batches to wait before logging training status")
parser.add_argument("--save_model", action="store_true",
help="whether to save the model")
parser.add_argument("--cached", action="store_true",
help="whether to cache the input representations")
parser.add_argument("--no_cuda", action="store_true",
help="use to turn off cuda")
args = parser.parse_args()
train_discriminator(**(vars(args)))
......@@ -247,7 +247,11 @@ def main():
out = out[:, len(context_tokens):].tolist()
for o in out:
text = tokenizer.decode(o, clean_up_tokenization_spaces=True)
text = text[: text.find(args.stop_token) if args.stop_token else None]
if args.stop_token:
index = text.find(args.stop_token)
if index == -1:
index = None
text = text[:index]
print(text)
......
......@@ -22,6 +22,7 @@ import glob
import logging
import os
import random
import json
import numpy as np
import torch
......@@ -47,7 +48,14 @@ from transformers import (WEIGHTS_NAME, BertConfig,
XLNetTokenizer,
DistilBertConfig,
DistilBertForSequenceClassification,
DistilBertTokenizer)
DistilBertTokenizer,
AlbertConfig,
AlbertForSequenceClassification,
AlbertTokenizer,
XLMRobertaConfig,
XLMRobertaForSequenceClassification,
XLMRobertaTokenizer,
)
from transformers import AdamW, get_linear_schedule_with_warmup
......@@ -66,7 +74,9 @@ MODEL_CLASSES = {
'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
'albert': (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
'xlmroberta': (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
}
......@@ -99,6 +109,7 @@ def train(args, train_dataset, model, tokenizer):
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
if args.fp16:
......@@ -158,7 +169,7 @@ def train(args, train_dataset, model, tokenizer):
loss.backward()
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0 and not args.tpu:
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
......@@ -170,15 +181,23 @@ def train(args, train_dataset, model, tokenizer):
global_step += 1
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics
logs = {}
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer)
for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
eval_key = 'eval_{}'.format(key)
logs[eval_key] = value
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
learning_rate_scalar = scheduler.get_lr()[0]
logs['learning_rate'] = learning_rate_scalar
logs['loss'] = loss_scalar
logging_loss = tr_loss
for key, value in logs.items():
tb_writer.add_scalar(key, value, global_step)
print(json.dumps({**logs, **{'step': global_step}}))
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
......@@ -189,11 +208,6 @@ def train(args, train_dataset, model, tokenizer):
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
if args.tpu:
args.xla_model.optimizer_step(optimizer, barrier=True)
model.zero_grad()
global_step += 1
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
......@@ -221,7 +235,7 @@ def evaluate(args, model, tokenizer, prefix=""):
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval
......@@ -294,7 +308,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
else:
logger.info("Creating features from dataset file at %s", args.data_dir)
label_list = processor.get_labels()
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta']:
if task in ['mnli', 'mnli-mm'] and args.model_type in ['roberta', 'xlmroberta']:
# HACK(label indices are swapped in RoBERTa pretrained model)
label_list[1], label_list[2] = label_list[2], label_list[1]
examples = processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
......@@ -370,7 +384,7 @@ def main():
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight deay if we apply some.")
help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
......@@ -397,15 +411,6 @@ def main():
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--tpu', action='store_true',
help="Whether to run on the TPU defined in the environment variables")
parser.add_argument('--tpu_ip_address', type=str, default='',
help="TPU IP address if none are set in the environment variables")
parser.add_argument('--tpu_name', type=str, default='',
help="TPU name if none are set in the environment variables")
parser.add_argument('--xrt_tpu_config', type=str, default='',
help="XRT TPU config if none are set in the environment variables")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
......@@ -439,23 +444,6 @@ def main():
args.n_gpu = 1
args.device = device
if args.tpu:
if args.tpu_ip_address:
os.environ["TPU_IP_ADDRESS"] = args.tpu_ip_address
if args.tpu_name:
os.environ["TPU_NAME"] = args.tpu_name
if args.xrt_tpu_config:
os.environ["XRT_TPU_CONFIG"] = args.xrt_tpu_config
assert "TPU_IP_ADDRESS" in os.environ
assert "TPU_NAME" in os.environ
assert "XRT_TPU_CONFIG" in os.environ
import torch_xla
import torch_xla.core.xla_model as xm
args.device = xm.xla_device()
args.xla_model = xm
# Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
......@@ -509,7 +497,7 @@ def main():
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0) and not args.tpu:
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
......
......@@ -47,7 +47,8 @@ from transformers import (WEIGHTS_NAME, AdamW, get_linear_schedule_with_warmup,
GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
RobertaConfig, RobertaForMaskedLM, RobertaTokenizer,
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer,
CamembertConfig, CamembertForMaskedLM, CamembertTokenizer)
logger = logging.getLogger(__name__)
......@@ -58,7 +59,8 @@ MODEL_CLASSES = {
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
'bert': (BertConfig, BertForMaskedLM, BertTokenizer),
'roberta': (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer)
'distilbert': (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
'camembert': (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer)
}
......@@ -68,7 +70,7 @@ class TextDataset(Dataset):
directory, filename = os.path.split(file_path)
cached_features_file = os.path.join(directory, args.model_name_or_path + '_cached_lm_' + str(block_size) + '_' + filename)
if os.path.exists(cached_features_file):
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
with open(cached_features_file, 'rb') as handle:
self.examples = pickle.load(handle)
......@@ -186,6 +188,13 @@ def train(args, train_dataset, model, tokenizer):
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
# Check if saved optimizer or scheduler states exist
if os.path.isfile(os.path.join(args.model_name_or_path, 'optimizer.pt')) and os.path.isfile(os.path.join(args.model_name_or_path, 'scheduler.pt')):
# Load in optimizer and scheduler states
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'optimizer.pt')))
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, 'scheduler.pt')))
if args.fp16:
try:
from apex import amp
......@@ -214,13 +223,37 @@ def train(args, train_dataset, model, tokenizer):
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
if os.path.exists(args.model_name_or_path):
# set global_step to gobal_step of last saved checkpoint from model path
global_step = int(args.model_name_or_path.split('-')[-1].split('/')[0])
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", global_step)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
tr_loss, logging_loss = 0.0, 0.0
model_to_resize = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_resize.resize_token_embeddings(len(tokenizer))
model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
train_iterator = trange(epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproducibility (even between python 2 and 3)
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
continue
inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
inputs = inputs.to(args.device)
labels = labels.to(args.device)
......@@ -268,11 +301,17 @@ def train(args, train_dataset, model, tokenizer):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
_rotate_checkpoints(args, checkpoint_prefix)
torch.save(optimizer.state_dict(), os.path.join(output_dir, 'optimizer.pt'))
torch.save(scheduler.state_dict(), os.path.join(output_dir, 'scheduler.pt'))
logger.info("Saving optimizer and scheduler states to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
......@@ -297,7 +336,7 @@ def evaluate(args, model, tokenizer, prefix=""):
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
......@@ -391,7 +430,7 @@ def main():
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight deay if we apply some.")
help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
......@@ -431,7 +470,7 @@ def main():
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
args = parser.parse_args()
if args.model_type in ["bert", "roberta", "distilbert"] and not args.mlm:
if args.model_type in ["bert", "roberta", "distilbert", "camembert"] and not args.mlm:
raise ValueError("BERT and RoBERTa do not have LM heads but masked LM heads. They must be run using the --mlm "
"flag (masked language modeling).")
if args.eval_data_file is None and args.do_eval:
......
......@@ -226,7 +226,7 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
......
......@@ -37,17 +37,22 @@ from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import WEIGHTS_NAME, BertConfig, BertForTokenClassification, BertTokenizer
from transformers import RobertaConfig, RobertaForTokenClassification, RobertaTokenizer
from transformers import DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer
from transformers import CamembertConfig, CamembertForTokenClassification, CamembertTokenizer
from transformers import XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer
logger = logging.getLogger(__name__)
ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig,
CamembertConfig, XLMRobertaConfig)),
())
MODEL_CLASSES = {
"bert": (BertConfig, BertForTokenClassification, BertTokenizer),
"roberta": (RobertaConfig, RobertaForTokenClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer)
"distilbert": (DistilBertConfig, DistilBertForTokenClassification, DistilBertTokenizer),
"camembert": (CamembertConfig, CamembertForTokenClassification, CamembertTokenizer),
"xlmroberta": (XLMRobertaConfig, XLMRobertaForTokenClassification, XLMRobertaTokenizer),
}
......@@ -125,7 +130,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
"attention_mask": batch[1],
"labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
......@@ -215,7 +220,7 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
"attention_mask": batch[1],
"labels": batch[3]}
if args.model_type != "distilbert":
inputs["token_type_ids"]: batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
inputs["token_type_ids"] = batch[2] if args.model_type in ["bert", "xlnet"] else None # XLM and RoBERTa don"t use segment_ids
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]
......
......@@ -16,6 +16,8 @@
""" Finetuning the library models for question-answering on SQuAD (DistilBERT, Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function
from transformers.data.processors.squad import SquadV1Processor, SquadV2Processor, SquadResult
from transformers.data.metrics.squad_metrics import compute_predictions_logits, compute_predictions_log_probs, squad_evaluate
import argparse
import logging
......@@ -23,11 +25,9 @@ import os
import random
import glob
import timeit
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.utils.data.distributed import DistributedSampler
try:
......@@ -43,18 +43,12 @@ from transformers import (WEIGHTS_NAME, BertConfig,
XLMTokenizer, XLNetConfig,
XLNetForQuestionAnswering,
XLNetTokenizer,
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
from transformers import AdamW, get_linear_schedule_with_warmup
DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer,
AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer,
XLMConfig, XLMForQuestionAnswering, XLMTokenizer,
)
from utils_squad import (read_squad_examples, convert_examples_to_features,
RawResult, write_predictions,
RawResultExtended, write_predictions_extended)
# The follwing import is the official SQuAD evaluation script (2.0).
# You can remove it from the dependencies if you are using this script outside of the library
# We've added it here for automated tests (see examples/test_examples.py file)
from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
from transformers import AdamW, get_linear_schedule_with_warmup, squad_convert_examples_to_features
logger = logging.getLogger(__name__)
......@@ -65,7 +59,8 @@ MODEL_CLASSES = {
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer)
'distilbert': (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
'albert': (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
}
def set_seed(args):
......@@ -101,11 +96,13 @@ def train(args, train_dataset, model, tokenizer):
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
......@@ -128,22 +125,28 @@ def train(args, train_dataset, model, tokenizer):
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
global_step = 1
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
model.train()
batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0],
inputs = {
'input_ids': batch[0],
'attention_mask': batch[1],
'start_positions': batch[3],
'end_positions': batch[4]}
'end_positions': batch[4]
}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2]
if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[5],
'p_mask': batch[6]})
......@@ -175,8 +178,8 @@ def train(args, train_dataset, model, tokenizer):
model.zero_grad()
global_step += 1
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer)
for key, value in results.items():
......@@ -185,8 +188,8 @@ def train(args, train_dataset, model, tokenizer):
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
......@@ -215,50 +218,72 @@ def evaluate(args, model, tokenizer, prefix=""):
os.makedirs(args.output_dir)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
eval_sampler = SequentialSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
if args.n_gpu > 1:
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
model = torch.nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
all_results = []
start_time = timeit.default_timer()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval()
batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad():
inputs = {'input_ids': batch[0],
inputs = {
'input_ids': batch[0],
'attention_mask': batch[1]
}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = None if args.model_type == 'xlm' else batch[2] # XLM don't use segment_ids
example_indices = batch[3]
# XLNet and XLM use more arguments for their predictions
if args.model_type in ['xlnet', 'xlm']:
inputs.update({'cls_index': batch[4],
'p_mask': batch[5]})
inputs.update({'cls_index': batch[4], 'p_mask': batch[5]})
outputs = model(**inputs)
for i, example_index in enumerate(example_indices):
eval_feature = features[example_index.item()]
unique_id = int(eval_feature.unique_id)
if args.model_type in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure
result = RawResultExtended(unique_id = unique_id,
start_top_log_probs = to_list(outputs[0][i]),
start_top_index = to_list(outputs[1][i]),
end_top_log_probs = to_list(outputs[2][i]),
end_top_index = to_list(outputs[3][i]),
cls_logits = to_list(outputs[4][i]))
output = [to_list(output[i]) for output in outputs]
# Some models (XLNet, XLM) use 5 arguments for their predictions, while the other "simpler"
# models only use two.
if len(output) >= 5:
start_logits = output[0]
start_top_index = output[1]
end_logits = output[2]
end_top_index = output[3]
cls_logits = output[4]
result = SquadResult(
unique_id, start_logits, end_logits,
start_top_index=start_top_index,
end_top_index=end_top_index,
cls_logits=cls_logits
)
else:
result = RawResult(unique_id = unique_id,
start_logits = to_list(outputs[0][i]),
end_logits = to_list(outputs[1][i]))
start_logits, end_logits = output
result = SquadResult(
unique_id, start_logits, end_logits
)
all_results.append(result)
evalTime = timeit.default_timer() - start_time
......@@ -267,63 +292,84 @@ def evaluate(args, model, tokenizer, prefix=""):
# Compute predictions
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
if args.version_2_with_negative:
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
else:
output_null_log_odds_file = None
# XLNet and XLM use a more complex post-processing procedure
if args.model_type in ['xlnet', 'xlm']:
# XLNet uses a more complex post-processing procedure
write_predictions_extended(examples, features, all_results, args.n_best_size,
start_n_top = model.config.start_n_top if hasattr(model, "config") else model.module.config.start_n_top
end_n_top = model.config.end_n_top if hasattr(model, "config") else model.module.config.end_n_top
predictions = compute_predictions_log_probs(examples, features, all_results, args.n_best_size,
args.max_answer_length, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.predict_file,
model.config.start_n_top, model.config.end_n_top,
output_nbest_file, output_null_log_odds_file,
start_n_top, end_n_top,
args.version_2_with_negative, tokenizer, args.verbose_logging)
else:
write_predictions(examples, features, all_results, args.n_best_size,
predictions = compute_predictions_logits(examples, features, all_results, args.n_best_size,
args.max_answer_length, args.do_lower_case, output_prediction_file,
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
args.version_2_with_negative, args.null_score_diff_threshold)
# Evaluate with the official SQuAD script
evaluate_options = EVAL_OPTS(data_file=args.predict_file,
pred_file=output_prediction_file,
na_prob_file=output_null_log_odds_file)
results = evaluate_on_squad(evaluate_options)
# Compute the F1 and exact scores.
results = squad_evaluate(examples, predictions)
return results
def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False):
if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Load data features from cache or dataset file
input_file = args.predict_file if evaluate else args.train_file
cached_features_file = os.path.join(os.path.dirname(input_file), 'cached_{}_{}_{}'.format(
input_dir = args.data_dir if args.data_dir else "."
cached_features_file = os.path.join(input_dir, 'cached_{}_{}_{}'.format(
'dev' if evaluate else 'train',
list(filter(None, args.model_name_or_path.split('/'))).pop(),
str(args.max_seq_length)))
str(args.max_seq_length))
)
# Init features and dataset from cache if it exists
if os.path.exists(cached_features_file) and not args.overwrite_cache and not output_examples:
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
features_and_dataset = torch.load(cached_features_file)
features, dataset = features_and_dataset["features"], features_and_dataset["dataset"]
else:
logger.info("Creating features from dataset file at %s", input_file)
examples = read_squad_examples(input_file=input_file,
is_training=not evaluate,
version_2_with_negative=args.version_2_with_negative)
features = convert_examples_to_features(examples=examples,
logger.info("Creating features from dataset file at %s", input_dir)
if not args.data_dir and ((evaluate and not args.predict_file) or (not evaluate and not args.train_file)):
try:
import tensorflow_datasets as tfds
except ImportError:
raise ImportError("If not data_dir is specified, tensorflow_datasets needs to be installed.")
if args.version_2_with_negative:
logger.warn("tensorflow_datasets does not handle version 2 of SQuAD.")
tfds_examples = tfds.load("squad")
examples = SquadV1Processor().get_examples_from_dataset(tfds_examples, evaluate=evaluate)
else:
processor = SquadV2Processor() if args.version_2_with_negative else SquadV1Processor()
if evaluate:
examples = processor.get_dev_examples(args.data_dir, filename=args.predict_file)
else:
examples = processor.get_train_examples(args.data_dir, filename=args.train_file)
features, dataset = squad_convert_examples_to_features(
examples=examples,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
doc_stride=args.doc_stride,
max_query_length=args.max_query_length,
is_training=not evaluate,
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
return_dataset='pt'
)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
torch.save({"features": features, "dataset": dataset}, cached_features_file)
if args.local_rank == 0 and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
......@@ -355,10 +401,6 @@ def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--train_file", default=None, type=str, required=True,
help="SQuAD json for training. E.g., train-v1.1.json")
parser.add_argument("--predict_file", default=None, type=str, required=True,
help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
......@@ -367,6 +409,15 @@ def main():
help="The output directory where the model checkpoints and predictions will be written.")
## Other parameters
parser.add_argument("--data_dir", default=None, type=str,
help="The input data dir. Should contain the .json files for the task." +
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
parser.add_argument("--train_file", default=None, type=str,
help="The input training file. If a data dir is specified, will look for the file there" +
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
parser.add_argument("--predict_file", default=None, type=str,
help="The input evaluation file. If a data dir is specified, will look for the file there" +
"If no data dir or train/predict files are specified, will run with tensorflow_datasets.")
parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name")
parser.add_argument("--tokenizer_name", default="", type=str,
......@@ -405,7 +456,7 @@ def main():
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight deay if we apply some.")
help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
......@@ -540,7 +591,7 @@ def main():
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
model = model_class.from_pretrained(args.output_dir, force_download=True)
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
model.to(args.device)
......@@ -548,17 +599,23 @@ def main():
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
if args.do_train:
logger.info("Loading checkpoints saved during training for evaluation")
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
else:
logger.info("Loading checkpoint %s for evaluation", args.model_name_or_path)
checkpoints = [args.model_name_or_path]
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
# Reload the model
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint)
model = model_class.from_pretrained(checkpoint, force_download=True)
model.to(args.device)
# Evaluate
......
# coding=utf-8
# Copyright 2019 The HuggingFace Inc. team.
# Copyright (c) 2019 The HuggingFace Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning seq2seq models for sequence generation."""
import argparse
import functools
import logging
import os
import random
import sys
import numpy as np
from tqdm import tqdm, trange
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import (
AutoTokenizer,
BertForMaskedLM,
BertConfig,
PreTrainedEncoderDecoder,
Model2Model,
)
from utils_summarization import (
CNNDailyMailDataset,
encode_for_summarization,
fit_to_block_size,
build_lm_labels,
build_mask,
compute_token_type_ids,
)
logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# ------------
# Load dataset
# ------------
def load_and_cache_examples(args, tokenizer):
dataset = CNNDailyMailDataset(tokenizer, data_dir=args.data_dir)
return dataset
def collate(data, tokenizer, block_size):
""" List of tuple as an input. """
# remove the files with empty an story/summary, encode and fit to block
data = filter(lambda x: not (len(x[0]) == 0 or len(x[1]) == 0), data)
data = [
encode_for_summarization(story, summary, tokenizer) for story, summary in data
]
data = [
(
fit_to_block_size(story, block_size, tokenizer.pad_token_id),
fit_to_block_size(summary, block_size, tokenizer.pad_token_id),
)
for story, summary in data
]
stories = torch.tensor([story for story, summary in data])
summaries = torch.tensor([summary for story, summary in data])
encoder_token_type_ids = compute_token_type_ids(stories, tokenizer.cls_token_id)
encoder_mask = build_mask(stories, tokenizer.pad_token_id)
decoder_mask = build_mask(summaries, tokenizer.pad_token_id)
lm_labels = build_lm_labels(summaries, tokenizer.pad_token_id)
return (
stories,
summaries,
encoder_token_type_ids,
encoder_mask,
decoder_mask,
lm_labels,
)
# ----------
# Optimizers
# ----------
class BertSumOptimizer(object):
""" Specific optimizer for BertSum.
As described in [1], the authors fine-tune BertSum for abstractive
summarization using two Adam Optimizers with different warm-up steps and
learning rate. They also use a custom learning rate scheduler.
[1] Liu, Yang, and Mirella Lapata. "Text summarization with pretrained encoders."
arXiv preprint arXiv:1908.08345 (2019).
"""
def __init__(self, model, lr, warmup_steps, beta_1=0.99, beta_2=0.999, eps=1e-8):
self.encoder = model.encoder
self.decoder = model.decoder
self.lr = lr
self.warmup_steps = warmup_steps
self.optimizers = {
"encoder": Adam(
model.encoder.parameters(),
lr=lr["encoder"],
betas=(beta_1, beta_2),
eps=eps,
),
"decoder": Adam(
model.decoder.parameters(),
lr=lr["decoder"],
betas=(beta_1, beta_2),
eps=eps,
),
}
self._step = 0
def _update_rate(self, stack):
return self.lr[stack] * min(
self._step ** (-0.5), self._step * self.warmup_steps[stack] ** (-0.5)
)
def zero_grad(self):
self.optimizer_decoder.zero_grad()
self.optimizer_encoder.zero_grad()
def step(self):
self._step += 1
for stack, optimizer in self.optimizers.items():
new_rate = self._update_rate(stack)
for param_group in optimizer.param_groups:
param_group["lr"] = new_rate
optimizer.step()
# ------------
# Train
# ------------
def train(args, model, tokenizer):
""" Fine-tune the pretrained model on the corpus. """
set_seed(args)
# Load the data
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_dataset = load_and_cache_examples(args, tokenizer)
train_sampler = RandomSampler(train_dataset)
model_collate_fn = functools.partial(collate, tokenizer=tokenizer, block_size=512)
train_dataloader = DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=args.train_batch_size,
collate_fn=model_collate_fn,
)
# Training schedule
if args.max_steps > 0:
t_total = args.max_steps
args.num_train_epochs = t_total // (
len(train_dataloader) // args.gradient_accumulation_steps + 1
)
else:
t_total = (
len(train_dataloader)
// args.gradient_accumulation_steps
* args.num_train_epochs
)
# Prepare the optimizer
lr = {"encoder": 0.002, "decoder": 0.2}
warmup_steps = {"encoder": 20000, "decoder": 10000}
optimizer = BertSumOptimizer(model, lr, warmup_steps)
# Train
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(
" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size
)
logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps
# * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
model.zero_grad()
train_iterator = trange(args.num_train_epochs, desc="Epoch", disable=True)
global_step = 0
tr_loss = 0.0
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=True)
for step, batch in enumerate(epoch_iterator):
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
source = source.to(args.device)
target = target.to(args.device)
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
encoder_mask = encoder_mask.to(args.device)
decoder_mask = decoder_mask.to(args.device)
lm_labels = lm_labels.to(args.device)
model.train()
outputs = model(
source,
target,
encoder_token_type_ids=encoder_token_type_ids,
encoder_attention_mask=encoder_mask,
decoder_attention_mask=decoder_mask,
decoder_lm_labels=lm_labels,
)
loss = outputs[0]
print(loss)
if args.gradient_accumulation_steps > 1:
loss /= args.gradient_accumulation_steps
loss.backward()
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
model.zero_grad()
global_step += 1
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
if args.max_steps > 0 and global_step > args.max_steps:
train_iterator.close()
break
return global_step, tr_loss / global_step
# ------------
# Train
# ------------
def evaluate(args, model, tokenizer, prefix=""):
set_seed(args)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
eval_dataset = load_and_cache_examples(args, tokenizer, evaluate=True)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size
)
# multi-gpu evaluate
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss = 0.0
nb_eval_steps = 0
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
source, target, encoder_token_type_ids, encoder_mask, decoder_mask, lm_labels = batch
source = source.to(args.device)
target = target.to(args.device)
encoder_token_type_ids = encoder_token_type_ids.to(args.device)
encoder_mask = encoder_mask.to(args.device)
decoder_mask = decoder_mask.to(args.device)
lm_labels = lm_labels.to(args.device)
with torch.no_grad():
outputs = model(
source,
target,
encoder_token_type_ids=encoder_token_type_ids,
encoder_attention_mask=encoder_mask,
decoder_attention_mask=decoder_mask,
decoder_lm_labels=lm_labels,
)
lm_loss = outputs[0]
eval_loss += lm_loss.mean().item()
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
perplexity = torch.exp(torch.tensor(eval_loss))
result = {"perplexity": perplexity}
# Save the evaluation's results
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
return result
def main():
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument(
"--data_dir",
default=None,
type=str,
required=True,
help="The input training data file (a text file).",
)
parser.add_argument(
"--output_dir",
default=None,
type=str,
required=True,
help="The output directory where the model predictions and checkpoints will be written.",
)
# Optional parameters
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--do_evaluate",
type=bool,
default=False,
help="Run model evaluation on out-of-sample data.",
)
parser.add_argument("--do_train", type=bool, default=False, help="Run training.")
parser.add_argument(
"--do_overwrite_output_dir",
type=bool,
default=False,
help="Whether to overwrite the output dir.",
)
parser.add_argument(
"--model_name_or_path",
default="bert-base-cased",
type=str,
help="The model checkpoint to initialize the encoder and decoder's weights with.",
)
parser.add_argument(
"--model_type",
default="bert",
type=str,
help="The decoder architecture to be fine-tuned.",
)
parser.add_argument(
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
)
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
)
parser.add_argument(
"--to_cpu", default=False, type=bool, help="Whether to force training on CPU."
)
parser.add_argument(
"--num_train_epochs",
default=10,
type=int,
help="Total number of training epochs to perform.",
)
parser.add_argument(
"--per_gpu_train_batch_size",
default=4,
type=int,
help="Batch size per GPU/CPU for training.",
)
parser.add_argument("--seed", default=42, type=int)
args = parser.parse_args()
if (
os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
and not args.do_overwrite_output_dir
):
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite.".format(
args.output_dir
)
)
# Set up training device
if args.to_cpu or not torch.cuda.is_available():
args.device = torch.device("cpu")
args.n_gpu = 0
else:
args.device = torch.device("cuda")
args.n_gpu = torch.cuda.device_count()
# Load pretrained model and tokenizer. The decoder's weights are randomly initialized.
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
config = BertConfig.from_pretrained(args.model_name_or_path)
decoder_model = BertForMaskedLM(config)
model = Model2Model.from_pretrained(
args.model_name_or_path, decoder_model=decoder_model
)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.warning(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
0,
args.device,
args.n_gpu,
False,
False,
)
logger.info("Training/evaluation parameters %s", args)
# Train the model
model.to(args.device)
if args.do_train:
global_step, tr_loss = train(args, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = (
model.module if hasattr(model, "module") else model
) # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))
# Evaluate the model
results = {}
if args.do_evaluate:
checkpoints = []
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
encoder_checkpoint = os.path.join(checkpoint, "encoder")
decoder_checkpoint = os.path.join(checkpoint, "decoder")
model = PreTrainedEncoderDecoder.from_pretrained(
encoder_checkpoint, decoder_checkpoint
)
model.to(args.device)
results = "placeholder"
return results
if __name__ == "__main__":
main()
# coding=utf-8
import datetime
import os
import math
import glob
import re
import tensorflow as tf
import collections
import numpy as np
from seqeval import metrics
import _pickle as pickle
from absl import logging
from transformers import TF2_WEIGHTS_NAME, BertConfig, BertTokenizer, TFBertForTokenClassification
from transformers import RobertaConfig, RobertaTokenizer, TFRobertaForTokenClassification
from transformers import DistilBertConfig, DistilBertTokenizer, TFDistilBertForTokenClassification
from transformers import create_optimizer, GradientAccumulator
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
from fastprogress import master_bar, progress_bar
from absl import flags
from absl import app
ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
())
MODEL_CLASSES = {
"bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
"roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer)
}
flags.DEFINE_string(
"data_dir", None,
"The input data dir. Should contain the .conll files (or other data files) "
"for the task.")
flags.DEFINE_string(
"model_type", None,
"Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
flags.DEFINE_string(
"model_name_or_path", None,
"Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
flags.DEFINE_string(
"output_dir", None,
"The output directory where the model checkpoints will be written.")
flags.DEFINE_string(
"labels", "",
"Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.")
flags.DEFINE_string(
"config_name", "",
"Pretrained config name or path if not the same as model_name")
flags.DEFINE_string(
"tokenizer_name", "",
"Pretrained tokenizer name or path if not the same as model_name")
flags.DEFINE_string(
"cache_dir", "",
"Where do you want to store the pre-trained models downloaded from s3")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sentence length after tokenization. "
"Sequences longer than this will be truncated, sequences shorter "
"will be padded.")
flags.DEFINE_string(
"tpu", None,
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
flags.DEFINE_integer(
"num_tpu_cores", 8,
"Total number of TPU cores to use.")
flags.DEFINE_boolean(
"do_train", False,
"Whether to run training.")
flags.DEFINE_boolean(
"do_eval", False,
"Whether to run eval on the dev set.")
flags.DEFINE_boolean(
"do_predict", False,
"Whether to run predictions on the test set.")
flags.DEFINE_boolean(
"evaluate_during_training", False,
"Whether to run evaluation during training at each logging step.")
flags.DEFINE_boolean(
"do_lower_case", False,
"Set this flag if you are using an uncased model.")
flags.DEFINE_integer(
"per_device_train_batch_size", 8,
"Batch size per GPU/CPU/TPU for training.")
flags.DEFINE_integer(
"per_device_eval_batch_size", 8,
"Batch size per GPU/CPU/TPU for evaluation.")
flags.DEFINE_integer(
"gradient_accumulation_steps", 1,
"Number of updates steps to accumulate before performing a backward/update pass.")
flags.DEFINE_float(
"learning_rate", 5e-5,
"The initial learning rate for Adam.")
flags.DEFINE_float(
"weight_decay", 0.0,
"Weight decay if we apply some.")
flags.DEFINE_float(
"adam_epsilon", 1e-8,
"Epsilon for Adam optimizer.")
flags.DEFINE_float(
"max_grad_norm", 1.0,
"Max gradient norm.")
flags.DEFINE_integer(
"num_train_epochs", 3,
"Total number of training epochs to perform.")
flags.DEFINE_integer(
"max_steps", -1,
"If > 0: set total number of training steps to perform. Override num_train_epochs.")
flags.DEFINE_integer(
"warmup_steps", 0,
"Linear warmup over warmup_steps.")
flags.DEFINE_integer(
"logging_steps", 50,
"Log every X updates steps.")
flags.DEFINE_integer(
"save_steps", 50,
"Save checkpoint every X updates steps.")
flags.DEFINE_boolean(
"eval_all_checkpoints", False,
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
flags.DEFINE_boolean(
"no_cuda", False,
"Avoid using CUDA when available")
flags.DEFINE_boolean(
"overwrite_output_dir", False,
"Overwrite the content of the output directory")
flags.DEFINE_boolean(
"overwrite_cache", False,
"Overwrite the cached training and evaluation sets")
flags.DEFINE_integer(
"seed", 42,
"random seed for initialization")
flags.DEFINE_boolean(
"fp16", False,
"Whether to use 16-bit (mixed) precision instead of 32-bit")
flags.DEFINE_string(
"gpus", "0",
"Comma separated list of gpus devices. If only one, switch to single "
"gpu strategy, if None takes all the gpus available.")
def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id):
if args['max_steps'] > 0:
num_train_steps = args['max_steps'] * args['gradient_accumulation_steps']
args['num_train_epochs'] = 1
else:
num_train_steps = math.ceil(num_train_examples / train_batch_size) // args['gradient_accumulation_steps'] * args['num_train_epochs']
writer = tf.summary.create_file_writer("/tmp/mylogs")
with strategy.scope():
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
optimizer = create_optimizer(args['learning_rate'], num_train_steps, args['warmup_steps'])
if args['fp16']:
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, 'dynamic')
loss_metric = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
gradient_accumulator = GradientAccumulator()
logging.info("***** Running training *****")
logging.info(" Num examples = %d", num_train_examples)
logging.info(" Num Epochs = %d", args['num_train_epochs'])
logging.info(" Instantaneous batch size per device = %d", args['per_device_train_batch_size'])
logging.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
train_batch_size * args['gradient_accumulation_steps'])
logging.info(" Gradient Accumulation steps = %d", args['gradient_accumulation_steps'])
logging.info(" Total training steps = %d", num_train_steps)
model.summary()
@tf.function
def apply_gradients():
grads_and_vars = []
for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
if gradient is not None:
scaled_gradient = gradient / (args['n_device'] * args['gradient_accumulation_steps'])
grads_and_vars.append((scaled_gradient, variable))
else:
grads_and_vars.append((gradient, variable))
optimizer.apply_gradients(grads_and_vars, args['max_grad_norm'])
gradient_accumulator.reset()
@tf.function
def train_step(train_features, train_labels):
def step_fn(train_features, train_labels):
inputs = {'attention_mask': train_features['input_mask'], 'training': True}
if args['model_type'] != "distilbert":
inputs["token_type_ids"] = train_features['segment_ids'] if args['model_type'] in ["bert", "xlnet"] else None
with tf.GradientTape() as tape:
logits = model(train_features['input_ids'], **inputs)[0]
logits = tf.reshape(logits, (-1, len(labels) + 1))
active_loss = tf.reshape(train_features['input_mask'], (-1,))
active_logits = tf.boolean_mask(logits, active_loss)
train_labels = tf.reshape(train_labels, (-1,))
active_labels = tf.boolean_mask(train_labels, active_loss)
cross_entropy = loss_fct(active_labels, active_logits)
loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
grads = tape.gradient(loss, model.trainable_variables)
gradient_accumulator(grads)
return cross_entropy
per_example_losses = strategy.experimental_run_v2(step_fn, args=(train_features, train_labels))
mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)
return mean_loss
current_time = datetime.datetime.now()
train_iterator = master_bar(range(args['num_train_epochs']))
global_step = 0
logging_loss = 0.0
for epoch in train_iterator:
epoch_iterator = progress_bar(train_dataset, total=num_train_steps, parent=train_iterator, display=args['n_device'] > 1)
step = 1
with strategy.scope():
for train_features, train_labels in epoch_iterator:
loss = train_step(train_features, train_labels)
if step % args['gradient_accumulation_steps'] == 0:
strategy.experimental_run_v2(apply_gradients)
loss_metric(loss)
global_step += 1
if args['logging_steps'] > 0 and global_step % args['logging_steps'] == 0:
# Log metrics
if args['n_device'] == 1 and args['evaluate_during_training']: # Only evaluate when single GPU otherwise metrics may not average well
y_true, y_pred, eval_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev")
report = metrics.classification_report(y_true, y_pred, digits=4)
logging.info("Eval at step " + str(global_step) + "\n" + report)
logging.info("eval_loss: " + str(eval_loss))
precision = metrics.precision_score(y_true, y_pred)
recall = metrics.recall_score(y_true, y_pred)
f1 = metrics.f1_score(y_true, y_pred)
with writer.as_default():
tf.summary.scalar("eval_loss", eval_loss, global_step)
tf.summary.scalar("precision", precision, global_step)
tf.summary.scalar("recall", recall, global_step)
tf.summary.scalar("f1", f1, global_step)
lr = optimizer.learning_rate
learning_rate = lr(step)
with writer.as_default():
tf.summary.scalar("lr", learning_rate, global_step)
tf.summary.scalar("loss", (loss_metric.result() - logging_loss) / args['logging_steps'], global_step)
logging_loss = loss_metric.result()
with writer.as_default():
tf.summary.scalar("loss", loss_metric.result(), step=step)
if args['save_steps'] > 0 and global_step % args['save_steps'] == 0:
# Save model checkpoint
output_dir = os.path.join(args['output_dir'], "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model.save_pretrained(output_dir)
logging.info("Saving model checkpoint to %s", output_dir)
train_iterator.child.comment = f'loss : {loss_metric.result()}'
step += 1
train_iterator.write(f'loss epoch {epoch + 1}: {loss_metric.result()}')
loss_metric.reset_states()
logging.info(" Training took time = {}".format(datetime.datetime.now() - current_time))
def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
eval_dataset, size = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode)
eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
preds = None
num_eval_steps = math.ceil(size / eval_batch_size)
master = master_bar(range(1))
eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args['n_device'] > 1)
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
loss = 0.0
logging.info("***** Running evaluation *****")
logging.info(" Num examples = %d", size)
logging.info(" Batch size = %d", eval_batch_size)
for eval_features, eval_labels in eval_iterator:
inputs = {'attention_mask': eval_features['input_mask'], 'training': False}
if args['model_type'] != "distilbert":
inputs["token_type_ids"] = eval_features['segment_ids'] if args['model_type'] in ["bert", "xlnet"] else None
with strategy.scope():
logits = model(eval_features['input_ids'], **inputs)[0]
tmp_logits = tf.reshape(logits, (-1, len(labels) + 1))
active_loss = tf.reshape(eval_features['input_mask'], (-1,))
active_logits = tf.boolean_mask(tmp_logits, active_loss)
tmp_eval_labels = tf.reshape(eval_labels, (-1,))
active_labels = tf.boolean_mask(tmp_eval_labels, active_loss)
cross_entropy = loss_fct(active_labels, active_logits)
loss += tf.reduce_sum(cross_entropy) * (1.0 / eval_batch_size)
if preds is None:
preds = logits.numpy()
label_ids = eval_labels.numpy()
else:
preds = np.append(preds, logits.numpy(), axis=0)
label_ids = np.append(label_ids, eval_labels.numpy(), axis=0)
preds = np.argmax(preds, axis=2)
y_pred = [[] for _ in range(label_ids.shape[0])]
y_true = [[] for _ in range(label_ids.shape[0])]
loss = loss / num_eval_steps
for i in range(label_ids.shape[0]):
for j in range(label_ids.shape[1]):
if label_ids[i, j] != pad_token_label_id:
y_pred[i].append(labels[preds[i, j] - 1])
y_true[i].append(labels[label_ids[i, j] - 1])
return y_true, y_pred, loss.numpy()
def load_cache(cached_file, max_seq_length):
name_to_features = {
"input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
"input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
"segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
"label_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
}
def _decode_record(record):
example = tf.io.parse_single_example(record, name_to_features)
features = {}
features['input_ids'] = example['input_ids']
features['input_mask'] = example['input_mask']
features['segment_ids'] = example['segment_ids']
return features, example['label_ids']
d = tf.data.TFRecordDataset(cached_file)
d = d.map(_decode_record, num_parallel_calls=4)
count = d.reduce(0, lambda x, _: x + 1)
return d, count.numpy()
def save_cache(features, cached_features_file):
writer = tf.io.TFRecordWriter(cached_features_file)
for (ex_index, feature) in enumerate(features):
if ex_index % 5000 == 0:
logging.info("Writing example %d of %d" % (ex_index, len(features)))
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
record_feature = collections.OrderedDict()
record_feature["input_ids"] = create_int_feature(feature.input_ids)
record_feature["input_mask"] = create_int_feature(feature.input_mask)
record_feature["segment_ids"] = create_int_feature(feature.segment_ids)
record_feature["label_ids"] = create_int_feature(feature.label_ids)
tf_example = tf.train.Example(features=tf.train.Features(feature=record_feature))
writer.write(tf_example.SerializeToString())
writer.close()
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
drop_remainder = True if args['tpu'] or mode == 'train' else False
# Load data features from cache or dataset file
cached_features_file = os.path.join(args['data_dir'], "cached_{}_{}_{}.tf_record".format(mode,
list(filter(None, args['model_name_or_path'].split("/"))).pop(),
str(args['max_seq_length'])))
if os.path.exists(cached_features_file) and not args['overwrite_cache']:
logging.info("Loading features from cached file %s", cached_features_file)
dataset, size = load_cache(cached_features_file, args['max_seq_length'])
else:
logging.info("Creating features from dataset file at %s", args['data_dir'])
examples = read_examples_from_file(args['data_dir'], mode)
features = convert_examples_to_features(examples, labels, args['max_seq_length'], tokenizer,
cls_token_at_end=bool(args['model_type'] in ["xlnet"]),
# xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if args['model_type'] in ["xlnet"] else 0,
sep_token=tokenizer.sep_token,
sep_token_extra=bool(args['model_type'] in ["roberta"]),
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(args['model_type'] in ["xlnet"]),
# pad on the left for xlnet
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=4 if args['model_type'] in ["xlnet"] else 0,
pad_token_label_id=pad_token_label_id
)
logging.info("Saving features into cached file %s", cached_features_file)
save_cache(features, cached_features_file)
dataset, size = load_cache(cached_features_file, args['max_seq_length'])
if mode == 'train':
dataset = dataset.repeat()
dataset = dataset.shuffle(buffer_size=8192, seed=args['seed'])
dataset = dataset.batch(batch_size, drop_remainder)
dataset = dataset.prefetch(buffer_size=batch_size)
return dataset, size
def main(_):
logging.set_verbosity(logging.INFO)
args = flags.FLAGS.flag_values_dict()
if os.path.exists(args['output_dir']) and os.listdir(
args['output_dir']) and args['do_train'] and not args['overwrite_output_dir']:
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args['output_dir']))
if args['fp16']:
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
if args['tpu']:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args['tpu'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
args['n_device'] = args['num_tpu_cores']
elif len(args['gpus'].split(',')) > 1:
args['n_device'] = len([f"/gpu:{gpu}" for gpu in args['gpus'].split(',')])
strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args['gpus'].split(',')])
elif args['no_cuda']:
args['n_device'] = 1
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
else:
args['n_device'] = len(args['gpus'].split(','))
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args['gpus'].split(',')[0])
logging.warning("n_device: %s, distributed training: %s, 16-bits training: %s",
args['n_device'], bool(args['n_device'] > 1), args['fp16'])
labels = get_labels(args['labels'])
num_labels = len(labels) + 1
pad_token_label_id = 0
config_class, model_class, tokenizer_class = MODEL_CLASSES[args['model_type']]
config = config_class.from_pretrained(args['config_name'] if args['config_name'] else args['model_name_or_path'],
num_labels=num_labels,
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
logging.info("Training/evaluation parameters %s", args)
# Training
if args['do_train']:
tokenizer = tokenizer_class.from_pretrained(args['tokenizer_name'] if args['tokenizer_name'] else args['model_name_or_path'],
do_lower_case=args['do_lower_case'],
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
with strategy.scope():
model = model_class.from_pretrained(args['model_name_or_path'],
from_pt=bool(".bin" in args['model_name_or_path']),
config=config,
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
model.layers[-1].activation = tf.keras.activations.softmax
train_batch_size = args['per_device_train_batch_size'] * args['n_device']
train_dataset, num_train_examples = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train")
train_dataset = strategy.experimental_distribute_dataset(train_dataset)
train(args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id)
if not os.path.exists(args['output_dir']):
os.makedirs(args['output_dir'])
logging.info("Saving model to %s", args['output_dir'])
model.save_pretrained(args['output_dir'])
tokenizer.save_pretrained(args['output_dir'])
# Evaluation
if args['do_eval']:
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
checkpoints = []
results = []
if args['eval_all_checkpoints']:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args['output_dir'] + "/**/" + TF2_WEIGHTS_NAME, recursive=True), key=lambda f: int(''.join(filter(str.isdigit, f)) or -1)))
logging.info("Evaluate the following checkpoints: %s", checkpoints)
if len(checkpoints) == 0:
checkpoints.append(args['output_dir'])
for checkpoint in checkpoints:
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
with strategy.scope():
model = model_class.from_pretrained(checkpoint)
y_true, y_pred, eval_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev")
report = metrics.classification_report(y_true, y_pred, digits=4)
if global_step:
results.append({global_step + "_report": report, global_step + "_loss": eval_loss})
output_eval_file = os.path.join(args['output_dir'], "eval_results.txt")
with tf.io.gfile.GFile(output_eval_file, "w") as writer:
for res in results:
for key, val in res.items():
if "loss" in key:
logging.info(key + " = " + str(val))
writer.write(key + " = " + str(val))
writer.write("\n")
else:
logging.info(key)
logging.info("\n" + report)
writer.write(key + "\n")
writer.write(report)
writer.write("\n")
if args['do_predict']:
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
model = model_class.from_pretrained(args['output_dir'])
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
predict_dataset, _ = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test")
y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
output_test_results_file = os.path.join(args['output_dir'], "test_results.txt")
output_test_predictions_file = os.path.join(args['output_dir'], "test_predictions.txt")
report = metrics.classification_report(y_true, y_pred, digits=4)
with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
report = metrics.classification_report(y_true, y_pred, digits=4)
logging.info("\n" + report)
writer.write(report)
writer.write("\n\nloss = " + str(pred_loss))
with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
with tf.io.gfile.GFile(os.path.join(args['data_dir'], "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not y_pred[example_id]:
example_id += 1
elif y_pred[example_id]:
output_line = line.split()[0] + " " + y_pred[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logging.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
flags.mark_flag_as_required("output_dir")
flags.mark_flag_as_required("model_name_or_path")
flags.mark_flag_as_required("model_type")
app.run(main)
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning multi-lingual models on XNLI (Bert, DistilBERT, XLM).
Adapted from `examples/run_glue.py`"""
from __future__ import absolute_import, division, print_function
import argparse
import glob
import logging
import os
import random
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler
try:
from torch.utils.tensorboard import SummaryWriter
except:
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange
from transformers import (WEIGHTS_NAME,
BertConfig, BertForSequenceClassification, BertTokenizer,
XLMConfig, XLMForSequenceClassification, XLMTokenizer,
DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import xnli_compute_metrics as compute_metrics
from transformers import xnli_output_modes as output_modes
from transformers import xnli_processors as processors
from transformers import glue_convert_examples_to_features as convert_examples_to_features
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, DistilBertConfig, XLMConfig)), ())
MODEL_CLASSES = {
'bert': (BertConfig, BertForSequenceClassification, BertTokenizer),
'xlm': (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
'distilbert': (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer)
}
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def train(args, train_dataset, model, tokenizer):
""" Train the model """
if args.local_rank in [-1, 0]:
tb_writer = SummaryWriter()
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
if args.max_steps > 0:
t_total = args.max_steps
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
else:
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay)
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
# multi-gpu training (should be after apex fp16 initialization)
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Distributed training (should be after apex fp16 initialization)
if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True)
# Train!
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_dataset))
logger.info(" Num Epochs = %d", args.num_train_epochs)
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", t_total)
global_step = 0
tr_loss, logging_loss = 0.0, 0.0
model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproductibility (even between python 2 and 3)
for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator):
model.train()
batch = tuple(t.to(args.device) for t in batch)
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'labels': batch[3]}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert'] else None # XLM and DistilBERT don't use segment_ids
outputs = model(**inputs)
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
tr_loss += loss.item()
if (step + 1) % args.gradient_accumulation_steps == 0:
if args.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
scheduler.step() # Update learning rate schedule
model.zero_grad()
global_step += 1
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
# Log metrics
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
results = evaluate(args, model, tokenizer)
for key, value in results.items():
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
logging_loss = tr_loss
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
# Save model checkpoint
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
logger.info("Saving model checkpoint to %s", output_dir)
if args.max_steps > 0 and global_step > args.max_steps:
epoch_iterator.close()
break
if args.max_steps > 0 and global_step > args.max_steps:
train_iterator.close()
break
if args.local_rank in [-1, 0]:
tb_writer.close()
return global_step, tr_loss / global_step
def evaluate(args, model, tokenizer, prefix=""):
eval_task_names = (args.task_name,)
eval_outputs_dirs = (args.output_dir,)
results = {}
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
os.makedirs(eval_output_dir)
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Eval!
logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size)
eval_loss = 0.0
nb_eval_steps = 0
preds = None
out_label_ids = None
for batch in tqdm(eval_dataloader, desc="Evaluating"):
model.eval()
batch = tuple(t.to(args.device) for t in batch)
with torch.no_grad():
inputs = {'input_ids': batch[0],
'attention_mask': batch[1],
'labels': batch[3]}
if args.model_type != 'distilbert':
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert'] else None # XLM and DistilBERT don't use segment_ids
outputs = model(**inputs)
tmp_eval_loss, logits = outputs[:2]
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
if preds is None:
preds = logits.detach().cpu().numpy()
out_label_ids = inputs['labels'].detach().cpu().numpy()
else:
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
out_label_ids = np.append(out_label_ids, inputs['labels'].detach().cpu().numpy(), axis=0)
eval_loss = eval_loss / nb_eval_steps
if args.output_mode == "classification":
preds = np.argmax(preds, axis=1)
else:
raise ValueError('No other `output_mode` for XNLI.')
result = compute_metrics(eval_task, preds, out_label_ids)
results.update(result)
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results {} *****".format(prefix))
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
return results
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
if args.local_rank not in [-1, 0] and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
processor = processors[task](language=args.language, train_language=args.train_language)
output_mode = output_modes[task]
# Load data features from cache or dataset file
cached_features_file = os.path.join(args.data_dir, 'cached_{}_{}_{}_{}_{}'.format(
'test' if evaluate else 'train',
list(filter(None, args.model_name_or_path.split('/'))).pop(),
str(args.max_seq_length),
str(task),
str(args.train_language if (not evaluate and args.train_language is not None) else args.language)))
if os.path.exists(cached_features_file) and not args.overwrite_cache:
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
logger.info("Creating features from dataset file at %s", args.data_dir)
label_list = processor.get_labels()
examples = processor.get_test_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
features = convert_examples_to_features(examples,
tokenizer,
label_list=label_list,
max_length=args.max_seq_length,
output_mode=output_mode,
pad_on_left=False,
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=0,
)
if args.local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
if args.local_rank == 0 and not evaluate:
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
if output_mode == "classification":
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
else:
raise ValueError('No other `output_mode` for XNLI.')
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
return dataset
def main():
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--data_dir", default=None, type=str, required=True,
help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
parser.add_argument("--model_type", default=None, type=str, required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
parser.add_argument("--language", default=None, type=str, required=True,
help="Evaluation language. Also train language if `train_language` is set to None.")
parser.add_argument("--train_language", default=None, type=str,
help="Train language if is different of the evaluation language.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
## Other parameters
parser.add_argument("--config_name", default="", type=str,
help="Pretrained config name or path if not the same as model_name")
parser.add_argument("--tokenizer_name", default="", type=str,
help="Pretrained tokenizer name or path if not the same as model_name")
parser.add_argument("--cache_dir", default="", type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length", default=128, type=int,
help="The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded.")
parser.add_argument("--do_train", action='store_true',
help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true',
help="Whether to run eval on the test set.")
parser.add_argument("--evaluate_during_training", action='store_true',
help="Rul evaluation during training at each logging step.")
parser.add_argument("--do_lower_case", action='store_true',
help="Set this flag if you are using an uncased model.")
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
help="Batch size per GPU/CPU for evaluation.")
parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.")
parser.add_argument("--learning_rate", default=5e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="Weight deay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float,
help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", default=1.0, type=float,
help="Max gradient norm.")
parser.add_argument("--num_train_epochs", default=3.0, type=float,
help="Total number of training epochs to perform.")
parser.add_argument("--max_steps", default=-1, type=int,
help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
parser.add_argument("--warmup_steps", default=0, type=int,
help="Linear warmup over warmup_steps.")
parser.add_argument('--logging_steps', type=int, default=50,
help="Log every X updates steps.")
parser.add_argument('--save_steps', type=int, default=50,
help="Save checkpoint every X updates steps.")
parser.add_argument("--eval_all_checkpoints", action='store_true',
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
parser.add_argument("--no_cuda", action='store_true',
help="Avoid using CUDA when available")
parser.add_argument('--overwrite_output_dir', action='store_true',
help="Overwrite the content of the output directory")
parser.add_argument('--overwrite_cache', action='store_true',
help="Overwrite the cached training and evaluation sets")
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
parser.add_argument('--fp16_opt_level', type=str, default='O1',
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html")
parser.add_argument("--local_rank", type=int, default=-1,
help="For distributed training: local_rank")
parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
args = parser.parse_args()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
# Setup distant debugging if needed
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
# Setup CUDA, GPU & distributed training
if args.local_rank == -1 or args.no_cuda:
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
args.n_gpu = torch.cuda.device_count()
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.cuda.set_device(args.local_rank)
device = torch.device("cuda", args.local_rank)
torch.distributed.init_process_group(backend='nccl')
args.n_gpu = 1
args.device = device
# Setup logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
# Set seed
set_seed(args)
# Prepare XNLI task
args.task_name = 'xnli'
if args.task_name not in processors:
raise ValueError("Task not found: %s" % (args.task_name))
processor = processors[args.task_name](language=args.language, train_language=args.train_language)
args.output_mode = output_modes[args.task_name]
label_list = processor.get_labels()
num_labels = len(label_list)
# Load pretrained model and tokenizer
if args.local_rank not in [-1, 0]:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
args.model_type = args.model_type.lower()
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path,
num_labels=num_labels,
finetuning_task=args.task_name,
cache_dir=args.cache_dir if args.cache_dir else None)
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
do_lower_case=args.do_lower_case,
cache_dir=args.cache_dir if args.cache_dir else None)
model = model_class.from_pretrained(args.model_name_or_path,
from_tf=bool('.ckpt' in args.model_name_or_path),
config=config,
cache_dir=args.cache_dir if args.cache_dir else None)
if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
model.to(args.device)
logger.info("Training/evaluation parameters %s", args)
# Training
if args.do_train:
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
# Create output directory if needed
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
os.makedirs(args.output_dir)
logger.info("Saving model checkpoint to %s", args.output_dir)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
model_to_save.save_pretrained(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
# Load a trained model and vocabulary that you have fine-tuned
model = model_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
model.to(args.device)
# Evaluation
results = {}
if args.do_eval and args.local_rank in [-1, 0]:
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
checkpoints = [args.output_dir]
if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints:
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else ""
model = model_class.from_pretrained(checkpoint)
model.to(args.device)
result = evaluate(args, model, tokenizer, prefix=prefix)
result = dict((k + '_{}'.format(global_step), v) for k, v in result.items())
results.update(result)
return results
if __name__ == "__main__":
main()
# Text Summarization with Pretrained Encoders
This folder contains part of the code necessary to reproduce the results on abstractive summarization from the article [Text Summarization with Pretrained Encoders](https://arxiv.org/pdf/1908.08345.pdf) by [Yang Liu](https://nlp-yang.github.io/) and [Mirella Lapata](https://homepages.inf.ed.ac.uk/mlap/). It can also be used to summarize any document.
The original code can be found on the Yang Liu's [github repository](https://github.com/nlpyang/PreSumm).
The model is loaded with the pre-trained weights for the abstractive summarization model trained on the CNN/Daily Mail dataset with an extractive and then abstractive tasks.
## Setup
```
git clone https://github.com/huggingface/transformers && cd transformers
pip install [--editable] .
pip install nltk py-rouge
cd examples/summarization
```
## Reproduce the authors' results on ROUGE
To be able to reproduce the authors' results on the CNN/Daily Mail dataset you first need to download both CNN and Daily Mail datasets [from Kyunghyun Cho's website](https://cs.nyu.edu/~kcho/DMQA/) (the links next to "Stories") in the same folder. Then uncompress the archives by running:
```bash
tar -xvf cnn_stories.tgz && tar -xvf dailymail_stories.tgz
```
And move all the stories to the same folder. We will refer as `$DATA_PATH` the path to where you uncompressed both archive. Then run the following in the same folder as `run_summarization.py`:
```bash
python run_summarization.py \
--documents_dir $DATA_PATH \
--summaries_output_dir $SUMMARIES_PATH \ # optional
--no_cuda false \
--batch_size 4 \
--min_length 50 \
--max_length 200 \
--beam_size 5 \
--alpha 0.95 \
--block_trigram true \
--compute_rouge true
```
The scripts executes on GPU if one is available and if `no_cuda` is not set to `true`. Inference on multiple GPUs is not suported yet. The ROUGE scores will be displayed in the console at the end of evaluation and written in a `rouge_scores.txt` file. The script takes 30 hours to compute with a single Tesla V100 GPU and a batch size of 10 (300,000 texts to summarize).
## Summarize any text
Put the documents that you would like to summarize in a folder (the path to which is referred to as `$DATA_PATH` below) and run the following in the same folder as `run_summarization.py`:
```bash
python run_summarization.py \
--documents_dir $DATA_PATH \
--summaries_output_dir $SUMMARIES_PATH \ # optional
--no_cuda false \
--batch_size 4 \
--min_length 50 \
--max_length 200 \
--beam_size 5 \
--alpha 0.95 \
--block_trigram true \
```
You may want to play around with `min_length`, `max_length` and `alpha` to suit your use case. If you want to compute ROUGE on another dataset you will need to tweak the stories/summaries import in `utils_summarization.py` and tell it where to fetch the reference summaries.
# coding=utf-8
# Copyright 2019 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BertAbs configuration """
import json
import logging
import sys
from transformers import PretrainedConfig
logger = logging.getLogger(__name__)
BERTABS_FINETUNED_CONFIG_MAP = {
"bertabs-finetuned-cnndm": "https://s3.amazonaws.com/models.huggingface.co/bert/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization-config.json",
}
class BertAbsConfig(PretrainedConfig):
r""" Class to store the configuration of the BertAbs model.
Arguments:
vocab_size: int
Number of tokens in the vocabulary.
max_pos: int
The maximum sequence length that this model will be used with.
enc_layer: int
The numner of hidden layers in the Transformer encoder.
enc_hidden_size: int
The size of the encoder's layers.
enc_heads: int
The number of attention heads for each attention layer in the encoder.
enc_ff_size: int
The size of the encoder's feed-forward layers.
enc_dropout: int
The dropout probabilitiy for all fully connected layers in the
embeddings, layers, pooler and also the attention probabilities in
the encoder.
dec_layer: int
The numner of hidden layers in the decoder.
dec_hidden_size: int
The size of the decoder's layers.
dec_heads: int
The number of attention heads for each attention layer in the decoder.
dec_ff_size: int
The size of the decoder's feed-forward layers.
dec_dropout: int
The dropout probabilitiy for all fully connected layers in the
embeddings, layers, pooler and also the attention probabilities in
the decoder.
"""
pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP
def __init__(
self,
vocab_size=30522,
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
**kwargs,
):
super(BertAbsConfig, self).__init__(**kwargs)
self.vocab_size = vocab_size
self.max_pos = max_pos
self.enc_layers = enc_layers
self.enc_hidden_size = enc_hidden_size
self.enc_heads = enc_heads
self.enc_ff_size = enc_ff_size
self.enc_dropout = enc_dropout
self.dec_layers = dec_layers
self.dec_hidden_size = dec_hidden_size
self.dec_heads = dec_heads
self.dec_ff_size = dec_ff_size
self.dec_dropout = dec_dropout
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Convert BertExtAbs's checkpoints.
The script looks like it is doing something trivial but it is not. The "weights"
proposed by the authors are actually the entire model pickled. We need to load
the model within the original codebase to be able to only save its `state_dict`.
"""
import argparse
from collections import namedtuple
import logging
import torch
from models.model_builder import AbsSummarizer # The authors' implementation
from model_bertabs import BertAbsSummarizer
from transformers import BertTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip'
BertAbsConfig = namedtuple(
"BertAbsConfig",
["temp_dir", "large", "use_bert_emb", "finetune_bert", "encoder", "share_emb", "max_pos", "enc_layers", "enc_hidden_size", "enc_heads", "enc_ff_size", "enc_dropout", "dec_layers", "dec_hidden_size", "dec_heads", "dec_ff_size", "dec_dropout"],
)
def convert_bertabs_checkpoints(path_to_checkpoints, dump_path):
""" Copy/paste and tweak the pre-trained weights provided by the creators
of BertAbs for the internal architecture.
"""
# Instantiate the authors' model with the pre-trained weights
config = BertAbsConfig(
temp_dir=".",
finetune_bert=False,
large=False,
share_emb=True,
use_bert_emb=False,
encoder="bert",
max_pos=512,
enc_layers=6,
enc_hidden_size=512,
enc_heads=8,
enc_ff_size=512,
enc_dropout=0.2,
dec_layers=6,
dec_hidden_size=768,
dec_heads=8,
dec_ff_size=2048,
dec_dropout=0.2,
)
checkpoints = torch.load(path_to_checkpoints, lambda storage, loc: storage)
original = AbsSummarizer(config, torch.device("cpu"), checkpoints)
original.eval()
new_model = BertAbsSummarizer(config, torch.device("cpu"))
new_model.eval()
# -------------------
# Convert the weights
# -------------------
logging.info("convert the model")
new_model.bert.load_state_dict(original.bert.state_dict())
new_model.decoder.load_state_dict(original.decoder.state_dict())
new_model.generator.load_state_dict(original.generator.state_dict())
# ----------------------------------
# Make sure the outpus are identical
# ----------------------------------
logging.info("Make sure that the models' outputs are identical")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# prepare the model inputs
encoder_input_ids = tokenizer.encode("This is sample éàalj'-.")
encoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(encoder_input_ids)))
encoder_input_ids = torch.tensor(encoder_input_ids).unsqueeze(0)
decoder_input_ids = tokenizer.encode("This is sample 3 éàalj'-.")
decoder_input_ids.extend([tokenizer.pad_token_id] * (512 - len(decoder_input_ids)))
decoder_input_ids = torch.tensor(decoder_input_ids).unsqueeze(0)
# failsafe to make sure the weights reset does not affect the
# loaded weights.
assert torch.max(torch.abs(original.generator[0].weight - new_model.generator[0].weight)) == 0
# forward pass
src = encoder_input_ids
tgt = decoder_input_ids
segs = token_type_ids = None
clss = None
mask_src = encoder_attention_mask = None
mask_tgt = decoder_attention_mask = None
mask_cls = None
# The original model does not apply the geneator layer immediatly but rather in
# the beam search (where it combines softmax + linear layer). Since we already
# apply the softmax in our generation process we only apply the linear layer here.
# We make sure that the outputs of the full stack are identical
output_original_model = original(src, tgt, segs, clss, mask_src, mask_tgt, mask_cls)[0]
output_original_generator = original.generator(output_original_model)
output_converted_model = new_model(encoder_input_ids, decoder_input_ids, token_type_ids, encoder_attention_mask, decoder_attention_mask)[0]
output_converted_generator = new_model.generator(output_converted_model)
maximum_absolute_difference = torch.max(torch.abs(output_converted_model - output_original_model)).item()
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
maximum_absolute_difference = torch.max(torch.abs(output_converted_generator - output_original_generator)).item()
print("Maximum absolute difference beween weights: {:.2f}".format(maximum_absolute_difference))
are_identical = torch.allclose(output_converted_model, output_original_model, atol=1e-3)
if are_identical:
logging.info("all weights are equal up to 1e-3")
else:
raise ValueError("the weights are different. The new model is likely different from the original one.")
# The model has been saved with torch.save(model) and this is bound to the exact
# directory structure. We save the state_dict instead.
logging.info("saving the model's state dictionary")
torch.save(new_model.state_dict(), "bertabs-finetuned-cnndm-extractive-abstractive-summarization-pytorch_model.bin")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--bertabs_checkpoint_path",
default=None,
type=str,
required=True,
help="Path the official PyTorch dump.",
)
parser.add_argument(
"--pytorch_dump_folder_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.",
)
args = parser.parse_args()
convert_bertabs_checkpoints(
args.bertabs_checkpoint_path,
args.pytorch_dump_folder_path,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment