Commit 448937c0 authored by thomwolf's avatar thomwolf
Browse files

python 2 compatibility

parent ba37ddc5
...@@ -17,26 +17,26 @@ ...@@ -17,26 +17,26 @@
Adapted from https://github.com/kimiyoung/transformer-xl. Adapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
""" """
from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import sys
import functools import functools
import argparse import argparse
import logging
import time import time
import math import math
import sys
from io import open
import torch import torch
from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus
def logging(s, log_path, print_=True, log_=True): logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
if print_: datefmt = '%m/%d/%Y %H:%M:%S',
print(s) level = logging.INFO)
if log_: logger = logging.getLogger(__name__)
with open(log_path, 'a+') as f_log:
f_log.write(s + '\n')
def get_logger(log_path, **kwargs):
return functools.partial(logging, log_path=log_path, **kwargs)
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
# parser.add_argument('--data', type=str, default='../data/wikitext-103', # parser.add_argument('--data', type=str, default='../data/wikitext-103',
...@@ -71,8 +71,8 @@ assert args.ext_len >= 0, 'extended context length must be non-negative' ...@@ -71,8 +71,8 @@ assert args.ext_len >= 0, 'extended context length must be non-negative'
device = torch.device("cuda" if args.cuda else "cpu") device = torch.device("cuda" if args.cuda else "cpu")
# Get logger # Get logger
logging = get_logger(os.path.join(args.work_dir, 'log.txt'), # logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
log_=not args.no_log) # log_=not args.no_log)
# Load dataset # Load dataset
corpus = TransfoXLCorpus.from_pretrained(args.model_name) corpus = TransfoXLCorpus.from_pretrained(args.model_name)
...@@ -90,7 +90,7 @@ te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len, ...@@ -90,7 +90,7 @@ te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
model = TransfoXLModel.from_pretrained(args.model_name) model = TransfoXLModel.from_pretrained(args.model_name)
model = model.to(device) model = model.to(device)
logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
model.reset_length(args.tgt_len, args.ext_len, args.mem_len) model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
...@@ -116,7 +116,7 @@ def evaluate(eval_iter): ...@@ -116,7 +116,7 @@ def evaluate(eval_iter):
total_loss += seq_len * loss.item() total_loss += seq_len * loss.item()
total_len += seq_len total_len += seq_len
total_time = time.time() - start_time total_time = time.time() - start_time
logging('Time : {:.2f}s, {:.2f}ms/segment'.format( logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
total_time, 1000 * total_time / (idx+1))) total_time, 1000 * total_time / (idx+1)))
return total_loss / total_len return total_loss / total_len
...@@ -146,6 +146,6 @@ if valid_loss is not None: ...@@ -146,6 +146,6 @@ if valid_loss is not None:
if test_loss is not None: if test_loss is not None:
log_str += format_log(test_loss, 'test') log_str += format_log(test_loss, 'test')
logging('=' * 100) logger.info('=' * 100)
logging(log_str) logger.info(log_str)
logging('=' * 100) logger.info('=' * 100)
...@@ -15,26 +15,27 @@ ...@@ -15,26 +15,27 @@
# limitations under the License. # limitations under the License.
"""BERT finetuning runner.""" """BERT finetuning runner."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse
import csv import csv
import os
import logging import logging
import argparse import os
import random import random
from tqdm import tqdm, trange import sys
from io import open
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification from pytorch_pretrained_bert.modeling import BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.tokenization import BertTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -91,10 +92,12 @@ class DataProcessor(object): ...@@ -91,10 +92,12 @@ class DataProcessor(object):
@classmethod @classmethod
def _read_tsv(cls, input_file, quotechar=None): def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "r", encoding='utf-8') as f: with open(input_file, "rb") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = [] lines = []
for line in reader: for line in reader:
if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line)
lines.append(line) lines.append(line)
return lines return lines
...@@ -429,7 +432,8 @@ def main(): ...@@ -429,7 +432,8 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train: if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True) if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
task_name = args.task_name.lower() task_name = args.task_name.lower()
...@@ -451,7 +455,7 @@ def main(): ...@@ -451,7 +455,7 @@ def main():
# Prepare model # Prepare model
model = BertForSequenceClassification.from_pretrained(args.bert_model, model = BertForSequenceClassification.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank), cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)),
num_labels = num_labels) num_labels = num_labels)
if args.fp16: if args.fp16:
model.half() model.half()
......
...@@ -15,26 +15,23 @@ ...@@ -15,26 +15,23 @@
# limitations under the License. # limitations under the License.
"""BERT finetuning runner.""" """BERT finetuning runner."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function, unicode_literals
from __future__ import division
from __future__ import print_function
import os
import logging
import argparse import argparse
from tqdm import tqdm, trange import logging
import os
import random
from io import open
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForPreTraining from pytorch_pretrained_bert.modeling import BertForPreTraining
from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.tokenization import BertTokenizer
from torch.utils.data import Dataset
import random
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', datefmt='%m/%d/%Y %H:%M:%S',
...@@ -185,16 +182,16 @@ class BERTDataset(Dataset): ...@@ -185,16 +182,16 @@ class BERTDataset(Dataset):
if self.line_buffer is None: if self.line_buffer is None:
# read first non-empty line of file # read first non-empty line of file
while t1 == "" : while t1 == "" :
t1 = self.file.__next__().strip() t1 = next(self.file).strip()
t2 = self.file.__next__().strip() t2 = next(self.file).strip()
else: else:
# use t2 from previous iteration as new t1 # use t2 from previous iteration as new t1
t1 = self.line_buffer t1 = self.line_buffer
t2 = self.file.__next__().strip() t2 = next(self.file).strip()
# skip empty rows that are used for separating documents and keep track of current doc id # skip empty rows that are used for separating documents and keep track of current doc id
while t2 == "" or t1 == "": while t2 == "" or t1 == "":
t1 = self.file.__next__().strip() t1 = next(self.file).strip()
t2 = self.file.__next__().strip() t2 = next(self.file).strip()
self.current_doc = self.current_doc+1 self.current_doc = self.current_doc+1
self.line_buffer = t2 self.line_buffer = t2
...@@ -228,15 +225,15 @@ class BERTDataset(Dataset): ...@@ -228,15 +225,15 @@ class BERTDataset(Dataset):
def get_next_line(self): def get_next_line(self):
""" Gets next line of random_file and starts over when reaching end of file""" """ Gets next line of random_file and starts over when reaching end of file"""
try: try:
line = self.random_file.__next__().strip() line = next(self.random_file).strip()
#keep track of which document we are currently looking at to later avoid having the same doc as t1 #keep track of which document we are currently looking at to later avoid having the same doc as t1
if line == "": if line == "":
self.current_random_doc = self.current_random_doc + 1 self.current_random_doc = self.current_random_doc + 1
line = self.random_file.__next__().strip() line = next(self.random_file).strip()
except StopIteration: except StopIteration:
self.random_file.close() self.random_file.close()
self.random_file = open(self.corpus_path, "r", encoding=self.encoding) self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
line = self.random_file.__next__().strip() line = next(self.random_file).strip()
return line return line
...@@ -425,6 +422,7 @@ def main(): ...@@ -425,6 +422,7 @@ def main():
help="The output directory where the model checkpoints will be written.") help="The output directory where the model checkpoints will be written.")
## Other parameters ## Other parameters
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--max_seq_length", parser.add_argument("--max_seq_length",
default=128, default=128,
type=int, type=int,
...@@ -513,7 +511,8 @@ def main(): ...@@ -513,7 +511,8 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir): if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True) if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
...@@ -579,7 +578,7 @@ def main(): ...@@ -579,7 +578,7 @@ def main():
if args.local_rank == -1: if args.local_rank == -1:
train_sampler = RandomSampler(train_dataset) train_sampler = RandomSampler(train_dataset)
else: else:
#TODO: check if this works with current data generator from disk that relies on file.__next__ #TODO: check if this works with current data generator from disk that relies on next(file)
# (it doesn't return item back by index) # (it doesn't return item back by index)
train_sampler = DistributedSampler(train_dataset) train_sampler = DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
......
...@@ -15,29 +15,36 @@ ...@@ -15,29 +15,36 @@
# limitations under the License. # limitations under the License.
"""Run BERT on SQuAD.""" """Run BERT on SQuAD."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse import argparse
import collections import collections
import logging
import json import json
import logging
import math import math
import os import os
import random import random
import pickle import sys
from tqdm import tqdm, trange from io import open
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
BertTokenizer,
whitespace_tokenize)
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -784,7 +791,8 @@ def main(): ...@@ -784,7 +791,8 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir): if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory () already exists and is not empty.") raise ValueError("Output directory () already exists and is not empty.")
os.makedirs(args.output_dir, exist_ok=True) if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
...@@ -798,7 +806,7 @@ def main(): ...@@ -798,7 +806,7 @@ def main():
# Prepare model # Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model, model = BertForQuestionAnswering.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank)) cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)))
if args.fp16: if args.fp16:
model.half() model.half()
......
...@@ -15,22 +15,25 @@ ...@@ -15,22 +15,25 @@
# limitations under the License. # limitations under the License.
"""BERT finetuning runner.""" """BERT finetuning runner."""
import argparse
import csv
import logging import logging
import os import os
import argparse
import random import random
from tqdm import tqdm, trange import sys
import csv from io import open
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForMultipleChoice from pytorch_pretrained_bert.modeling import BertForMultipleChoice
from pytorch_pretrained_bert.optimization import BertAdam from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.tokenization import BertTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -65,17 +68,17 @@ class SwagExample(object): ...@@ -65,17 +68,17 @@ class SwagExample(object):
def __repr__(self): def __repr__(self):
l = [ l = [
f"swag_id: {self.swag_id}", "swag_id: {}".format(self.swag_id),
f"context_sentence: {self.context_sentence}", "context_sentence: {}".format(self.context_sentence),
f"start_ending: {self.start_ending}", "start_ending: {}".format(self.start_ending),
f"ending_0: {self.endings[0]}", "ending_0: {}".format(self.endings[0]),
f"ending_1: {self.endings[1]}", "ending_1: {}".format(self.endings[1]),
f"ending_2: {self.endings[2]}", "ending_2: {}".format(self.endings[2]),
f"ending_3: {self.endings[3]}", "ending_3: {}".format(self.endings[3]),
] ]
if self.label is not None: if self.label is not None:
l.append(f"label: {self.label}") l.append("label: {}".format(self.label))
return ", ".join(l) return ", ".join(l)
...@@ -102,7 +105,11 @@ class InputFeatures(object): ...@@ -102,7 +105,11 @@ class InputFeatures(object):
def read_swag_examples(input_file, is_training): def read_swag_examples(input_file, is_training):
with open(input_file, 'r', encoding='utf-8') as f: with open(input_file, 'r', encoding='utf-8') as f:
reader = csv.reader(f) reader = csv.reader(f)
lines = list(reader) lines = []
for line in reader:
if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line)
lines.append(line)
if is_training and lines[0][-1] != 'label': if is_training and lines[0][-1] != 'label':
raise ValueError( raise ValueError(
...@@ -184,15 +191,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -184,15 +191,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
label = example.label label = example.label
if example_index < 5: if example_index < 5:
logger.info("*** Example ***") logger.info("*** Example ***")
logger.info(f"swag_id: {example.swag_id}") logger.info("swag_id: {}".format(example.swag_id))
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features): for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
logger.info(f"choice: {choice_idx}") logger.info("choice: {}".format(choice_idx))
logger.info(f"tokens: {' '.join(tokens)}") logger.info("tokens: {}".format(' '.join(tokens)))
logger.info(f"input_ids: {' '.join(map(str, input_ids))}") logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
logger.info(f"input_mask: {' '.join(map(str, input_mask))}") logger.info("input_mask: {}".format(' '.join(map(str, input_mask))))
logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}") logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids))))
if is_training: if is_training:
logger.info(f"label: {label}") logger.info("label: {}".format(label))
features.append( features.append(
InputFeatures( InputFeatures(
...@@ -349,7 +356,8 @@ def main(): ...@@ -349,7 +356,8 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir): if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True) if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
...@@ -362,7 +370,7 @@ def main(): ...@@ -362,7 +370,7 @@ def main():
# Prepare model # Prepare model
model = BertForMultipleChoice.from_pretrained(args.bert_model, model = BertForMultipleChoice.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank), cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)),
num_choices=4) num_choices=4)
if args.fp16: if args.fp16:
model.half() model.half()
......
...@@ -15,7 +15,7 @@ def main(): ...@@ -15,7 +15,7 @@ def main():
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
try: try:
import tensorflow as tf import tensorflow as tf
except ModuleNotFoundError: except ImportError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see " "In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.") "https://www.tensorflow.org/install/ for installation instructions.")
...@@ -43,7 +43,7 @@ def main(): ...@@ -43,7 +43,7 @@ def main():
else: else:
try: try:
import tensorflow as tf import tensorflow as tf
except ModuleNotFoundError: except ImportError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see " "In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.") "https://www.tensorflow.org/install/ for installation instructions.")
......
...@@ -14,14 +14,18 @@ ...@@ -14,14 +14,18 @@
# limitations under the License. # limitations under the License.
"""Convert OpenAI GPT checkpoint.""" """Convert OpenAI GPT checkpoint."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse import argparse
from io import open
import torch import torch
from pytorch_pretrained_bert.modeling_openai import load_tf_weights_in_openai_gpt, OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
OpenAIGPTConfig,
OpenAIGPTModel,
load_tf_weights_in_openai_gpt)
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
# Construct model # Construct model
......
...@@ -14,25 +14,31 @@ ...@@ -14,25 +14,31 @@
# limitations under the License. # limitations under the License.
"""Convert Transformer XL checkpoint and datasets.""" """Convert Transformer XL checkpoint and datasets."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function
from __future__ import division
from __future__ import print_function
import argparse
import os import os
import sys import sys
import argparse from io import open
import pickle
import tensorflow as tf
import torch import torch
import numpy as np
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME, load_tf_weights_in_transfo_xl import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
from pytorch_pretrained_bert.tokenization_transfo_xl import VOCAB_NAME, CORPUS_NAME from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
WEIGHTS_NAME,
TransfoXLConfig,
TransfoXLModel,
load_tf_weights_in_transfo_xl)
from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
VOCAB_NAME)
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
# We do this to be able to load the python 2 datasets pickles # We do this to be able to load the python 2 datasets pickles
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
data_utils.Vocab = data_utils.TransfoXLTokenizer data_utils.Vocab = data_utils.TransfoXLTokenizer
data_utils.Corpus = data_utils.TransfoXLCorpus data_utils.Corpus = data_utils.TransfoXLCorpus
sys.modules['data_utils'] = data_utils sys.modules['data_utils'] = data_utils
......
...@@ -3,31 +3,39 @@ Utilities for working with the local dataset cache. ...@@ -3,31 +3,39 @@ Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors. Copyright by the AllenNLP authors.
""" """
from __future__ import (absolute_import, division, print_function, unicode_literals)
import os import json
import logging import logging
import os
import shutil import shutil
import tempfile import tempfile
import json
from urllib.parse import urlparse
from pathlib import Path
from typing import Optional, Tuple, Union, IO, Callable, Set
from hashlib import sha256
from functools import wraps from functools import wraps
from hashlib import sha256
from tqdm import tqdm from io import open
import boto3 import boto3
from botocore.exceptions import ClientError
import requests import requests
from botocore.exceptions import ClientError
from tqdm import tqdm
logger = logging.getLogger(__name__) # pylint: disable=invalid-name try:
from urllib.parse import urlparse
except ImportError:
from urlparse import urlparse
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', try:
from pathlib import Path
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert')) Path.home() / '.pytorch_pretrained_bert'))
except ImportError:
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def url_to_filename(url: str, etag: str = None) -> str: def url_to_filename(url, etag=None):
""" """
Convert `url` into a hashed filename in a repeatable way. Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited If `etag` is specified, append its hash to the url's, delimited
...@@ -45,25 +53,23 @@ def url_to_filename(url: str, etag: str = None) -> str: ...@@ -45,25 +53,23 @@ def url_to_filename(url: str, etag: str = None) -> str:
return filename return filename
def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]: def filename_to_url(filename, cache_dir=None):
""" """
Return the url and etag (which may be ``None``) stored for `filename`. Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist. Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename) cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path): if not os.path.exists(cache_path):
raise FileNotFoundError("file {} not found".format(cache_path)) raise EnvironmentError("file {} not found".format(cache_path))
meta_path = cache_path + '.json' meta_path = cache_path + '.json'
if not os.path.exists(meta_path): if not os.path.exists(meta_path):
raise FileNotFoundError("file {} not found".format(meta_path)) raise EnvironmentError("file {} not found".format(meta_path))
with open(meta_path) as meta_file: with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file) metadata = json.load(meta_file)
url = metadata['url'] url = metadata['url']
etag = metadata['etag'] etag = metadata['etag']
...@@ -71,7 +77,7 @@ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[ ...@@ -71,7 +77,7 @@ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[
return url, etag return url, etag
def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str: def cached_path(url_or_filename, cache_dir=None):
""" """
Given something that might be a URL (or might be a local path), Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and determine which. If it's a URL, download the file and cache it, and
...@@ -80,10 +86,6 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = ...@@ -80,10 +86,6 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
parsed = urlparse(url_or_filename) parsed = urlparse(url_or_filename)
...@@ -95,13 +97,13 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = ...@@ -95,13 +97,13 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
return url_or_filename return url_or_filename
elif parsed.scheme == '': elif parsed.scheme == '':
# File, but it doesn't exist. # File, but it doesn't exist.
raise FileNotFoundError("file {} not found".format(url_or_filename)) raise EnvironmentError("file {} not found".format(url_or_filename))
else: else:
# Something unknown # Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url: str) -> Tuple[str, str]: def split_s3_path(url):
"""Split a full s3 path into the bucket name and path.""" """Split a full s3 path into the bucket name and path."""
parsed = urlparse(url) parsed = urlparse(url)
if not parsed.netloc or not parsed.path: if not parsed.netloc or not parsed.path:
...@@ -114,19 +116,19 @@ def split_s3_path(url: str) -> Tuple[str, str]: ...@@ -114,19 +116,19 @@ def split_s3_path(url: str) -> Tuple[str, str]:
return bucket_name, s3_path return bucket_name, s3_path
def s3_request(func: Callable): def s3_request(func):
""" """
Wrapper function for s3 requests in order to create more helpful error Wrapper function for s3 requests in order to create more helpful error
messages. messages.
""" """
@wraps(func) @wraps(func)
def wrapper(url: str, *args, **kwargs): def wrapper(url, *args, **kwargs):
try: try:
return func(url, *args, **kwargs) return func(url, *args, **kwargs)
except ClientError as exc: except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404: if int(exc.response["Error"]["Code"]) == 404:
raise FileNotFoundError("file {} not found".format(url)) raise EnvironmentError("file {} not found".format(url))
else: else:
raise raise
...@@ -134,7 +136,7 @@ def s3_request(func: Callable): ...@@ -134,7 +136,7 @@ def s3_request(func: Callable):
@s3_request @s3_request
def s3_etag(url: str) -> Optional[str]: def s3_etag(url):
"""Check ETag on S3 object.""" """Check ETag on S3 object."""
s3_resource = boto3.resource("s3") s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url) bucket_name, s3_path = split_s3_path(url)
...@@ -143,14 +145,14 @@ def s3_etag(url: str) -> Optional[str]: ...@@ -143,14 +145,14 @@ def s3_etag(url: str) -> Optional[str]:
@s3_request @s3_request
def s3_get(url: str, temp_file: IO) -> None: def s3_get(url, temp_file):
"""Pull a file directly from S3.""" """Pull a file directly from S3."""
s3_resource = boto3.resource("s3") s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url) bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url: str, temp_file: IO) -> None: def http_get(url, temp_file):
req = requests.get(url, stream=True) req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length') content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None total = int(content_length) if content_length is not None else None
...@@ -162,17 +164,16 @@ def http_get(url: str, temp_file: IO) -> None: ...@@ -162,17 +164,16 @@ def http_get(url: str, temp_file: IO) -> None:
progress.close() progress.close()
def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: def get_from_cache(url, cache_dir=None):
""" """
Given a URL, look for the corresponding dataset in the local cache. Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file. If it's not there, download it. Then return the path to the cached file.
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True) if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
# Get eTag to add to filename, if it exists. # Get eTag to add to filename, if it exists.
if url.startswith("s3://"): if url.startswith("s3://"):
...@@ -213,7 +214,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: ...@@ -213,7 +214,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
logger.info("creating metadata file for %s", cache_path) logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag} meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json' meta_path = cache_path + '.json'
with open(meta_path, 'w') as meta_file: with open(meta_path, 'w', encoding="utf-8") as meta_file:
json.dump(meta, meta_file) json.dump(meta, meta_file)
logger.info("removing temp file %s", temp_file.name) logger.info("removing temp file %s", temp_file.name)
...@@ -221,7 +222,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str: ...@@ -221,7 +222,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
return cache_path return cache_path
def read_set_from_file(filename: str) -> Set[str]: def read_set_from_file(filename):
''' '''
Extract a de-duped collection (set) of text from a file. Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line. Expected file format is one item per line.
...@@ -233,7 +234,7 @@ def read_set_from_file(filename: str) -> Set[str]: ...@@ -233,7 +234,7 @@ def read_set_from_file(filename: str) -> Set[str]:
return collection return collection
def get_file_extension(path: str, dot=True, lower: bool = True): def get_file_extension(path, dot=True, lower=True):
ext = os.path.splitext(path)[1] ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:] ext = ext if dot else ext[1:]
return ext.lower() if lower else ext return ext.lower() if lower else ext
...@@ -15,18 +15,18 @@ ...@@ -15,18 +15,18 @@
# limitations under the License. # limitations under the License.
"""PyTorch BERT model.""" """PyTorch BERT model."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function, unicode_literals
from __future__ import division
from __future__ import print_function
import os
import copy import copy
import json import json
import math
import logging import logging
import math
import os
import shutil
import tarfile import tarfile
import tempfile import tempfile
import shutil import sys
from io import open
import torch import torch
from torch import nn from torch import nn
...@@ -56,7 +56,7 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path): ...@@ -56,7 +56,7 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
import re import re
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
except ModuleNotFoundError: except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.") "https://www.tensorflow.org/install/ for installation instructions.")
raise raise
...@@ -164,7 +164,8 @@ class BertConfig(object): ...@@ -164,7 +164,8 @@ class BertConfig(object):
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
""" """
if isinstance(vocab_size_or_config_json_file, str): if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read()) json_config = json.loads(reader.read())
for key, value in json_config.items(): for key, value in json_config.items():
...@@ -343,8 +344,10 @@ class BertIntermediate(nn.Module): ...@@ -343,8 +344,10 @@ class BertIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertIntermediate, self).__init__() super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
if isinstance(config.hidden_act, str) else config.hidden_act self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
self.intermediate_act_fn = config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
...@@ -416,8 +419,10 @@ class BertPredictionHeadTransform(nn.Module): ...@@ -416,8 +419,10 @@ class BertPredictionHeadTransform(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__() super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \ if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
if isinstance(config.hidden_act, str) else config.hidden_act self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12) self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
def forward(self, hidden_states): def forward(self, hidden_states):
...@@ -542,7 +547,7 @@ class BertPreTrainedModel(nn.Module): ...@@ -542,7 +547,7 @@ class BertPreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError: except EnvironmentError:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file " "We assumed '{}' was a path or url but couldn't find any file "
......
...@@ -24,6 +24,8 @@ import os ...@@ -24,6 +24,8 @@ import os
import shutil import shutil
import tarfile import tarfile
import tempfile import tempfile
import sys
from io import open
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -160,7 +162,8 @@ class OpenAIGPTConfig(object): ...@@ -160,7 +162,8 @@ class OpenAIGPTConfig(object):
initializer_range: The sttdev of the truncated_normal_initializer for initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices. initializing all weight matrices.
""" """
if isinstance(vocab_size_or_config_json_file, str): if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
json_config = json.loads(reader.read()) json_config = json.loads(reader.read())
for key, value in json_config.items(): for key, value in json_config.items():
...@@ -442,7 +445,7 @@ class OpenAIGPTPreTrainedModel(nn.Module): ...@@ -442,7 +445,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
except FileNotFoundError: except EnvironmentError:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file " "We assumed '{}' was a path or url but couldn't find any file "
...@@ -641,7 +644,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -641,7 +644,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
hidden_states = inputs_embeds + position_embeds + token_type_embeds hidden_states = inputs_embeds + position_embeds + token_type_embeds
for block in self.h: for block in self.h:
hidden_states = block(hidden_states) hidden_states = block(hidden_states)
return hidden_states.view(*input_shape, hidden_states.size(-1)) output_shape = input_shape + (hidden_states.size(-1),)
return hidden_states.view(*output_shape)
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
......
...@@ -27,6 +27,8 @@ import tarfile ...@@ -27,6 +27,8 @@ import tarfile
import tempfile import tempfile
import shutil import shutil
import collections import collections
import sys
from io import open
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -124,7 +126,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path): ...@@ -124,7 +126,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
try: try:
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
except ModuleNotFoundError: except ImportError:
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see " print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.") "https://www.tensorflow.org/install/ for installation instructions.")
raise raise
...@@ -239,7 +241,8 @@ class TransfoXLConfig(object): ...@@ -239,7 +241,8 @@ class TransfoXLConfig(object):
proj_init_std: parameters initialized by N(0, init_std) proj_init_std: parameters initialized by N(0, init_std)
init_std: parameters initialized by N(0, init_std) init_std: parameters initialized by N(0, init_std)
""" """
if isinstance(vocab_size_or_config_json_file, str): if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
and isinstance(vocab_size_or_config_json_file, unicode)):
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
json_config = json.loads(reader.read()) json_config = json.loads(reader.read())
for key, value in json_config.items(): for key, value in json_config.items():
...@@ -503,11 +506,12 @@ class RelMultiHeadAttn(nn.Module): ...@@ -503,11 +506,12 @@ class RelMultiHeadAttn(nn.Module):
return x return x
def _rel_shift(self, x, zero_triu=False): def _rel_shift(self, x, zero_triu=False):
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]), zero_pad_shape = (x.size(0), 1) + x.size()[2:]
device=x.device, dtype=x.dtype) zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
x_padded = torch.cat([zero_pad, x], dim=1) x_padded = torch.cat([zero_pad, x], dim=1)
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:]) x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:]
x_padded = x_padded.view(*x_padded_shape)
x = x_padded[1:].view_as(x) x = x_padded[1:].view_as(x)
...@@ -797,7 +801,8 @@ class AdaptiveEmbedding(nn.Module): ...@@ -797,7 +801,8 @@ class AdaptiveEmbedding(nn.Module):
emb_flat.index_copy_(0, indices_i, emb_i) emb_flat.index_copy_(0, indices_i, emb_i)
embed = emb_flat.view(*inp.size(), self.d_proj) embed_shape = inp.size() + (self.d_proj,)
embed = emb_flat.view(embed_shape)
embed.mul_(self.emb_scale) embed.mul_(self.emb_scale)
...@@ -905,7 +910,7 @@ class TransfoXLPreTrainedModel(nn.Module): ...@@ -905,7 +910,7 @@ class TransfoXLPreTrainedModel(nn.Module):
try: try:
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir) resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir) resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except FileNotFoundError: except EnvironmentError:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} " "We assumed '{}' was a path or url but couldn't find files {} and {} "
......
...@@ -14,14 +14,13 @@ ...@@ -14,14 +14,13 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes.""" """Tokenization classes."""
from __future__ import absolute_import from __future__ import absolute_import, division, print_function, unicode_literals
from __future__ import division
from __future__ import print_function
import collections import collections
import unicodedata
import os
import logging import logging
import os
import unicodedata
from io import open
from .file_utils import cached_path from .file_utils import cached_path
...@@ -129,7 +128,7 @@ class BertTokenizer(object): ...@@ -129,7 +128,7 @@ class BertTokenizer(object):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except FileNotFoundError: except EnvironmentError:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file " "We assumed '{}' was a path or url but couldn't find any file "
......
...@@ -13,11 +13,17 @@ ...@@ -13,11 +13,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for OpenAI GPT.""" """Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import json
import logging
import os import os
import re import re
import json import sys
from io import open
from tqdm import tqdm from tqdm import tqdm
import logging
from .file_utils import cached_path from .file_utils import cached_path
...@@ -82,7 +88,7 @@ class OpenAIGPTTokenizer(object): ...@@ -82,7 +88,7 @@ class OpenAIGPTTokenizer(object):
try: try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
except FileNotFoundError: except EnvironmentError:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} and {} " "We assumed '{}' was a path or url but couldn't find files {} and {} "
...@@ -119,7 +125,7 @@ class OpenAIGPTTokenizer(object): ...@@ -119,7 +125,7 @@ class OpenAIGPTTokenizer(object):
self.max_len = max_len if max_len is not None else int(1e12) self.max_len = max_len if max_len is not None else int(1e12)
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.fix_text = ftfy.fix_text self.fix_text = ftfy.fix_text
self.encoder = json.load(open(vocab_file)) self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()} self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
merges = [tuple(merge.split()) for merge in merges] merges = [tuple(merge.split()) for merge in merges]
...@@ -196,7 +202,7 @@ class OpenAIGPTTokenizer(object): ...@@ -196,7 +202,7 @@ class OpenAIGPTTokenizer(object):
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
"""Converts a sequence of tokens into ids using the vocab.""" """Converts a sequence of tokens into ids using the vocab."""
ids = [] ids = []
if isinstance(tokens, str): if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
if tokens in self.special_tokens: if tokens in self.special_tokens:
return self.special_tokens[tokens] return self.special_tokens[tokens]
else: else:
......
...@@ -16,16 +16,27 @@ ...@@ -16,16 +16,27 @@
""" Tokenization classes for Transformer XL model. """ Tokenization classes for Transformer XL model.
Adapted from https://github.com/kimiyoung/transformer-xl. Adapted from https://github.com/kimiyoung/transformer-xl.
""" """
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import os
import glob import glob
import logging import logging
import pickle import os
import torch import sys
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from io import open
import torch
import numpy as np
from .file_utils import cached_path from .file_utils import cached_path
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = { PRETRAINED_VOCAB_ARCHIVE_MAP = {
...@@ -55,7 +66,7 @@ class TransfoXLTokenizer(object): ...@@ -55,7 +66,7 @@ class TransfoXLTokenizer(object):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
except FileNotFoundError: except EnvironmentError:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} " "We assumed '{}' was a path or url but couldn't find files {} "
...@@ -422,7 +433,7 @@ class TransfoXLCorpus(object): ...@@ -422,7 +433,7 @@ class TransfoXLCorpus(object):
# redirect to the cache, if necessary # redirect to the cache, if necessary
try: try:
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir) resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
except FileNotFoundError: except EnvironmentError:
logger.error( logger.error(
"Model name '{}' was not found in model name list ({}). " "Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find files {} " "We assumed '{}' was a path or url but couldn't find files {} "
......
...@@ -33,6 +33,7 @@ To create the package for pypi. ...@@ -33,6 +33,7 @@ To create the package for pypi.
7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory. 7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
""" """
from io import open
from setuptools import find_packages, setup from setuptools import find_packages, setup
setup( setup(
...@@ -58,7 +59,7 @@ setup( ...@@ -58,7 +59,7 @@ setup(
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main", "pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main",
] ]
}, },
python_requires='>=3.5.0', # python_requires='>=3.5.0',
tests_require=['pytest'], tests_require=['pytest'],
classifiers=[ classifiers=[
'Intended Audience :: Science/Research', 'Intended Audience :: Science/Research',
......
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
import os import os
import unittest import unittest
from io import open
from pytorch_pretrained_bert.tokenization import (BertTokenizer, BasicTokenizer, WordpieceTokenizer, from pytorch_pretrained_bert.tokenization import (BertTokenizer, BasicTokenizer, WordpieceTokenizer,
_is_whitespace, _is_control, _is_punctuation) _is_whitespace, _is_control, _is_punctuation)
...@@ -30,7 +31,7 @@ class TokenizationTest(unittest.TestCase): ...@@ -30,7 +31,7 @@ class TokenizationTest(unittest.TestCase):
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", "," "##ing", ","
] ]
with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer: with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name vocab_file = vocab_writer.name
...@@ -49,7 +50,7 @@ class TokenizationTest(unittest.TestCase): ...@@ -49,7 +50,7 @@ class TokenizationTest(unittest.TestCase):
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
"##ing", "," "##ing", ","
] ]
with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer: with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens])) vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name vocab_file = vocab_writer.name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment