Unverified Commit 03cdb2a3 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #254 from huggingface/python_2

Adding OpenAI GPT and Transformer-XL models, compatibility with Python 2
parents 2dfaf2f2 1e71f11d
version: 2 version: 2
jobs: jobs:
build: build_py3:
working_directory: ~/pytorch-pretrained-BERT working_directory: ~/pytorch-pretrained-BERT
docker: docker:
- image: circleci/python:3.7 - image: circleci/python:3.5
steps: steps:
- checkout - checkout
- run: sudo pip install --progress-bar off . - run: sudo pip install --progress-bar off .
- run: sudo pip install pytest - run: sudo pip install pytest ftfy spacy
- run: sudo python -m spacy download en
- run: python -m pytest -sv tests/ - run: python -m pytest -sv tests/
build_py2:
working_directory: ~/pytorch-pretrained-BERT
docker:
- image: circleci/python:2.7
steps:
- checkout
- run: sudo pip install --progress-bar off .
- run: sudo pip install pytest spacy
- run: sudo pip install ftfy==4.4.3
- run: sudo python -m spacy download en
- run: python -m pytest -sv tests/
workflows:
version: 2
build_and_test:
jobs:
- build_py3
- build_py2
\ No newline at end of file
This diff is collapsed.
...@@ -15,26 +15,26 @@ ...@@ -15,26 +15,26 @@
# limitations under the License. # limitations under the License.
"""BERT finetuning runner.""" """BERT finetuning runner."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse
import csv import csv
import os
import logging import logging
import argparse import os
import random import random
from tqdm import tqdm, trange import sys
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -91,10 +91,12 @@ class DataProcessor(object): ...@@ -91,10 +91,12 @@ class DataProcessor(object):
@classmethod @classmethod
def _read_tsv(cls, input_file, quotechar=None): def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "r", encoding='utf-8') as f: with open(input_file, "r") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = [] lines = []
for line in reader: for line in reader:
if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line)
lines.append(line) lines.append(line)
return lines return lines
...@@ -321,6 +323,10 @@ def main(): ...@@ -321,6 +323,10 @@ def main():
help="The output directory where the model predictions and checkpoints will be written.") help="The output directory where the model predictions and checkpoints will be written.")
## Other parameters ## Other parameters
parser.add_argument("--cache_dir",
default="",
type=str,
help="Where do you want to store the pre-trained models downloaded from s3")
parser.add_argument("--max_seq_length", parser.add_argument("--max_seq_length",
default=128, default=128,
type=int, type=int,
...@@ -380,9 +386,17 @@ def main(): ...@@ -380,9 +386,17 @@ def main():
help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n" help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
"0 (default value): dynamic loss scaling.\n" "0 (default value): dynamic loss scaling.\n"
"Positive power of 2: static loss scaling value.\n") "Positive power of 2: static loss scaling value.\n")
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
processors = { processors = {
"cola": ColaProcessor, "cola": ColaProcessor,
"mnli": MnliProcessor, "mnli": MnliProcessor,
...@@ -424,7 +438,8 @@ def main(): ...@@ -424,7 +438,8 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True) if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
task_name = args.task_name.lower() task_name = args.task_name.lower()
...@@ -447,8 +462,9 @@ def main(): ...@@ -447,8 +462,9 @@ def main():
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size() num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
# Prepare model # Prepare model
cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank))
model = BertForSequenceClassification.from_pretrained(args.bert_model, model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank), cache_dir=cache_dir,
num_labels = num_labels) num_labels = num_labels)
if args.fp16: if args.fp16:
model.half() model.half()
...@@ -545,15 +561,21 @@ def main(): ...@@ -545,15 +561,21 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if args.do_train: if args.do_train:
# Save a trained model and the associated configuration
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
# Load a trained model that you have fine-tuned with open(output_config_file, 'w') as f:
model_state_dict = torch.load(output_model_file) f.write(model_to_save.config.to_json_string())
model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict, num_labels=num_labels)
# Load a trained model and config that you have fine-tuned
config = BertConfig(output_config_file)
model = BertForSequenceClassification(config, num_labels=num_labels)
model.load_state_dict(torch.load(output_model_file))
else:
model = BertForSequenceClassification.from_pretrained(args.bert_model, num_labels=num_labels)
model.to(device) model.to(device)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
......
...@@ -15,22 +15,22 @@ ...@@ -15,22 +15,22 @@
# limitations under the License. # limitations under the License.
"""BERT finetuning runner.""" """BERT finetuning runner."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function, unicode_literals
from __future__ import division
from __future__ import print_function
import os
import logging
import argparse import argparse
from tqdm import tqdm, trange import logging
import os
import random
from io import open
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForPreTraining from pytorch_pretrained_bert.modeling import BertForPreTraining
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from torch.utils.data import Dataset from torch.utils.data import Dataset
...@@ -179,16 +179,16 @@ class BERTDataset(Dataset): ...@@ -179,16 +179,16 @@ class BERTDataset(Dataset):
if self.line_buffer is None: if self.line_buffer is None:
# read first non-empty line of file # read first non-empty line of file
while t1 == "" : while t1 == "" :
t1 = self.file.__next__().strip() t1 = next(self.file).strip()
t2 = self.file.__next__().strip() t2 = next(self.file).strip()
else: else:
# use t2 from previous iteration as new t1 # use t2 from previous iteration as new t1
t1 = self.line_buffer t1 = self.line_buffer
t2 = self.file.__next__().strip() t2 = next(self.file).strip()
# skip empty rows that are used for separating documents and keep track of current doc id # skip empty rows that are used for separating documents and keep track of current doc id
while t2 == "" or t1 == "": while t2 == "" or t1 == "":
t1 = self.file.__next__().strip() t1 = next(self.file).strip()
t2 = self.file.__next__().strip() t2 = next(self.file).strip()
self.current_doc = self.current_doc+1 self.current_doc = self.current_doc+1
self.line_buffer = t2 self.line_buffer = t2
...@@ -222,15 +222,15 @@ class BERTDataset(Dataset): ...@@ -222,15 +222,15 @@ class BERTDataset(Dataset):
def get_next_line(self): def get_next_line(self):
""" Gets next line of random_file and starts over when reaching end of file""" """ Gets next line of random_file and starts over when reaching end of file"""
try: try:
line = self.random_file.__next__().strip() line = next(self.random_file).strip()
#keep track of which document we are currently looking at to later avoid having the same doc as t1 #keep track of which document we are currently looking at to later avoid having the same doc as t1
if line == "": if line == "":
self.current_random_doc = self.current_random_doc + 1 self.current_random_doc = self.current_random_doc + 1
line = self.random_file.__next__().strip() line = next(self.random_file).strip()
except StopIteration: except StopIteration:
self.random_file.close() self.random_file.close()
self.random_file = open(self.corpus_path, "r", encoding=self.encoding) self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
line = self.random_file.__next__().strip() line = next(self.random_file).strip()
return line return line
...@@ -419,6 +419,7 @@ def main(): ...@@ -419,6 +419,7 @@ def main():
help="The output directory where the model checkpoints will be written.") help="The output directory where the model checkpoints will be written.")
## Other parameters ## Other parameters
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--max_seq_length", parser.add_argument("--max_seq_length",
default=128, default=128,
type=int, type=int,
...@@ -506,7 +507,8 @@ def main(): ...@@ -506,7 +507,8 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir): if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True) if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
...@@ -575,7 +577,7 @@ def main(): ...@@ -575,7 +577,7 @@ def main():
if args.local_rank == -1: if args.local_rank == -1:
train_sampler = RandomSampler(train_dataset) train_sampler = RandomSampler(train_dataset)
else: else:
#TODO: check if this works with current data generator from disk that relies on file.__next__ #TODO: check if this works with current data generator from disk that relies on next(file)
# (it doesn't return item back by index) # (it doesn't return item back by index)
train_sampler = DistributedSampler(train_dataset) train_sampler = DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
......
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" OpenAI GPT model fine-tuning script.
Adapted from https://github.com/huggingface/pytorch-openai-transformer-lm/blob/master/train.py
It self adapted from https://github.com/openai/finetune-transformer-lm/blob/master/train.py
This script with default values fine-tunes and evaluate a pretrained OpenAI GPT on the RocStories dataset
"""
import argparse
import os
import csv
import random
import logging
from tqdm import tqdm, trange
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from pytorch_pretrained_bert import OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, OpenAIAdam, cached_path
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
def accuracy(out, labels):
outputs = np.argmax(out, axis=1)
return np.sum(outputs == labels)
def load_rocstories_dataset(dataset_path):
""" Output a list of tuples(story, 1st continuation, 2nd continuation, label) """
with open(dataset_path, encoding='utf_8') as f:
f = csv.reader(f)
output = []
next(f) # skip the first line
for line in tqdm(f):
output.append((' '.join(line[1:5]), line[5], line[6], int(line[-1])-1))
return output
def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token):
""" Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label)
To Transformer inputs of shape (n_batch, n_alternative, length) comprising for each batch, continuation:
input_ids[batch, alternative, :] = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
"""
tensor_datasets = []
for dataset in encoded_datasets:
n_batch = len(dataset)
input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64)
mc_token_ids = np.zeros((n_batch, 2), dtype=np.int64)
lm_labels = np.full((n_batch, 2, input_len), fill_value=-1, dtype=np.int64)
mc_labels = np.zeros((n_batch,), dtype=np.int64)
for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token]
input_ids[i, 0, :len(with_cont1)] = with_cont1
input_ids[i, 1, :len(with_cont2)] = with_cont2
mc_token_ids[i, 0] = len(with_cont1) - 1
mc_token_ids[i, 1] = len(with_cont2) - 1
lm_labels[i, 0, :len(with_cont1)-1] = with_cont1[1:]
lm_labels[i, 1, :len(with_cont2)-1] = with_cont2[1:]
mc_labels[i] = mc_label
all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
return tensor_datasets
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='openai-gpt',
help='pretrained model name')
parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument('--train_dataset', type=str, default='')
parser.add_argument('--eval_dataset', type=str, default='')
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_train_epochs', type=int, default=3)
parser.add_argument('--train_batch_size', type=int, default=8)
parser.add_argument('--eval_batch_size', type=int, default=16)
parser.add_argument('--max_grad_norm', type=int, default=1)
parser.add_argument('--learning_rate', type=float, default=6.25e-5)
parser.add_argument('--warmup_proportion', type=float, default=0.002)
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--lm_coef', type=float, default=0.9)
parser.add_argument('--n_valid', type=int, default=374)
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args()
print(args)
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info("device: {}, n_gpu {}".format(device, n_gpu))
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
# Load tokenizer and model
# This loading functions also add new tokens and embeddings called `special tokens`
# These new embeddings will be fine-tuned on the RocStories dataset
special_tokens = ['_start_', '_delimiter_', '_classify_']
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name, special_tokens=special_tokens)
special_tokens_ids = list(tokenizer.convert_tokens_to_ids(token) for token in special_tokens)
model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_name, num_special_tokens=len(special_tokens))
model.to(device)
# Load and encode the datasets
if not args.train_dataset and not args.eval_dataset:
roc_stories = cached_path(ROCSTORIES_URL)
def tokenize_and_encode(obj):
""" Tokenize and encode a nested object """
if isinstance(obj, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
elif isinstance(obj, int):
return obj
return list(tokenize_and_encode(o) for o in obj)
logger.info("Encoding dataset...")
train_dataset = load_rocstories_dataset(args.train_dataset)
eval_dataset = load_rocstories_dataset(args.eval_dataset)
datasets = (train_dataset, eval_dataset)
encoded_datasets = tokenize_and_encode(datasets)
# Compute the mex input length for the Transformer
max_length = model.config.n_positions // 2 - 2
input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \
for dataset in encoded_datasets for story, cont1, cont2, _ in dataset)
input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model
# Prepare inputs tensors and dataloaders
tensor_datasets = pre_process_datasets(encoded_datasets, input_length, max_length, *special_tokens_ids)
train_tensor_dataset, eval_tensor_dataset = tensor_datasets[0], tensor_datasets[1]
train_data = TensorDataset(*train_tensor_dataset)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
eval_data = TensorDataset(*eval_tensor_dataset)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
# Prepare optimizer
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
num_train_optimization_steps = len(train_data) * args.num_train_epochs // args.train_batch_size
optimizer = OpenAIAdam(optimizer_grouped_parameters,
lr=args.learning_rate,
warmup=args.warmup_proportion,
max_grad_norm=args.max_grad_norm,
weight_decay=args.weight_decay,
t_total=num_train_optimization_steps)
if args.do_train:
nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
tr_loss = 0
nb_tr_steps = 0
tqdm_bar = tqdm(train_dataloader, desc="Training")
for step, batch in enumerate(tqdm_bar):
batch = tuple(t.to(device) for t in batch)
input_ids, mc_token_ids, lm_labels, mc_labels = batch
losses = model(input_ids, mc_token_ids, lm_labels, mc_labels)
loss = args.lm_coef * losses[0] + losses[1]
loss.backward()
optimizer.step()
tr_loss += loss.item()
exp_average_loss = loss.item() if exp_average_loss is None else 0.7*exp_average_loss+0.3*loss.item()
nb_tr_steps += 1
tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, optimizer.get_lr()[0])
# Save a trained model
if args.do_train:
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
config = model.config
torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file)
model = OpenAIGPTDoubleHeadsModel(config)
model.load_state_dict(model_state_dict)
model.to(device)
if args.do_eval:
model.eval()
eval_loss, eval_accuracy = 0, 0
nb_eval_steps, nb_eval_examples = 0, 0
for batch in tqdm(eval_dataloader, desc="Evaluating"):
batch = tuple(t.to(device) for t in batch)
input_ids, mc_token_ids, lm_labels, mc_labels = batch
with torch.no_grad():
_, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels)
_, mc_logits = model(input_ids, mc_token_ids)
mc_logits = mc_logits.detach().cpu().numpy()
mc_labels = mc_labels.to('cpu').numpy()
tmp_eval_accuracy = accuracy(mc_logits, mc_labels)
eval_loss += mc_loss.mean().item()
eval_accuracy += tmp_eval_accuracy
nb_eval_examples += input_ids.size(0)
nb_eval_steps += 1
eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples
train_loss = tr_loss/nb_tr_steps if args.do_train else None
result = {'eval_loss': eval_loss,
'eval_accuracy': eval_accuracy,
'train_loss': train_loss}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
for key in sorted(result.keys()):
logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == '__main__':
main()
...@@ -15,29 +15,36 @@ ...@@ -15,29 +15,36 @@
# limitations under the License. # limitations under the License.
"""Run BERT on SQuAD.""" """Run BERT on SQuAD."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse import argparse
import collections import collections
import logging
import json import json
import logging
import math import math
import os import os
import random import random
import pickle import sys
from tqdm import tqdm, trange from io import open
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
BertTokenizer,
whitespace_tokenize)
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -863,7 +870,8 @@ def main(): ...@@ -863,7 +870,8 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
raise ValueError("Output directory () already exists and is not empty.") raise ValueError("Output directory () already exists and is not empty.")
os.makedirs(args.output_dir, exist_ok=True) if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
...@@ -879,7 +887,7 @@ def main(): ...@@ -879,7 +887,7 @@ def main():
# Prepare model # Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model, model = BertForQuestionAnswering.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank)) cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)))
if args.fp16: if args.fp16:
model.half() model.half()
...@@ -909,7 +917,7 @@ def main(): ...@@ -909,7 +917,7 @@ def main():
if args.fp16: if args.fp16:
try: try:
from apex.optimizer import FP16_Optimizer from apex.optimizers import FP16_Optimizer
from apex.optimizers import FusedAdam from apex.optimizers import FusedAdam
except ImportError: except ImportError:
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
...@@ -993,14 +1001,19 @@ def main(): ...@@ -993,14 +1001,19 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if args.do_train: if args.do_train:
# Save a trained model and the associated configuration
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
model_state_dict = torch.load(output_model_file) with open(output_config_file, 'w') as f:
model = BertForQuestionAnswering.from_pretrained(args.bert_model, state_dict=model_state_dict) f.write(model_to_save.config.to_json_string())
# Load a trained model and config that you have fine-tuned
config = BertConfig(output_config_file)
model = BertForQuestionAnswering(config)
model.load_state_dict(torch.load(output_model_file))
else: else:
model = BertForQuestionAnswering.from_pretrained(args.bert_model) model = BertForQuestionAnswering.from_pretrained(args.bert_model)
......
...@@ -15,22 +15,25 @@ ...@@ -15,22 +15,25 @@
# limitations under the License. # limitations under the License.
"""BERT finetuning runner.""" """BERT finetuning runner."""
import argparse
import csv
import logging import logging
import os import os
import argparse
import random import random
from tqdm import tqdm, trange import sys
import csv from io import open
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForMultipleChoice from pytorch_pretrained_bert.modeling import BertForMultipleChoice
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.tokenization import BertTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -65,17 +68,17 @@ class SwagExample(object): ...@@ -65,17 +68,17 @@ class SwagExample(object):
def __repr__(self): def __repr__(self):
l = [ l = [
f"swag_id: {self.swag_id}", "swag_id: {}".format(self.swag_id),
f"context_sentence: {self.context_sentence}", "context_sentence: {}".format(self.context_sentence),
f"start_ending: {self.start_ending}", "start_ending: {}".format(self.start_ending),
f"ending_0: {self.endings[0]}", "ending_0: {}".format(self.endings[0]),
f"ending_1: {self.endings[1]}", "ending_1: {}".format(self.endings[1]),
f"ending_2: {self.endings[2]}", "ending_2: {}".format(self.endings[2]),
f"ending_3: {self.endings[3]}", "ending_3: {}".format(self.endings[3]),
] ]
if self.label is not None: if self.label is not None:
l.append(f"label: {self.label}") l.append("label: {}".format(self.label))
return ", ".join(l) return ", ".join(l)
...@@ -102,7 +105,11 @@ class InputFeatures(object): ...@@ -102,7 +105,11 @@ class InputFeatures(object):
def read_swag_examples(input_file, is_training): def read_swag_examples(input_file, is_training):
with open(input_file, 'r', encoding='utf-8') as f: with open(input_file, 'r', encoding='utf-8') as f:
reader = csv.reader(f) reader = csv.reader(f)
lines = list(reader) lines = []
for line in reader:
if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line)
lines.append(line)
if is_training and lines[0][-1] != 'label': if is_training and lines[0][-1] != 'label':
raise ValueError( raise ValueError(
...@@ -184,15 +191,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -184,15 +191,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
label = example.label label = example.label
if example_index < 5: if example_index < 5:
logger.info("*** Example ***") logger.info("*** Example ***")
logger.info(f"swag_id: {example.swag_id}") logger.info("swag_id: {}".format(example.swag_id))
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features): for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
logger.info(f"choice: {choice_idx}") logger.info("choice: {}".format(choice_idx))
logger.info(f"tokens: {' '.join(tokens)}") logger.info("tokens: {}".format(' '.join(tokens)))
logger.info(f"input_ids: {' '.join(map(str, input_ids))}") logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
logger.info(f"input_mask: {' '.join(map(str, input_mask))}") logger.info("input_mask: {}".format(' '.join(map(str, input_mask))))
logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}") logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids))))
if is_training: if is_training:
logger.info(f"label: {label}") logger.info("label: {}".format(label))
features.append( features.append(
InputFeatures( InputFeatures(
...@@ -344,7 +351,8 @@ def main(): ...@@ -344,7 +351,8 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir): if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True) if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
...@@ -359,7 +367,7 @@ def main(): ...@@ -359,7 +367,7 @@ def main():
# Prepare model # Prepare model
model = BertForMultipleChoice.from_pretrained(args.bert_model, model = BertForMultipleChoice.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank), cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)),
num_choices=4) num_choices=4)
if args.fp16: if args.fp16:
model.half() model.half()
...@@ -461,18 +469,25 @@ def main(): ...@@ -461,18 +469,25 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
# Save a trained model
if args.do_train:
# Save a trained model and the associated configuration
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin") output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
# Load a trained model that you have fine-tuned with open(output_config_file, 'w') as f:
model_state_dict = torch.load(output_model_file) f.write(model_to_save.config.to_json_string())
model = BertForMultipleChoice.from_pretrained(args.bert_model,
state_dict=model_state_dict, # Load a trained model and config that you have fine-tuned
num_choices=4) config = BertConfig(output_config_file)
model = BertForMultipleChoice(config, num_choices=4)
model.load_state_dict(torch.load(output_model_file))
else:
model = BertForMultipleChoice.from_pretrained(args.bert_model, num_choices=4)
model.to(device) model.to(device)
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0): if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True) eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
......
# coding=utf-8
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HugginFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Transformer XL model evaluation script.
Adapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
This script with default values evaluates a pretrained Transformer-XL on WikiText 103
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import logging
import time
import math
import torch
from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103',
help='pretrained model name')
parser.add_argument('--split', type=str, default='test',
choices=['all', 'valid', 'test'],
help='which split to evaluate')
parser.add_argument('--batch_size', type=int, default=10,
help='batch size')
parser.add_argument('--tgt_len', type=int, default=128,
help='number of tokens to predict')
parser.add_argument('--ext_len', type=int, default=0,
help='length of the extended context')
parser.add_argument('--mem_len', type=int, default=1600,
help='length of the retained previous heads')
parser.add_argument('--clamp_len', type=int, default=1000,
help='max positional embedding index')
parser.add_argument('--no_cuda', action='store_true',
help='Do not use CUDA even though CUA is available')
parser.add_argument('--work_dir', type=str, required=True,
help='path to the work_dir')
parser.add_argument('--no_log', action='store_true',
help='do not log the eval result')
parser.add_argument('--same_length', action='store_true',
help='set same length attention with masking')
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args()
assert args.ext_len >= 0, 'extended context length must be non-negative'
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
logger.info("device: {}".format(device))
# Load a pre-processed dataset
# You can also build the corpus yourself using TransfoXLCorpus methods
# The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
# and tokenizing the dataset
# The pre-processed corpus is a convertion (using the conversion script )
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
ntokens = len(corpus.vocab)
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
# Load a pre-trained model
model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
model = model.to(device)
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
if args.clamp_len > 0:
model.clamp_len = args.clamp_len
if args.same_length:
model.same_length = True
###############################################################################
# Evaluation code
###############################################################################
def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout.
model.eval()
total_len, total_loss = 0, 0.
start_time = time.time()
with torch.no_grad():
mems = None
for idx, (data, target, seq_len) in enumerate(eval_iter):
ret = model(data, target, mems)
loss, mems = ret
loss = loss.mean()
total_loss += seq_len * loss.item()
total_len += seq_len
total_time = time.time() - start_time
logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
total_time, 1000 * total_time / (idx+1)))
return total_loss / total_len
# Run on test data.
if args.split == 'all':
test_loss = evaluate(te_iter)
valid_loss = evaluate(va_iter)
elif args.split == 'valid':
valid_loss = evaluate(va_iter)
test_loss = None
elif args.split == 'test':
test_loss = evaluate(te_iter)
valid_loss = None
def format_log(loss, split):
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
split, loss, math.exp(loss))
return log_str
log_str = ''
if valid_loss is not None:
log_str += format_log(valid_loss, 'valid')
if test_loss is not None:
log_str += format_log(test_loss, 'test')
logger.info('=' * 100)
logger.info(log_str)
logger.info('=' * 100)
if __name__ == '__main__':
main()
__version__ = "0.4.0" __version__ = "0.5.0"
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .tokenization_openai import OpenAIGPTTokenizer
from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus)
from .modeling import (BertConfig, BertModel, BertForPreTraining, from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction, BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice, BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering) BertForTokenClassification, BertForQuestionAnswering,
load_tf_weights_in_bert)
from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel,
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt)
from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl)
from .optimization import BertAdam from .optimization import BertAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE from .optimization_openai import OpenAIAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path
# coding: utf8 # coding: utf8
def main(): def main():
import sys import sys
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
"convert_tf_checkpoint_to_pytorch",
"convert_openai_checkpoint",
"convert_transfo_xl_checkpoint"
]:
print(
"Should be used as one of: \n"
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]` or \n"
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
else:
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
try: try:
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ModuleNotFoundError: except ImportError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see " "In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.") "https://www.tensorflow.org/install/ for installation instructions.")
...@@ -17,6 +29,37 @@ def main(): ...@@ -17,6 +29,37 @@ def main():
TF_CONFIG = sys.argv.pop() TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop() TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "convert_openai_checkpoint":
from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3]
if len(sys.argv) == 5:
OPENAI_GPT_CONFIG = sys.argv[4]
else:
OPENAI_GPT_CONFIG = ""
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
OPENAI_GPT_CONFIG,
PYTORCH_DUMP_OUTPUT)
else:
try:
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
except ImportError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if 'ckpt' in sys.argv[2].lower():
TF_CHECKPOINT = sys.argv[2]
TF_DATASET_FILE = ""
else:
TF_DATASET_FILE = sys.argv[2]
TF_CHECKPOINT = ""
PYTORCH_DUMP_OUTPUT = sys.argv[3]
if len(sys.argv) == 5:
TF_CONFIG = sys.argv[4]
else:
TF_CONFIG = ""
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
# coding=utf-8
# Copyright 2018 The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert OpenAI GPT checkpoint."""
from __future__ import absolute_import, division, print_function
import argparse
from io import open
import torch
from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
OpenAIGPTConfig,
OpenAIGPTModel,
load_tf_weights_in_openai_gpt)
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
# Construct model
if openai_config_file == "":
config = OpenAIGPTConfig()
else:
config = OpenAIGPTConfig(openai_config_file)
model = OpenAIGPTModel(config)
# Load weights from numpy
load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path)
# Save pytorch-model
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--openai_checkpoint_folder_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
parser.add_argument("--openai_config_file",
default = "",
type = str,
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
"This specifies the model architecture.")
args = parser.parse_args()
convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
args.openai_config_file,
args.pytorch_dump_folder_path)
...@@ -25,62 +25,16 @@ import tensorflow as tf ...@@ -25,62 +25,16 @@ import tensorflow as tf
import torch import torch
import numpy as np import numpy as np
from .modeling import BertConfig, BertForPreTraining from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
config_path = os.path.abspath(bert_config_file)
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
# Initialise PyTorch model # Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file) config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config))) print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config) model = BertForPreTraining(config)
for name, array in zip(names, arrays): # Load weights from tf checkpoint
name = name.split('/') load_tf_weights_in_bert(model, tf_checkpoint_path)
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m", "global_step"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
# Save pytorch-model # Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path)) print("Save PyTorch model to {}".format(pytorch_dump_path))
......
# coding=utf-8
# Copyright 2018 The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert Transformer XL checkpoint and datasets."""
from __future__ import absolute_import, division, print_function
import argparse
import os
import sys
from io import open
import torch
import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
WEIGHTS_NAME,
TransfoXLConfig,
TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl)
from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
VOCAB_NAME)
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
# We do this to be able to load python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
data_utils.Vocab = data_utils.TransfoXLTokenizer
data_utils.Corpus = data_utils.TransfoXLCorpus
sys.modules['data_utils'] = data_utils
sys.modules['vocabulary'] = data_utils
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
transfo_xl_config_file,
pytorch_dump_folder_path,
transfo_xl_dataset_file):
if transfo_xl_dataset_file:
# Convert a pre-processed corpus (see original TensorFlow repo)
with open(transfo_xl_dataset_file, "rb") as fp:
corpus = pickle.load(fp, encoding="latin1")
# Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME
print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
corpus_vocab_dict = corpus.vocab.__dict__
torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)
corpus_dict_no_vocab = corpus.__dict__
corpus_dict_no_vocab.pop('vocab', None)
pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
print("Save dataset to {}".format(pytorch_dataset_dump_path))
torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)
if tf_checkpoint_path:
# Convert a pre-trained TensorFlow model
config_path = os.path.abspath(transfo_xl_config_file)
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path))
# Initialise PyTorch model
if transfo_xl_config_file == "":
config = TransfoXLConfig()
else:
config = TransfoXLConfig(transfo_xl_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = TransfoXLLMHeadModel(config)
model = load_tf_weights_in_transfo_xl(model, config, tf_path)
# Save pytorch-model
pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME)
pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME)
print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path)))
torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path)))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the folder to store the PyTorch model or dataset/vocab.")
parser.add_argument("--tf_checkpoint_path",
default = "",
type = str,
help = "An optional path to a TensorFlow checkpoint path to be converted.")
parser.add_argument("--transfo_xl_config_file",
default = "",
type = str,
help = "An optional config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--transfo_xl_dataset_file",
default = "",
type = str,
help = "An optional dataset file to be converted in a vocabulary.")
args = parser.parse_args()
convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.transfo_xl_config_file,
args.pytorch_dump_folder_path,
args.transfo_xl_dataset_file)
...@@ -3,31 +3,40 @@ Utilities for working with the local dataset cache. ...@@ -3,31 +3,40 @@ Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors. Copyright by the AllenNLP authors.
""" """
from __future__ import (absolute_import, division, print_function, unicode_literals)
import os import json
import logging import logging
import os
import shutil import shutil
import tempfile import tempfile
import json
from urllib.parse import urlparse
from pathlib import Path
from typing import Optional, Tuple, Union, IO, Callable, Set
from hashlib import sha256
from functools import wraps from functools import wraps
from hashlib import sha256
from tqdm import tqdm import sys
from io import open
import boto3 import boto3
from botocore.exceptions import ClientError
import requests import requests
from botocore.exceptions import ClientError
from tqdm import tqdm
logger = logging.getLogger(__name__) # pylint: disable=invalid-name try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert')) Path.home() / '.pytorch_pretrained_bert'))
except AttributeError:
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def url_to_filename(url: str, etag: str = None) -> str: def url_to_filename(url, etag=None):
""" """
Convert `url` into a hashed filename in a repeatable way. Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited If `etag` is specified, append its hash to the url's, delimited
...@@ -45,25 +54,25 @@ def url_to_filename(url: str, etag: str = None) -> str: ...@@ -45,25 +54,25 @@ def url_to_filename(url: str, etag: str = None) -> str:
return filename return filename
def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: def filename_to_url(filename, cache_dir=None):
""" """
Return the url and etag (which may be ``None``) stored for `filename`. Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path): if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename) cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
raise FileNotFoundError("file {} not found".format(cache_path)) raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + '.json' meta_path = cache_path + '.json'
if not os.path.exists(meta_path): if not os.path.exists(meta_path):
raise FileNotFoundError("file {} not found".format(meta_path)) raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path) as meta_file: with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file) metadata = json.load(meta_file)
url = metadata['url'] url = metadata['url']
etag = metadata['etag'] etag = metadata['etag']
...@@ -71,7 +80,7 @@ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[ ...@@ -71,7 +80,7 @@ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[
return url, etag return url, etag
def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: def cached_path(url_or_filename, cache_dir=None):
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
...@@ -80,9 +89,9 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = ...@@ -80,9 +89,9 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(url_or_filename, Path): if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename) url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path): if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename) parsed = urlparse(url_or_filename)
...@@ -95,13 +104,13 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = ...@@ -95,13 +104,13 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
return url_or_filename return url_or_filename
elif parsed.scheme == '': elif parsed.scheme == '':
# File, but it doesn't exist. # File, but it doesn't exist.
raise FileNotFoundError("file {} not found".format(url_or_filename)) raise EnvironmentError("file {} not found".format(url_or_filename))
else: else:
# Something unknown # Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url: str) -> Tuple[str, str]: def split_s3_path(url):
"""Split a full s3 path into the bucket name and path.""" """Split a full s3 path into the bucket name and path."""
parsed = urlparse(url) parsed = urlparse(url)
if not parsed.netloc or not parsed.path: if not parsed.netloc or not parsed.path:
...@@ -114,19 +123,19 @@ def split_s3_path(url: str) -> Tuple[str, str]: ...@@ -114,19 +123,19 @@ def split_s3_path(url: str) -> Tuple[str, str]:
return bucket_name, s3_path return bucket_name, s3_path
def s3_request(func: Callable): def s3_request(func):
""" """
Wrapper function for s3 requests in order to create more helpful error Wrapper function for s3 requests in order to create more helpful error
messages. messages.
""" """
@wraps(func) @wraps(func)
def wrapper(url: str, *args, **kwargs): def wrapper(url, *args, **kwargs):
try: try:
return func(url, *args, **kwargs) return func(url, *args, **kwargs)
except ClientError as exc: except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404: if int(exc.response["Error"]["Code"]) == 404:
raise FileNotFoundError("file {} not found".format(url)) raise EnvironmentError("file {} not found".format(url))
else: else:
raise raise
...@@ -134,7 +143,7 @@ def s3_request(func: Callable): ...@@ -134,7 +143,7 @@ def s3_request(func: Callable):
@s3_request @s3_request
def s3_etag(url: str) -> Optional[str]: def s3_etag(url):
"""Check ETag on S3 object.""" """Check ETag on S3 object."""
s3_resource = boto3.resource("s3") s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url) bucket_name, s3_path = split_s3_path(url)
...@@ -143,14 +152,14 @@ def s3_etag(url: str) -> Optional[str]: ...@@ -143,14 +152,14 @@ def s3_etag(url: str) -> Optional[str]:
@s3_request @s3_request
def s3_get(url: str, temp_file: IO) -> None: def s3_get(url, temp_file):
"""Pull a file directly from S3.""" """Pull a file directly from S3."""
s3_resource = boto3.resource("s3") s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url) bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url: str, temp_file: IO) -> None: def http_get(url, temp_file):
req = requests.get(url, stream=True) req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length') content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None total = int(content_length) if content_length is not None else None
...@@ -162,17 +171,18 @@ def http_get(url: str, temp_file: IO) -> None: ...@@ -162,17 +171,18 @@ def http_get(url: str, temp_file: IO) -> None:
progress.close() progress.close()
def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: def get_from_cache(url, cache_dir=None):
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path): if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True) if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists. # Get eTag to add to filename, if it exists.
if url.startswith("s3://"): if url.startswith("s3://"):
...@@ -213,7 +223,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: ...@@ -213,7 +223,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
logger.info("creating metadata file for %s", cache_path) logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag} meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json' meta_path = cache_path + '.json'
with open(meta_path, 'w') as meta_file: with open(meta_path, 'w', encoding="utf-8") as meta_file:
json.dump(meta, meta_file) json.dump(meta, meta_file)
logger.info("removing temp file %s", temp_file.name) logger.info("removing temp file %s", temp_file.name)
...@@ -221,7 +231,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: ...@@ -221,7 +231,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
return cache_path return cache_path
def read_set_from_file(filename: str) -> Set[str]: def read_set_from_file(filename):
''' '''
Extract a de-duped collection (set) of text from a file. Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line. Expected file format is one item per line.
...@@ -233,7 +243,7 @@ def read_set_from_file(filename: str) -> Set[str]: ...@@ -233,7 +243,7 @@ def read_set_from_file(filename: str) -> Set[str]:
return collection return collection
def get_file_extension(path: str, dot=True, lower: bool = True): def get_file_extension(path, dot=True, lower=True):
ext = os.path.splitext(path)[1] ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:] ext = ext if dot else ext[1:]
return ext.lower() if lower else ext return ext.lower() if lower else ext
...@@ -15,18 +15,18 @@ ...@@ -15,18 +15,18 @@
# limitations under the License. # limitations under the License.
"""PyTorch BERT model.""" """PyTorch BERT model."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function, unicode_literals
from __future__ import division
from __future__ import print_function
import os
import copy import copy
import json import json
import math
import logging import logging
import math
import os
import shutil
import tarfile import tarfile
import tempfile import tempfile
import shutil import sys
from io import open
import torch import torch
from torch import nn from torch import nn
...@@ -47,6 +47,68 @@ PRETRAINED_MODEL_ARCHIVE_MAP = { ...@@ -47,6 +47,68 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
} }
CONFIG_NAME = 'bert_config.json' CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin' WEIGHTS_NAME = 'pytorch_model.bin'
TF_WEIGHTS_NAME = 'model.ckpt'
def load_tf_weights_in_bert(model, tf_checkpoint_path):
""" Load tf checkpoints in a pytorch model
"""
try:
import re
import numpy as np
import tensorflow as tf
except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {}".format(tf_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
for name, array in zip(names, arrays):
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
return model
def gelu(x): def gelu(x):
"""Implementation of the gelu activation function. """Implementation of the gelu activation function.
...@@ -102,7 +164,8 @@ class BertConfig(object): ...@@ -102,7 +164,8 @@ class BertConfig(object):
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
""" """
if isinstance(vocab_size_or_config_json_file, str): if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read()) json_config = json.loads(reader.read())
for key, value in json_config.items(): for key, value in json_config.items():
...@@ -281,8 +344,10 @@ class BertIntermediate(nn.Module): ...@@ -281,8 +344,10 @@ class BertIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertIntermediate, self).__init__() super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
if isinstance(config.hidden_act, str) else config.hidden_act self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
...@@ -354,8 +419,10 @@ class BertPredictionHeadTransform(nn.Module): ...@@ -354,8 +419,10 @@ class BertPredictionHeadTransform(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__() super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
if isinstance(config.hidden_act, str) else config.hidden_act self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -416,12 +483,12 @@ class BertPreTrainingHeads(nn.Module): ...@@ -416,12 +483,12 @@ class BertPreTrainingHeads(nn.Module):
return prediction_scores, seq_relationship_score return prediction_scores, seq_relationship_score
class PreTrainedBertModel(nn.Module): class BertPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and """ An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models. a simple interface for dowloading and loading pretrained models.
""" """
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(PreTrainedBertModel, self).__init__() super(BertPreTrainedModel, self).__init__()
if not isinstance(config, BertConfig): if not isinstance(config, BertConfig):
raise ValueError( raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. " "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
...@@ -445,13 +512,14 @@ class PreTrainedBertModel(nn.Module): ...@@ -445,13 +512,14 @@ class PreTrainedBertModel(nn.Module):
module.bias.data.zero_() module.bias.data.zero_()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None,
from_tf=False, *inputs, **kwargs):
""" """
Instantiate a PreTrainedBertModel from a pre-trained model file or a pytorch state dict. Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
Params: Params:
pretrained_model_name: either: pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of: - a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased` . `bert-base-uncased`
. `bert-large-uncased` . `bert-large-uncased`
...@@ -463,24 +531,28 @@ class PreTrainedBertModel(nn.Module): ...@@ -463,24 +531,28 @@ class PreTrainedBertModel(nn.Module):
- a path or url to a pretrained model archive containing: - a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model . `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance . `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached. cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class *inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification) (ex: num_labels for BertForSequenceClassification)
""" """
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name] archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
else: else:
archive_file = pretrained_model_name archive_file = pretrained_model_name_or_path
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError: except EnvironmentError:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file " "We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format( "associated to this path or url.".format(
pretrained_model_name, pretrained_model_name_or_path,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()), ', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
archive_file)) archive_file))
return None return None
...@@ -490,7 +562,7 @@ class PreTrainedBertModel(nn.Module): ...@@ -490,7 +562,7 @@ class PreTrainedBertModel(nn.Module):
logger.info("loading archive file {} from cache at {}".format( logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file)) archive_file, resolved_archive_file))
tempdir = None tempdir = None
if os.path.isdir(resolved_archive_file): if os.path.isdir(resolved_archive_file) or from_tf:
serialization_dir = resolved_archive_file serialization_dir = resolved_archive_file
else: else:
# Extract archive to temp dir # Extract archive to temp dir
...@@ -506,10 +578,17 @@ class PreTrainedBertModel(nn.Module): ...@@ -506,10 +578,17 @@ class PreTrainedBertModel(nn.Module):
logger.info("Model config {}".format(config)) logger.info("Model config {}".format(config))
# Instantiate model. # Instantiate model.
model = cls(config, *inputs, **kwargs) model = cls(config, *inputs, **kwargs)
if state_dict is None: if state_dict is None and not from_tf:
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path) state_dict = torch.load(weights_path, map_location='cpu' if not torch.cuda.is_available() else None)
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
if from_tf:
# Directly load from a TensorFlow checkpoint
weights_path = os.path.join(serialization_dir, TF_WEIGHTS_NAME)
return load_tf_weights_in_bert(model, weights_path)
# Load from a PyTorch state_dict
old_keys = [] old_keys = []
new_keys = [] new_keys = []
for key in state_dict.keys(): for key in state_dict.keys():
...@@ -540,20 +619,23 @@ class PreTrainedBertModel(nn.Module): ...@@ -540,20 +619,23 @@ class PreTrainedBertModel(nn.Module):
for name, child in module._modules.items(): for name, child in module._modules.items():
if child is not None: if child is not None:
load(child, prefix + name + '.') load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'bert') else 'bert.') start_prefix = ''
if not hasattr(model, 'bert') and any(s.startswith('bert.') for s in state_dict.keys()):
start_prefix = 'bert.'
load(model, prefix=start_prefix)
if len(missing_keys) > 0: if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format( logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys)) model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format( logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys)) model.__class__.__name__, unexpected_keys))
if tempdir: if len(error_msgs) > 0:
# Clean up temp dir raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
shutil.rmtree(tempdir) model.__class__.__name__, "\n\t".join(error_msgs)))
return model return model
class BertModel(PreTrainedBertModel): class BertModel(BertPreTrainedModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer"). """BERT model ("Bidirectional Embedding Representations from a Transformer").
Params: Params:
...@@ -581,7 +663,7 @@ class BertModel(PreTrainedBertModel): ...@@ -581,7 +663,7 @@ class BertModel(PreTrainedBertModel):
to the last attention block of shape [batch_size, sequence_length, hidden_size], to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper). input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
Example usage: Example usage:
```python ```python
...@@ -636,7 +718,7 @@ class BertModel(PreTrainedBertModel): ...@@ -636,7 +718,7 @@ class BertModel(PreTrainedBertModel):
return encoded_layers, pooled_output return encoded_layers, pooled_output
class BertForPreTraining(PreTrainedBertModel): class BertForPreTraining(BertPreTrainedModel):
"""BERT model with pre-training heads. """BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads: This module comprises the BERT model followed by the two pre-training heads:
- the masked language modeling head, and - the masked language modeling head, and
...@@ -656,10 +738,10 @@ class BertForPreTraining(PreTrainedBertModel): ...@@ -656,10 +738,10 @@ class BertForPreTraining(PreTrainedBertModel):
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences. a batch has varying length sentences.
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length] `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size] is only computed for the labels set in [0, ..., vocab_size]
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size] `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1]. with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
...@@ -707,7 +789,7 @@ class BertForPreTraining(PreTrainedBertModel): ...@@ -707,7 +789,7 @@ class BertForPreTraining(PreTrainedBertModel):
return prediction_scores, seq_relationship_score return prediction_scores, seq_relationship_score
class BertForMaskedLM(PreTrainedBertModel): class BertForMaskedLM(BertPreTrainedModel):
"""BERT model with the masked language modeling head. """BERT model with the masked language modeling head.
This module comprises the BERT model followed by the masked language modeling head. This module comprises the BERT model followed by the masked language modeling head.
...@@ -768,7 +850,7 @@ class BertForMaskedLM(PreTrainedBertModel): ...@@ -768,7 +850,7 @@ class BertForMaskedLM(PreTrainedBertModel):
return prediction_scores return prediction_scores
class BertForNextSentencePrediction(PreTrainedBertModel): class BertForNextSentencePrediction(BertPreTrainedModel):
"""BERT model with next sentence prediction head. """BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence classification head. This module comprises the BERT model followed by the next sentence classification head.
...@@ -830,7 +912,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel): ...@@ -830,7 +912,7 @@ class BertForNextSentencePrediction(PreTrainedBertModel):
return seq_relationship_score return seq_relationship_score
class BertForSequenceClassification(PreTrainedBertModel): class BertForSequenceClassification(BertPreTrainedModel):
"""BERT model for classification. """BERT model for classification.
This module is composed of the BERT model with a linear layer on top of This module is composed of the BERT model with a linear layer on top of
the pooled output. the pooled output.
...@@ -875,7 +957,7 @@ class BertForSequenceClassification(PreTrainedBertModel): ...@@ -875,7 +957,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels=2): def __init__(self, config, num_labels):
super(BertForSequenceClassification, self).__init__(config) super(BertForSequenceClassification, self).__init__(config)
self.num_labels = num_labels self.num_labels = num_labels
self.bert = BertModel(config) self.bert = BertModel(config)
...@@ -896,7 +978,7 @@ class BertForSequenceClassification(PreTrainedBertModel): ...@@ -896,7 +978,7 @@ class BertForSequenceClassification(PreTrainedBertModel):
return logits return logits
class BertForMultipleChoice(PreTrainedBertModel): class BertForMultipleChoice(BertPreTrainedModel):
"""BERT model for multiple choice tasks. """BERT model for multiple choice tasks.
This module is composed of the BERT model with a linear layer on top of This module is composed of the BERT model with a linear layer on top of
the pooled output. the pooled output.
...@@ -940,7 +1022,7 @@ class BertForMultipleChoice(PreTrainedBertModel): ...@@ -940,7 +1022,7 @@ class BertForMultipleChoice(PreTrainedBertModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_choices=2): def __init__(self, config, num_choices):
super(BertForMultipleChoice, self).__init__(config) super(BertForMultipleChoice, self).__init__(config)
self.num_choices = num_choices self.num_choices = num_choices
self.bert = BertModel(config) self.bert = BertModel(config)
...@@ -965,7 +1047,7 @@ class BertForMultipleChoice(PreTrainedBertModel): ...@@ -965,7 +1047,7 @@ class BertForMultipleChoice(PreTrainedBertModel):
return reshaped_logits return reshaped_logits
class BertForTokenClassification(PreTrainedBertModel): class BertForTokenClassification(BertPreTrainedModel):
"""BERT model for token-level classification. """BERT model for token-level classification.
This module is composed of the BERT model with a linear layer on top of This module is composed of the BERT model with a linear layer on top of
the full hidden state of the last layer. the full hidden state of the last layer.
...@@ -1010,7 +1092,7 @@ class BertForTokenClassification(PreTrainedBertModel): ...@@ -1010,7 +1092,7 @@ class BertForTokenClassification(PreTrainedBertModel):
logits = model(input_ids, token_type_ids, input_mask) logits = model(input_ids, token_type_ids, input_mask)
``` ```
""" """
def __init__(self, config, num_labels=2): def __init__(self, config, num_labels):
super(BertForTokenClassification, self).__init__(config) super(BertForTokenClassification, self).__init__(config)
self.num_labels = num_labels self.num_labels = num_labels
self.bert = BertModel(config) self.bert = BertModel(config)
...@@ -1038,7 +1120,7 @@ class BertForTokenClassification(PreTrainedBertModel): ...@@ -1038,7 +1120,7 @@ class BertForTokenClassification(PreTrainedBertModel):
return logits return logits
class BertForQuestionAnswering(PreTrainedBertModel): class BertForQuestionAnswering(BertPreTrainedModel):
"""BERT model for Question Answering (span extraction). """BERT model for Question Answering (span extraction).
This module is composed of the BERT model with a linear layer on top of This module is composed of the BERT model with a linear layer on top of
the sequence output that computes start_logits and end_logits the sequence output that computes start_logits and end_logits
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for OpenAI GPT model."""
import math
import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
def warmup_cosine(x, warmup=0.002):
s = 1 if x <= warmup else 0
return s*(x/warmup) + (1-s)*(0.5 * (1 + torch.cos(math.pi * x)))
def warmup_constant(x, warmup=0.002):
s = 1 if x <= warmup else 0
return s*(x/warmup) + (1-s)*1
def warmup_linear(x, warmup=0.002):
s = 1 if x <= warmup else 0
return (s*(x/warmup) + (1-s))*(1-x)
SCHEDULES = {
'warmup_cosine':warmup_cosine,
'warmup_constant':warmup_constant,
'warmup_linear':warmup_linear,
}
class OpenAIAdam(Optimizer):
"""Implements Open AI version of Adam algorithm with weight decay fix.
"""
def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1,
b1=0.9, b2=0.999, e=1e-8, weight_decay=0,
vector_l2=False, max_grad_norm=-1, **kwargs):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {}".format(b1))
if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {}".format(b2))
if not e >= 0.0:
raise ValueError("Invalid epsilon value: {}".format(e))
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
max_grad_norm=max_grad_norm)
super(OpenAIAdam, self).__init__(params, defaults)
def get_lr(self):
lr = []
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
if len(state) == 0:
return [0]
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
lr.append(lr_scheduled)
return lr
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['b1'], group['b2']
state['step'] += 1
# Add grad clipping
if group['max_grad_norm'] > 0:
clip_grad_norm_(p, group['max_grad_norm'])
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
denom = exp_avg_sq.sqrt().add_(group['e'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
# Add weight decay at the end (fixed version)
if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0:
p.data.add_(-lr_scheduled * group['weight_decay'], p.data)
return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment