Commit 1de35b62 authored by thomwolf's avatar thomwolf
Browse files

preparing for first release

parent 8513741b
# How to Contribute
BERT needs to maintain permanent compatibility with the pre-trained model files,
so we do not plan to make any major changes to this library (other than what was
promised in the README). However, we can accept small patches related to
re-factoring and documentation. To submit contributes, there are just a few
small guidelines you need to follow.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution;
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to <https://cla.developers.google.com/> to see
your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
...@@ -8,29 +8,26 @@ This implementation can load any pre-trained TensorFlow checkpoint for BERT (in ...@@ -8,29 +8,26 @@ This implementation can load any pre-trained TensorFlow checkpoint for BERT (in
The code to use, in addition, [the Multilingual and Chinese models](https://github.com/google-research/bert/blob/master/multilingual.md) will be added later this week (it's actually just the tokenization code that needs to be updated). The code to use, in addition, [the Multilingual and Chinese models](https://github.com/google-research/bert/blob/master/multilingual.md) will be added later this week (it's actually just the tokenization code that needs to be updated).
## Loading a TensorFlow checkpoint (e.g. [Google's pre-trained models](https://github.com/google-research/bert#pre-trained-models)) ## Installation, requirements, test
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`). This code was tested on Python 3.5+. The requirements are:
You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too. - PyTorch (>= 0.4.1)
- tqdm
To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch. To install the dependencies:
Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model: ````bash
pip install -r ./requirements.txt
````
```shell A series of tests is included in the [tests folder](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/tests) and can be run using `pytest` (install pytest if needed: `pip install pytest`).
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
python convert_tf_checkpoint_to_pytorch.py \ You can run the tests with the command:
--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \ ```bash
--bert_config_file $BERT_BASE_DIR/bert_config.json \ python -m pytest -sv tests/
--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin
``` ```
You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models).
## PyTorch models for BERT ## PyTorch models for BERT
We included three PyTorch models in this repository that you will find in [`modeling.py`](modeling.py): We included three PyTorch models in this repository that you will find in [`modeling.py`](modeling.py):
...@@ -52,10 +49,15 @@ We detail them here. This model takes as inputs: ...@@ -52,10 +49,15 @@ We detail them here. This model takes as inputs:
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`), and - `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`), and
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). - `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. - `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
This model outputs a tuple composed of: This model outputs a tuple composed of:
- `all_encoder_layers`: a list of torch.FloatTensor of size [batch_size, sequence_length, hidden_size] which is a list of the full sequences of hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), and - `encoded_layers`: controled by the value of the `output_encoded_layers` argument:
. `output_all_encoded_layers=True`: outputs a list of the encoded-hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
. `output_all_encoded_layers=False`: outputs only the encoded-hidden-states corresponding to the last attention block,
- `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper). - `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
An example on how to use this class is given in the `extract_features.py` script which can be used to extract the hidden states of the model for a given input. An example on how to use this class is given in the `extract_features.py` script which can be used to extract the hidden states of the model for a given input.
...@@ -76,26 +78,30 @@ The token-level classifier takes as input the full sequence of the last hidden s ...@@ -76,26 +78,30 @@ The token-level classifier takes as input the full sequence of the last hidden s
An example on how to use this class is given in the `run_squad.py` script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task. An example on how to use this class is given in the `run_squad.py` script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task.
## Installation, requirements, test
This code was tested on Python 3.5+. The requirements are: ## Converting a TensorFlow checkpoint in a PyTorch checkpoint
- PyTorch (>= 0.4.1) You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
- tqdm
To install the dependencies: This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`).
````bash You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too.
pip install -r ./requirements.txt
````
A series of tests is included in the [tests folder](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/tests) and can be run using `pytest` (install pytest if needed: `pip install pytest`). To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch.
You can run the tests with the command: Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model:
```bash
python -m pytest -sv tests/ ```shell
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
python convert_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \
--bert_config_file $BERT_BASE_DIR/bert_config.json \
--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin
``` ```
You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models).
## Training on large batches: gradient accumulation, multi-GPU and distributed training ## Training on large batches: gradient accumulation, multi-GPU and distributed training
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32). BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32).
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/sh
python -m pytorch_pretrained_bert "$@"
\ No newline at end of file
...@@ -19,18 +19,17 @@ from __future__ import division ...@@ -19,18 +19,17 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse import argparse
import codecs
import collections import collections
import logging import logging
import json import json
import re import re
import torch import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
import tokenization from pytorch_pretrained_bert.tokenization import convert_to_unicode, BertTokenizer
from modeling import BertConfig, BertModel from pytorch_pretrained_bert.modeling import BertConfig, BertModel
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -171,7 +170,7 @@ def read_examples(input_file): ...@@ -171,7 +170,7 @@ def read_examples(input_file):
unique_id = 0 unique_id = 0
with open(input_file, "r") as reader: with open(input_file, "r") as reader:
while True: while True:
line = tokenization.convert_to_unicode(reader.readline()) line = convert_to_unicode(reader.readline())
if not line: if not line:
break break
line = line.strip() line = line.strip()
...@@ -227,13 +226,13 @@ def main(): ...@@ -227,13 +226,13 @@ def main():
n_gpu = 1 n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl') torch.distributed.init_process_group(backend='nccl')
logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1)) logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1)))
layer_indexes = [int(x) for x in args.layers.split(",")] layer_indexes = [int(x) for x in args.layers.split(",")]
bert_config = BertConfig.from_json_file(args.bert_config_file) bert_config = BertConfig.from_json_file(args.bert_config_file)
tokenizer = tokenization.FullTokenizer( tokenizer = BertTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
examples = read_examples(args.input_file) examples = read_examples(args.input_file)
......
...@@ -30,9 +30,9 @@ import torch ...@@ -30,9 +30,9 @@ import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
import tokenization from pytorch_pretrained_bert.tokenization import printable_text, convert_to_unicode, BertTokenizer
from modeling import BertConfig, BertForSequenceClassification from pytorch_pretrained_bert.modeling import BertConfig, BertForSequenceClassification
from optimization import BERTAdam from pytorch_pretrained_bert.optimization import BERTAdam
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -122,9 +122,9 @@ class MrpcProcessor(DataProcessor): ...@@ -122,9 +122,9 @@ class MrpcProcessor(DataProcessor):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3]) text_a = convert_to_unicode(line[3])
text_b = tokenization.convert_to_unicode(line[4]) text_b = convert_to_unicode(line[4])
label = tokenization.convert_to_unicode(line[0]) label = convert_to_unicode(line[0])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
...@@ -154,14 +154,14 @@ class MnliProcessor(DataProcessor): ...@@ -154,14 +154,14 @@ class MnliProcessor(DataProcessor):
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0: if i == 0:
continue continue
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0])) guid = "%s-%s" % (set_type, convert_to_unicode(line[0]))
text_a = tokenization.convert_to_unicode(line[8]) text_a = convert_to_unicode(line[8])
text_b = tokenization.convert_to_unicode(line[9]) text_b = convert_to_unicode(line[9])
label = tokenization.convert_to_unicode(line[-1]) label = convert_to_unicode(line[-1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)) InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples return examples
class ColaProcessor(DataProcessor): class ColaProcessor(DataProcessor):
"""Processor for the CoLA data set (GLUE version).""" """Processor for the CoLA data set (GLUE version)."""
...@@ -185,8 +185,8 @@ class ColaProcessor(DataProcessor): ...@@ -185,8 +185,8 @@ class ColaProcessor(DataProcessor):
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i) guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3]) text_a = convert_to_unicode(line[3])
label = tokenization.convert_to_unicode(line[1]) label = convert_to_unicode(line[1])
examples.append( examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label)) InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples return examples
...@@ -273,7 +273,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer ...@@ -273,7 +273,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
logger.info("*** Example ***") logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid)) logger.info("guid: %s" % (example.guid))
logger.info("tokens: %s" % " ".join( logger.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in tokens])) [printable_text(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info( logger.info(
...@@ -281,11 +281,10 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer ...@@ -281,11 +281,10 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
logger.info("label: %s (id = %d)" % (example.label, label_id)) logger.info("label: %s (id = %d)" % (example.label, label_id))
features.append( features.append(
InputFeatures( InputFeatures(input_ids=input_ids,
input_ids=input_ids, input_mask=input_mask,
input_mask=input_mask, segment_ids=segment_ids,
segment_ids=segment_ids, label_id=label_id))
label_id=label_id))
return features return features
...@@ -307,7 +306,7 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length): ...@@ -307,7 +306,7 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
def accuracy(out, labels): def accuracy(out, labels):
outputs = np.argmax(out, axis=1) outputs = np.argmax(out, axis=1)
return np.sum(outputs==labels) return np.sum(outputs == labels)
def copy_optimizer_params_to_model(named_params_model, named_params_optimizer): def copy_optimizer_params_to_model(named_params_model, named_params_optimizer):
""" Utility function for optimize_on_cpu and 16-bits training. """ Utility function for optimize_on_cpu and 16-bits training.
...@@ -497,7 +496,7 @@ def main(): ...@@ -497,7 +496,7 @@ def main():
processor = processors[task_name]() processor = processors[task_name]()
label_list = processor.get_labels() label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer( tokenizer = BertTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
train_examples = None train_examples = None
......
...@@ -25,7 +25,6 @@ import json ...@@ -25,7 +25,6 @@ import json
import math import math
import os import os
import random import random
import six
from tqdm import tqdm, trange from tqdm import tqdm, trange
import numpy as np import numpy as np
...@@ -33,9 +32,9 @@ import torch ...@@ -33,9 +32,9 @@ import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
import tokenization from pytorch_pretrained_bert.tokenization import printable_text, whitespace_tokenize, BasicTokenizer, BertTokenizer
from modeling import BertConfig, BertForQuestionAnswering from pytorch_pretrained_bert.modeling import BertConfig, BertForQuestionAnswering
from optimization import BERTAdam from pytorch_pretrained_bert.optimization import BERTAdam
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -65,9 +64,9 @@ class SquadExample(object): ...@@ -65,9 +64,9 @@ class SquadExample(object):
def __repr__(self): def __repr__(self):
s = "" s = ""
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id)) s += "qas_id: %s" % (printable_text(self.qas_id))
s += ", question_text: %s" % ( s += ", question_text: %s" % (
tokenization.printable_text(self.question_text)) printable_text(self.question_text))
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens)) s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
if self.start_position: if self.start_position:
s += ", start_position: %d" % (self.start_position) s += ", start_position: %d" % (self.start_position)
...@@ -156,7 +155,7 @@ def read_squad_examples(input_file, is_training): ...@@ -156,7 +155,7 @@ def read_squad_examples(input_file, is_training):
# guaranteed to be preserved. # guaranteed to be preserved.
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)]) actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = " ".join( cleaned_answer_text = " ".join(
tokenization.whitespace_tokenize(orig_answer_text)) whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1: if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'", logger.warning("Could not find answer: '%s' vs. '%s'",
actual_text, cleaned_answer_text) actual_text, cleaned_answer_text)
...@@ -290,11 +289,11 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -290,11 +289,11 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
logger.info("example_index: %s" % (example_index)) logger.info("example_index: %s" % (example_index))
logger.info("doc_span_index: %s" % (doc_span_index)) logger.info("doc_span_index: %s" % (doc_span_index))
logger.info("tokens: %s" % " ".join( logger.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in tokens])) [printable_text(x) for x in tokens]))
logger.info("token_to_orig_map: %s" % " ".join( logger.info("token_to_orig_map: %s" % " ".join([
["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)])) "%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()]))
logger.info("token_is_max_context: %s" % " ".join([ logger.info("token_is_max_context: %s" % " ".join([
"%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context) "%d:%s" % (x, y) for (x, y) in token_is_max_context.items()
])) ]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info( logger.info(
...@@ -306,7 +305,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -306,7 +305,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
logger.info("start_position: %d" % (start_position)) logger.info("start_position: %d" % (start_position))
logger.info("end_position: %d" % (end_position)) logger.info("end_position: %d" % (end_position))
logger.info( logger.info(
"answer: %s" % (tokenization.printable_text(answer_text))) "answer: %s" % (printable_text(answer_text)))
features.append( features.append(
InputFeatures( InputFeatures(
...@@ -582,7 +581,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -582,7 +581,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
# and `pred_text`, and check if they are the same length. If they are # and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same # NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned. # length, we assume the characters are one-to-one aligned.
tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
tok_text = " ".join(tokenizer.tokenize(orig_text)) tok_text = " ".join(tokenizer.tokenize(orig_text))
...@@ -606,7 +605,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -606,7 +605,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
# We then project the characters in `pred_text` back to `orig_text` using # We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment. # the character-to-character alignment.
tok_s_to_ns_map = {} tok_s_to_ns_map = {}
for (i, tok_index) in six.iteritems(tok_ns_to_s_map): for (i, tok_index) in tok_ns_to_s_map.items():
tok_s_to_ns_map[tok_index] = i tok_s_to_ns_map[tok_index] = i
orig_start_position = None orig_start_position = None
...@@ -827,7 +826,7 @@ def main(): ...@@ -827,7 +826,7 @@ def main():
raise ValueError("Output directory () already exists and is not empty.") raise ValueError("Output directory () already exists and is not empty.")
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
tokenizer = tokenization.FullTokenizer( tokenizer = BertTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case) vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
train_examples = None train_examples = None
......
...@@ -463,7 +463,7 @@ ...@@ -463,7 +463,7 @@
], ],
"source": [ "source": [
"bert_config = modeling_tensorflow.BertConfig.from_json_file(bert_config_file)\n", "bert_config = modeling_tensorflow.BertConfig.from_json_file(bert_config_file)\n",
"tokenizer = tokenization.FullTokenizer(\n", "tokenizer = tokenization.BertTokenizer(\n",
" vocab_file=vocab_file, do_lower_case=True)\n", " vocab_file=vocab_file, do_lower_case=True)\n",
"\n", "\n",
"eval_examples = read_squad_examples(\n", "eval_examples = read_squad_examples(\n",
......
This diff is collapsed.
This diff is collapsed.
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForQuestionAnswering)
from .optimization import BERTAdam
# coding: utf8
if __name__ == '__main__':
import sys
try:
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ModuleNotFoundError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if len(sys.argv) != 5:
# pylint: disable=line-too-long
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
else:
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
...@@ -18,66 +18,39 @@ from __future__ import absolute_import ...@@ -18,66 +18,39 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os
import re import re
import argparse import argparse
import tensorflow as tf import tensorflow as tf
import torch import torch
import numpy as np import numpy as np
from modeling import BertConfig, BertModel from .modeling import BertConfig, BertForPreTraining
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
def convert():
# Initialise PyTorch model
config = BertConfig.from_json_file(args.bert_config_file)
model = BertModel(config)
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
config_path = os.path.abspath(bert_config_file)
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path))
# Load weights from TF model # Load weights from TF model
path = args.tf_checkpoint_path init_vars = tf.train.list_variables(tf_path)
print("Converting TensorFlow checkpoint from {}".format(path))
init_vars = tf.train.list_variables(path)
names = [] names = []
arrays = [] arrays = []
for name, shape in init_vars: for name, shape in init_vars:
print("Loading {} with shape {}".format(name, shape)) print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(path, name) array = tf.train.load_variable(tf_path, name)
print("Numpy array shape {}".format(array.shape))
names.append(name) names.append(name)
arrays.append(array) arrays.append(array)
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)
for name, array in zip(names, arrays): for name, array in zip(names, arrays):
if not name.startswith("bert"):
print("Skipping {}".format(name))
continue
else:
name = name.replace("bert/", "") # skip "bert/"
print("Loading {}".format(name))
name = name.split('/') name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model # which are not required for using pretrained model
if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m": if name[-1] in ["adam_v", "adam_m"]:
print("Skipping {}".format("/".join(name))) print("Skipping {}".format("/".join(name)))
continue continue
pointer = model pointer = model
...@@ -88,6 +61,10 @@ def convert(): ...@@ -88,6 +61,10 @@ def convert():
l = [m_name] l = [m_name]
if l[0] == 'kernel': if l[0] == 'kernel':
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else: else:
pointer = getattr(pointer, l[0]) pointer = getattr(pointer, l[0])
if len(l) >= 2: if len(l) >= 2:
...@@ -102,10 +79,34 @@ def convert(): ...@@ -102,10 +79,34 @@ def convert():
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array) pointer.data = torch.from_numpy(array)
# Save pytorch-model # Save pytorch-model
torch.save(model.state_dict(), args.pytorch_dump_path) print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__": if __name__ == "__main__":
convert() parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.bert_config_file,
args.pytorch_dump_path)
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
import os
import logging
import shutil
import tempfile
import json
from urllib.parse import urlparse
from pathlib import Path
from typing import Optional, Tuple, Union, IO, Callable, Set
from hashlib import sha256
from functools import wraps
from tqdm import tqdm
import boto3
from botocore.exceptions import ClientError
import requests
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert'))
def url_to_filename(url: str, etag: str = None) -> str:
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode('utf-8')
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
return filename
def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise FileNotFoundError("file {} not found".format(cache_path))
meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise FileNotFoundError("file {} not found".format(meta_path))
with open(meta_path) as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']
return url, etag
def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str:
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
parsed = urlparse(url_or_filename)
if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == '':
# File, but it doesn't exist.
raise FileNotFoundError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url: str) -> Tuple[str, str]:
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func: Callable):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url: str, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise FileNotFoundError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url: str) -> Optional[str]:
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url: str, temp_file: IO) -> None:
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url: str, temp_file: IO) -> None:
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url: str, cache_dir: str = None) -> str:
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
os.makedirs(cache_dir, exist_ok=True)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError("HEAD request failed for url {} with status code {}"
.format(url, response.status_code))
etag = response.headers.get("ETag")
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
else:
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w') as meta_file:
json.dump(meta, meta_file)
logger.info("removing temp file %s", temp_file.name)
return cache_path
def read_set_from_file(filename: str) -> Set[str]:
'''
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
'''
collection = set()
with open(filename, 'r') as file_:
for line in file_:
collection.add(line.rstrip())
return collection
def get_file_extension(path: str, dot=True, lower: bool = True):
ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:]
return ext.lower() if lower else ext
...@@ -20,27 +20,32 @@ from __future__ import print_function ...@@ -20,27 +20,32 @@ from __future__ import print_function
import collections import collections
import unicodedata import unicodedata
import six import os
import logging
from .file_utils import cached_path
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-base-multilingual': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
}
def convert_to_unicode(text): def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input.""" """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3: if isinstance(text, str):
if isinstance(text, str): return text
return text elif isinstance(text, bytes):
elif isinstance(text, bytes): return text.decode("utf-8", "ignore")
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else: else:
raise ValueError("Not running on Python2 or Python 3?") raise ValueError("Unsupported string type: %s" % (type(text)))
def printable_text(text): def printable_text(text):
...@@ -48,22 +53,12 @@ def printable_text(text): ...@@ -48,22 +53,12 @@ def printable_text(text):
# These functions want `str` for both Python2 and Python3, but in one case # These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string. # it's a Unicode string and in the other it's a byte string.
if six.PY3: if isinstance(text, str):
if isinstance(text, str): return text
return text elif isinstance(text, bytes):
elif isinstance(text, bytes): return text.decode("utf-8", "ignore")
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else: else:
raise ValueError("Not running on Python2 or Python 3?") raise ValueError("Unsupported string type: %s" % (type(text)))
def load_vocab(vocab_file): def load_vocab(vocab_file):
...@@ -81,14 +76,6 @@ def load_vocab(vocab_file): ...@@ -81,14 +76,6 @@ def load_vocab(vocab_file):
return vocab return vocab
def convert_tokens_to_ids(vocab, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
ids.append(vocab[token])
return ids
def whitespace_tokenize(text): def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a peice of text.""" """Runs basic whitespace cleaning and splitting on a peice of text."""
text = text.strip() text = text.strip()
...@@ -98,11 +85,16 @@ def whitespace_tokenize(text): ...@@ -98,11 +85,16 @@ def whitespace_tokenize(text):
return tokens return tokens
class FullTokenizer(object): class BertTokenizer(object):
"""Runs end-to-end tokenziation.""" """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True): def __init__(self, vocab_file, do_lower_case=True):
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
self.vocab = load_vocab(vocab_file) self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case) self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
...@@ -111,11 +103,52 @@ class FullTokenizer(object): ...@@ -111,11 +103,52 @@ class FullTokenizer(object):
for token in self.basic_tokenizer.tokenize(text): for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token): for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token) split_tokens.append(sub_token)
return split_tokens return split_tokens
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
return convert_tokens_to_ids(self.vocab, tokens) """Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
ids.append(self.vocab[token])
return ids
def convert_ids_to_tokens(self, ids):
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
tokens = []
for i in ids:
tokens.append(self.ids_to_tokens[i])
return tokens
@classmethod
def from_pretrained(cls, pretrained_model_name, do_lower_case=True):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
else:
vocab_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file)
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, do_lower_case)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name))
tokenizer = None
return tokenizer
class BasicTokenizer(object): class BasicTokenizer(object):
......
torch # This installs Pytorch for CUDA 8 only. If you are using a newer version,
tqdm # please visit http://pytorch.org/ and install the relevant version.
\ No newline at end of file torch>=0.4.1,<0.5.0
# progress bars in model download and training scripts
tqdm>=4.19
# Accessing files from S3 directly.
boto3
# Used for downloading models over HTTP
requests>=2.18
\ No newline at end of file
from setuptools import find_packages, setup
setup(
name="pytorch_pretrained_bert",
version="0.1.0",
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors",
author_email="thomas@huggingface.co",
description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",
long_description=open("README.md", "r").read(),
long_description_content_type="text/markdown",
keywords='BERT NLP deep learning google',
license='Apache',
url="https://github.com/huggingface/pytorch-pretrained-BERT",
packages=find_packages(exclude=["*.tests", "*.tests.*",
"tests.*", "tests"]),
install_requires=['numpy',
'torch>=0.4.1',
'boto3',
'requests>=2.18',
'tqdm>=4.19'],
scripts=["bin/pytorch_pretrained_bert"],
python_requires='>=3.5.0',
tests_require=['pytest'],
classifiers=[
'Intended Audience :: Science/Research',
'Development Status :: 1 - Alpha',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
)
...@@ -34,7 +34,7 @@ class TokenizationTest(unittest.TestCase): ...@@ -34,7 +34,7 @@ class TokenizationTest(unittest.TestCase):
vocab_file = vocab_writer.name vocab_file = vocab_writer.name
tokenizer = tokenization.FullTokenizer(vocab_file) tokenizer = tokenization.BertTokenizer(vocab_file)
os.remove(vocab_file) os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running") tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment