Commit eed51c5b authored by thomwolf's avatar thomwolf
Browse files

add OpenAI GPT

parent 793dcd23
__version__ = "0.4.0"
__version__ = "0.5.0"
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .tokenization_openai import OpenAIGPTTokenizer
from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForMultipleChoice,
BertForTokenClassification, BertForQuestionAnswering)
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTDoubleHeadsModel
from .optimization import BertAdam
from .optimization_openai import OpenAIAdam
from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE
# coding: utf8
def main():
import sys
try:
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ModuleNotFoundError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if len(sys.argv) != 5:
# pylint: disable=line-too-long
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
"convert_tf_checkpoint_to_pytorch",
"convert_openai_checkpoint"
]:
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT` \n or `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
else:
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
try:
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ModuleNotFoundError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if len(sys.argv) != 5:
# pylint: disable=line-too-long
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
else:
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
else:
from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3]
if len(sys.argv) == 5:
OPENAI_GPT_CONFIG = sys.argv[4]
else:
OPENAI_GPT_CONFIG = ""
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
OPENAI_GPT_CONFIG,
PYTORCH_DUMP_OUTPUT)
if __name__ == '__main__':
main()
......@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""
"""Convert OpenAI GPT checkpoint."""
from __future__ import absolute_import
from __future__ import division
......@@ -20,45 +20,53 @@ from __future__ import print_function
import os
import re
import json
import argparse
import tensorflow as tf
import torch
import numpy as np
from .modeling import BertConfig, BertForPreTraining
from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
def convert_openai_checkpoint_to_pytorch(open_checkpoint_folder_path, openai_config_file, pytorch_dump_path):
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/',
path_names='./'):
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
# Load weights from TF model
print("Loading weights...")
names = json.load(open(path_names + 'parameters_names.json'))
shapes = json.load(open(path + 'params_shapes.json'))
names = json.load(open(openai_checkpoint_folder_path + '/parameters_names.json', "r", encoding='utf-8'))
shapes = json.load(open(openai_checkpoint_folder_path + '/params_shapes.json', "r", encoding='utf-8'))
offsets = np.cumsum([np.prod(shape) for shape in shapes])
init_params = [np.load(path + 'params_{}.npy'.format(n)) for n in range(10)]
init_params = [np.load(openai_checkpoint_folder_path + '/params_{}.npy'.format(n)) for n in range(10)]
init_params = np.split(np.concatenate(init_params, 0), offsets)[:-1]
init_params = [param.reshape(shape) for param, shape in zip(init_params, shapes)]
if n_ctx > 0:
init_params[0] = init_params[0][:n_ctx]
if n_special > 0:
init_params[0] = np.concatenate(
[init_params[1],
(np.random.randn(n_special, n_embd) * 0.02).astype(np.float32),
init_params[0]
], 0)
else:
init_params[0] = np.concatenate(
[init_params[1],
init_params[0]
], 0)
# if n_ctx > 0:
# init_params[0] = init_params[0][:n_ctx]
# if n_special > 0:
# init_params[0] = np.concatenate(
# [init_params[1],
# (np.random.randn(n_special, n_embd) * 0.02).astype(np.float32),
# init_params[0]
# ], 0)
# else:
# init_params[0] = np.concatenate(
# [init_params[1],
# init_params[0]
# ], 0)
# del init_params[1]
# if n_transfer == -1:
# n_transfer = 0
# else:
# n_transfer = 1 + n_transfer * 12
init_params[0] = np.concatenate([init_params[1], init_params[0]], 0)
del init_params[1]
if n_transfer == -1:
n_transfer = 0
else:
n_transfer = 1 + n_transfer * 12
init_params = [arr.squeeze() for arr in init_params]
# Construct model
if openai_config_file == "":
config = OpenAIGPTConfig()
else:
config = OpenAIGPTConfig(openai_config_file)
model = OpenAIGPTModel(config)
try:
assert model.embed.weight.shape == init_params[0].shape
except AssertionError as e:
......@@ -66,8 +74,10 @@ def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n
raise
model.embed.weight.data = torch.from_numpy(init_params[0])
names.pop(0)
init_params.pop(0)
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
name = name[6:] # skip "model/"
assert name[-2:] == ":0"
name = name[:-2]
......@@ -78,64 +88,22 @@ def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n
l = re.split(r'(\d+)', m_name)
else:
l = [m_name]
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
try:
assert pointer.shape == ip.shape
except AssertionError as e:
e.args += (pointer.shape, ip.shape)
raise
pointer.data = torch.from_numpy(ip)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
config_path = os.path.abspath(bert_config_file)
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path))
# Load weights from TF model
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)
for name, array in zip(names, arrays):
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if any(n in ["adam_v", "adam_m"] for n in name):
print("Skipping {}".format("/".join(name)))
continue
pointer = model
for m_name in name:
if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
l = re.split(r'_(\d+)', m_name)
else:
l = [m_name]
if l[0] == 'kernel' or l[0] == 'gamma':
if l[0] == 'g':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias' or l[0] == 'beta':
elif l[0] == 'b':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
elif l[0] == 'w':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
num = int(l[1])
pointer = pointer[num]
if m_name[-11:] == '_embeddings':
pointer = getattr(pointer, 'weight')
elif m_name == 'kernel':
array = np.transpose(array)
try:
assert pointer.shape == array.shape
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
try:
assert pointer.shape == array.shape
except AssertionError as e:
......@@ -145,30 +113,33 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
pointer.data = torch.from_numpy(array)
# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
torch.save(model.state_dict(), pytorch_weights_dump_path)
print("Save configuration file to {}".format(pytorch_config_dump_path))
with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
f.write(config.to_json_string())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
parser.add_argument("--openai_checkpoint_folder_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
parser.add_argument("--openai_config_file",
default = "",
type = str,
help = "An optional config json file corresponding to the pre-trained OpenAI model. \n"
"This specifies the model architecture.")
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.bert_config_file,
args.pytorch_dump_path)
convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path,
args.pytorch_dump_folder_path,
args.openai_config_file)
......@@ -416,12 +416,12 @@ class BertPreTrainingHeads(nn.Module):
return prediction_scores, seq_relationship_score
class PreTrainedModel(nn.Module):
class BertPreTrainedModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
super(BertPreTrainedModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
......@@ -447,7 +447,7 @@ class PreTrainedModel(nn.Module):
@classmethod
def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedModel from a pre-trained model file or a pytorch state dict.
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
......@@ -547,13 +547,16 @@ class PreTrainedModel(nn.Module):
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
self.__class__.__name__, "\n\t".join(error_msgs)))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return model
class BertModel(PreTrainedModel):
class BertModel(BertPreTrainedModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
......@@ -636,7 +639,7 @@ class BertModel(PreTrainedModel):
return encoded_layers, pooled_output
class BertForPreTraining(PreTrainedModel):
class BertForPreTraining(BertPreTrainedModel):
"""BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads:
- the masked language modeling head, and
......@@ -707,7 +710,7 @@ class BertForPreTraining(PreTrainedModel):
return prediction_scores, seq_relationship_score
class BertForMaskedLM(PreTrainedModel):
class BertForMaskedLM(BertPreTrainedModel):
"""BERT model with the masked language modeling head.
This module comprises the BERT model followed by the masked language modeling head.
......@@ -768,7 +771,7 @@ class BertForMaskedLM(PreTrainedModel):
return prediction_scores
class BertForNextSentencePrediction(PreTrainedModel):
class BertForNextSentencePrediction(BertPreTrainedModel):
"""BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence classification head.
......@@ -830,7 +833,7 @@ class BertForNextSentencePrediction(PreTrainedModel):
return seq_relationship_score
class BertForSequenceClassification(PreTrainedModel):
class BertForSequenceClassification(BertPreTrainedModel):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
......@@ -875,7 +878,7 @@ class BertForSequenceClassification(PreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
def __init__(self, config, num_labels):
super(BertForSequenceClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
......@@ -896,7 +899,7 @@ class BertForSequenceClassification(PreTrainedModel):
return logits
class BertForMultipleChoice(PreTrainedModel):
class BertForMultipleChoice(BertPreTrainedModel):
"""BERT model for multiple choice tasks.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
......@@ -940,7 +943,7 @@ class BertForMultipleChoice(PreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices=2):
def __init__(self, config, num_choices):
super(BertForMultipleChoice, self).__init__(config)
self.num_choices = num_choices
self.bert = BertModel(config)
......@@ -965,7 +968,7 @@ class BertForMultipleChoice(PreTrainedModel):
return reshaped_logits
class BertForTokenClassification(PreTrainedModel):
class BertForTokenClassification(BertPreTrainedModel):
"""BERT model for token-level classification.
This module is composed of the BERT model with a linear layer on top of
the full hidden state of the last layer.
......@@ -1010,7 +1013,7 @@ class BertForTokenClassification(PreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels=2):
def __init__(self, config, num_labels):
super(BertForTokenClassification, self).__init__(config)
self.num_labels = num_labels
self.bert = BertModel(config)
......@@ -1031,7 +1034,7 @@ class BertForTokenClassification(PreTrainedModel):
return logits
class BertForQuestionAnswering(PreTrainedModel):
class BertForQuestionAnswering(BertPreTrainedModel):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with a linear layer on top of
the sequence output that computes start_logits and end_logits
......
This diff is collapsed.
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for OpenAI GPT model."""
import math
import torch
from torch.optim import Optimizer
from torch.optim.optimizer import required
from torch.nn.utils import clip_grad_norm_
def warmup_cosine(x, warmup=0.002):
......@@ -25,26 +42,41 @@ SCHEDULES = {
class OpenAIAdam(Optimizer):
"""Implements Open AI version of Adam algorithm with weight decay fix.
"""
def __init__(self, params, lr, schedule, warmup, t_total,
b1=0.9, b2=0.999, e=1e-8, l2=0,
def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1,
b1=0.9, b2=0.999, e=1e-8, weight_decay=0,
vector_l2=False, max_grad_norm=-1, **kwargs):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
if schedule not in SCHEDULES:
raise ValueError("Invalid schedule parameter: {}".format(schedule))
if not 0 <= warmup:
raise ValueError("Invalid warmup: {}".format(warmup))
if not 0.0 <= warmup < 1.0 and not warmup == -1:
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
if not 0.0 <= b1 < 1.0:
raise ValueError("Invalid b1 parameter: {}".format(b1))
if not 0.0 <= b2 < 1.0:
raise ValueError("Invalid b2 parameter: {}".format(b2))
if not 0.0 <= e:
if not e >= 0.0:
raise ValueError("Invalid epsilon value: {}".format(e))
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
b1=b1, b2=b2, e=e, l2=l2, vector_l2=vector_l2,
b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2,
max_grad_norm=max_grad_norm)
super(OpenAIAdam, self).__init__(params, defaults)
def get_lr(self):
lr = []
for group in self.param_groups:
for p in group['params']:
state = self.state[p]
if len(state) == 0:
return [0]
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
lr.append(lr_scheduled)
return lr
def step(self, closure=None):
"""Performs a single optimization step.
......@@ -91,14 +123,18 @@ class OpenAIAdam(Optimizer):
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
if group['t_total'] != -1:
schedule_fct = SCHEDULES[group['schedule']]
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
else:
lr_scheduled = group['lr']
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(-step_size, exp_avg, denom)
# Add weight decay at the end (fixed version)
if (len(p.size()) > 1 or group['vector_l2']) and group['l2'] > 0:
p.data.add_(-lr_scheduled * group['l2'], p.data)
if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0:
p.data.add_(-lr_scheduled * group['weight_decay'], p.data)
return loss
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HugginFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for OpenAI GPT."""
import os
import re
import ftfy
import json
import spacy
from tqdm import tqdm
import logging
from .file_utils import cached_path
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
}
PRETRAINED_MERGES_ARCHIVE_MAP = {
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
}
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
'openai-gpt': 512,
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
def get_pairs(word):
"""
......@@ -32,16 +62,65 @@ def text_standardize(text):
text = re.sub(r'[^\S\n]+', ' ', text)
return text.strip()
class TextEncoder(object):
class OpenAIGPTTokenizer(object):
"""
mostly a wrapper for a public python bpe tokenizer
"""
@classmethod
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name]
else:
vocab_file = pretrained_model_name
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
merges_file = os.path.join(vocab_file, MERGES_NAME)
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
logger.info("loading merges file {}".format(merges_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file):
try:
import ftfy
import spacy
except ImportError:
raise ImportError("Please install ftfy and spacy to use OpenAI GPT tokenizer.")
def __init__(self, encoder_path, bpe_path):
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.encoder = json.load(open(encoder_path))
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(bpe_path, encoding='utf-8').read().split('\n')[1:-1]
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
......@@ -89,7 +168,7 @@ class TextEncoder(object):
self.cache[token] = word
return word
def encode(self, texts, verbose=True):
def tokenize(self, texts, verbose=True):
texts_tokens = []
if verbose:
for text in tqdm(texts, ncols=80, leave=False):
......
......@@ -37,8 +37,8 @@ from setuptools import find_packages, setup
setup(
name="pytorch_pretrained_bert",
version="0.4.0",
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors",
version="0.5.0",
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors, Open AI team Authors",
author_email="thomas@huggingface.co",
description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",
long_description=open("README.md", "r", encoding='utf-8').read(),
......@@ -55,7 +55,7 @@ setup(
'tqdm'],
entry_points={
'console_scripts': [
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main"
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main",
]
},
python_requires='>=3.5.0',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment