Commit 1de35b62 authored by thomwolf's avatar thomwolf
Browse files

preparing for first release

parent 8513741b
# How to Contribute
BERT needs to maintain permanent compatibility with the pre-trained model files,
so we do not plan to make any major changes to this library (other than what was
promised in the README). However, we can accept small patches related to
re-factoring and documentation. To submit contributes, there are just a few
small guidelines you need to follow.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution;
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to <https://cla.developers.google.com/> to see
your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google.com/conduct/).
......@@ -8,29 +8,26 @@ This implementation can load any pre-trained TensorFlow checkpoint for BERT (in
The code to use, in addition, [the Multilingual and Chinese models](https://github.com/google-research/bert/blob/master/multilingual.md) will be added later this week (it's actually just the tokenization code that needs to be updated).
## Loading a TensorFlow checkpoint (e.g. [Google's pre-trained models](https://github.com/google-research/bert#pre-trained-models))
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
## Installation, requirements, test
This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`).
This code was tested on Python 3.5+. The requirements are:
You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too.
- PyTorch (>= 0.4.1)
- tqdm
To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch.
To install the dependencies:
Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model:
````bash
pip install -r ./requirements.txt
````
```shell
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
A series of tests is included in the [tests folder](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/tests) and can be run using `pytest` (install pytest if needed: `pip install pytest`).
python convert_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \
--bert_config_file $BERT_BASE_DIR/bert_config.json \
--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin
You can run the tests with the command:
```bash
python -m pytest -sv tests/
```
You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models).
## PyTorch models for BERT
We included three PyTorch models in this repository that you will find in [`modeling.py`](modeling.py):
......@@ -52,10 +49,15 @@ We detail them here. This model takes as inputs:
- `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary (see the tokens preprocessing logic in the scripts `extract_features.py`, `run_classifier.py` and `run_squad.py`), and
- `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details).
- `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences.
- `output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
This model outputs a tuple composed of:
- `all_encoder_layers`: a list of torch.FloatTensor of size [batch_size, sequence_length, hidden_size] which is a list of the full sequences of hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), and
- `encoded_layers`: controled by the value of the `output_encoded_layers` argument:
. `output_all_encoded_layers=True`: outputs a list of the encoded-hidden-states at the end of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
. `output_all_encoded_layers=False`: outputs only the encoded-hidden-states corresponding to the last attention block,
- `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a classifier pretrained on top of the hidden state associated to the first character of the input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
An example on how to use this class is given in the `extract_features.py` script which can be used to extract the hidden states of the model for a given input.
......@@ -76,26 +78,30 @@ The token-level classifier takes as input the full sequence of the last hidden s
An example on how to use this class is given in the `run_squad.py` script which can be used to fine-tune a token classifier using BERT, for example for the SQuAD task.
## Installation, requirements, test
This code was tested on Python 3.5+. The requirements are:
## Converting a TensorFlow checkpoint in a PyTorch checkpoint
- PyTorch (>= 0.4.1)
- tqdm
You can convert any TensorFlow checkpoint for BERT (in particular [the pre-trained models released by Google](https://github.com/google-research/bert#pre-trained-models)) in a PyTorch save file by using the [`convert_tf_checkpoint_to_pytorch.py`](convert_tf_checkpoint_to_pytorch.py) script.
To install the dependencies:
This script takes as input a TensorFlow checkpoint (three files starting with `bert_model.ckpt`) and the associated configuration file (`bert_config.json`), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using `torch.load()` (see examples in `extract_features.py`, `run_classifier.py` and `run_squad.py`).
````bash
pip install -r ./requirements.txt
````
You only need to run this conversion script **once** to get a PyTorch model. You can then disregard the TensorFlow checkpoint (the three files starting with `bert_model.ckpt`) but be sure to keep the configuration file (`bert_config.json`) and the vocabulary file (`vocab.txt`) as these are needed for the PyTorch model too.
A series of tests is included in the [tests folder](https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/tests) and can be run using `pytest` (install pytest if needed: `pip install pytest`).
To run this specific conversion script you will need to have TensorFlow and PyTorch installed (`pip install tensorflow`). The rest of the repository only requires PyTorch.
You can run the tests with the command:
```bash
python -m pytest -sv tests/
Here is an example of the conversion process for a pre-trained `BERT-Base Uncased` model:
```shell
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
python convert_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path $BERT_BASE_DIR/bert_model.ckpt \
--bert_config_file $BERT_BASE_DIR/bert_config.json \
--pytorch_dump_path $BERT_BASE_DIR/pytorch_model.bin
```
You can download Google's pre-trained models for the conversion [here](https://github.com/google-research/bert#pre-trained-models).
## Training on large batches: gradient accumulation, multi-GPU and distributed training
BERT-base and BERT-large are respectively 110M and 340M parameters models and it can be difficult to fine-tune them on a single GPU with the recommended batch size for good performance (in most case a batch size of 32).
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/bin/sh
python -m pytorch_pretrained_bert "$@"
\ No newline at end of file
......@@ -19,18 +19,17 @@ from __future__ import division
from __future__ import print_function
import argparse
import codecs
import collections
import logging
import json
import re
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
import tokenization
from modeling import BertConfig, BertModel
from pytorch_pretrained_bert.tokenization import convert_to_unicode, BertTokenizer
from pytorch_pretrained_bert.modeling import BertConfig, BertModel
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
......@@ -171,7 +170,7 @@ def read_examples(input_file):
unique_id = 0
with open(input_file, "r") as reader:
while True:
line = tokenization.convert_to_unicode(reader.readline())
line = convert_to_unicode(reader.readline())
if not line:
break
line = line.strip()
......@@ -227,13 +226,13 @@ def main():
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
logger.info("device: {} n_gpu: {} distributed training: {}".format(device, n_gpu, bool(args.local_rank != -1)))
layer_indexes = [int(x) for x in args.layers.split(",")]
bert_config = BertConfig.from_json_file(args.bert_config_file)
tokenizer = tokenization.FullTokenizer(
tokenizer = BertTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
examples = read_examples(args.input_file)
......
......@@ -30,9 +30,9 @@ import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
import tokenization
from modeling import BertConfig, BertForSequenceClassification
from optimization import BERTAdam
from pytorch_pretrained_bert.tokenization import printable_text, convert_to_unicode, BertTokenizer
from pytorch_pretrained_bert.modeling import BertConfig, BertForSequenceClassification
from pytorch_pretrained_bert.optimization import BERTAdam
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
......@@ -122,9 +122,9 @@ class MrpcProcessor(DataProcessor):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
text_b = tokenization.convert_to_unicode(line[4])
label = tokenization.convert_to_unicode(line[0])
text_a = convert_to_unicode(line[3])
text_b = convert_to_unicode(line[4])
label = convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -154,10 +154,10 @@ class MnliProcessor(DataProcessor):
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, tokenization.convert_to_unicode(line[0]))
text_a = tokenization.convert_to_unicode(line[8])
text_b = tokenization.convert_to_unicode(line[9])
label = tokenization.convert_to_unicode(line[-1])
guid = "%s-%s" % (set_type, convert_to_unicode(line[0]))
text_a = convert_to_unicode(line[8])
text_b = convert_to_unicode(line[9])
label = convert_to_unicode(line[-1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
......@@ -185,8 +185,8 @@ class ColaProcessor(DataProcessor):
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
label = tokenization.convert_to_unicode(line[1])
text_a = convert_to_unicode(line[3])
label = convert_to_unicode(line[1])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
......@@ -273,7 +273,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in tokens]))
[printable_text(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info(
......@@ -281,8 +281,7 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
logger.info("label: %s (id = %d)" % (example.label, label_id))
features.append(
InputFeatures(
input_ids=input_ids,
InputFeatures(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id))
......@@ -307,7 +306,7 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
def accuracy(out, labels):
outputs = np.argmax(out, axis=1)
return np.sum(outputs==labels)
return np.sum(outputs == labels)
def copy_optimizer_params_to_model(named_params_model, named_params_optimizer):
""" Utility function for optimize_on_cpu and 16-bits training.
......@@ -497,7 +496,7 @@ def main():
processor = processors[task_name]()
label_list = processor.get_labels()
tokenizer = tokenization.FullTokenizer(
tokenizer = BertTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
train_examples = None
......
......@@ -25,7 +25,6 @@ import json
import math
import os
import random
import six
from tqdm import tqdm, trange
import numpy as np
......@@ -33,9 +32,9 @@ import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
import tokenization
from modeling import BertConfig, BertForQuestionAnswering
from optimization import BERTAdam
from pytorch_pretrained_bert.tokenization import printable_text, whitespace_tokenize, BasicTokenizer, BertTokenizer
from pytorch_pretrained_bert.modeling import BertConfig, BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BERTAdam
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
......@@ -65,9 +64,9 @@ class SquadExample(object):
def __repr__(self):
s = ""
s += "qas_id: %s" % (tokenization.printable_text(self.qas_id))
s += "qas_id: %s" % (printable_text(self.qas_id))
s += ", question_text: %s" % (
tokenization.printable_text(self.question_text))
printable_text(self.question_text))
s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
if self.start_position:
s += ", start_position: %d" % (self.start_position)
......@@ -156,7 +155,7 @@ def read_squad_examples(input_file, is_training):
# guaranteed to be preserved.
actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
cleaned_answer_text = " ".join(
tokenization.whitespace_tokenize(orig_answer_text))
whitespace_tokenize(orig_answer_text))
if actual_text.find(cleaned_answer_text) == -1:
logger.warning("Could not find answer: '%s' vs. '%s'",
actual_text, cleaned_answer_text)
......@@ -290,11 +289,11 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
logger.info("example_index: %s" % (example_index))
logger.info("doc_span_index: %s" % (doc_span_index))
logger.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in tokens]))
logger.info("token_to_orig_map: %s" % " ".join(
["%d:%d" % (x, y) for (x, y) in six.iteritems(token_to_orig_map)]))
[printable_text(x) for x in tokens]))
logger.info("token_to_orig_map: %s" % " ".join([
"%d:%d" % (x, y) for (x, y) in token_to_orig_map.items()]))
logger.info("token_is_max_context: %s" % " ".join([
"%d:%s" % (x, y) for (x, y) in six.iteritems(token_is_max_context)
"%d:%s" % (x, y) for (x, y) in token_is_max_context.items()
]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info(
......@@ -306,7 +305,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
logger.info("start_position: %d" % (start_position))
logger.info("end_position: %d" % (end_position))
logger.info(
"answer: %s" % (tokenization.printable_text(answer_text)))
"answer: %s" % (printable_text(answer_text)))
features.append(
InputFeatures(
......@@ -582,7 +581,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case)
tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
tok_text = " ".join(tokenizer.tokenize(orig_text))
......@@ -606,7 +605,7 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
# We then project the characters in `pred_text` back to `orig_text` using
# the character-to-character alignment.
tok_s_to_ns_map = {}
for (i, tok_index) in six.iteritems(tok_ns_to_s_map):
for (i, tok_index) in tok_ns_to_s_map.items():
tok_s_to_ns_map[tok_index] = i
orig_start_position = None
......@@ -827,7 +826,7 @@ def main():
raise ValueError("Output directory () already exists and is not empty.")
os.makedirs(args.output_dir, exist_ok=True)
tokenizer = tokenization.FullTokenizer(
tokenizer = BertTokenizer(
vocab_file=args.vocab_file, do_lower_case=args.do_lower_case)
train_examples = None
......
......@@ -463,7 +463,7 @@
],
"source": [
"bert_config = modeling_tensorflow.BertConfig.from_json_file(bert_config_file)\n",
"tokenizer = tokenization.FullTokenizer(\n",
"tokenizer = tokenization.BertTokenizer(\n",
" vocab_file=vocab_file, do_lower_case=True)\n",
"\n",
"eval_examples = read_squad_examples(\n",
......
......@@ -22,8 +22,8 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:50.559657Z",
"start_time": "2018-11-05T13:58:50.546096Z"
"end_time": "2018-11-15T14:56:48.412622Z",
"start_time": "2018-11-15T14:56:48.400110Z"
}
},
"outputs": [],
......@@ -44,8 +44,8 @@
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:50.574455Z",
"start_time": "2018-11-05T13:58:50.561988Z"
"end_time": "2018-11-15T14:56:49.483829Z",
"start_time": "2018-11-15T14:56:49.471296Z"
}
},
"outputs": [],
......@@ -63,19 +63,39 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:52.202531Z",
"start_time": "2018-11-05T13:58:50.576198Z"
"end_time": "2018-11-15T14:57:51.597932Z",
"start_time": "2018-11-15T14:57:51.549466Z"
}
},
"outputs": [],
"outputs": [
{
"ename": "DuplicateFlagError",
"evalue": "The flag 'input_file' is defined twice. First from *, Second from *. Description from first occurrence: (no help available)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mDuplicateFlagError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-6-86ecffb49060>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mspec\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mspec_from_file_location\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'*'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moriginal_tf_inplem_dir\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m'/extract_features_tensorflow.py'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_from_spec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mspec\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0mspec\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloader\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexec_module\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'extract_features_tensorflow'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/importlib/_bootstrap_external.py\u001b[0m in \u001b[0;36mexec_module\u001b[0;34m(self, module)\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/importlib/_bootstrap.py\u001b[0m in \u001b[0;36m_call_with_frames_removed\u001b[0;34m(f, *args, **kwds)\u001b[0m\n",
"\u001b[0;32m~/Documents/Thomas/Code/HF/BERT/pytorch-pretrained-BERT/tensorflow_code/extract_features_tensorflow.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0mFLAGS\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFLAGS\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 34\u001b[0;31m \u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDEFINE_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"input_file\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 35\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0mflags\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDEFINE_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"output_file\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/tensorflow/python/platform/flags.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;34m'Use of the keyword argument names (flag_name, default_value, '\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m 'docstring) is deprecated, please use (name, default, help) instead.')\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0moriginal_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtf_decorator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmake_decorator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moriginal_function\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_defines.py\u001b[0m in \u001b[0;36mDEFINE_string\u001b[0;34m(name, default, help, flag_values, **args)\u001b[0m\n\u001b[1;32m 239\u001b[0m \u001b[0mparser\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_argument_parser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArgumentParser\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[0mserializer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_argument_parser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mArgumentSerializer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 241\u001b[0;31m \u001b[0mDEFINE\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparser\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhelp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflag_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mserializer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 242\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_defines.py\u001b[0m in \u001b[0;36mDEFINE\u001b[0;34m(parser, name, default, help, flag_values, serializer, module_name, **args)\u001b[0m\n\u001b[1;32m 80\u001b[0m \"\"\"\n\u001b[1;32m 81\u001b[0m DEFINE_flag(_flag.Flag(parser, serializer, name, default, help, **args),\n\u001b[0;32m---> 82\u001b[0;31m flag_values, module_name)\n\u001b[0m\u001b[1;32m 83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_defines.py\u001b[0m in \u001b[0;36mDEFINE_flag\u001b[0;34m(flag, flag_values, module_name)\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;31m# Copying the reference to flag_values prevents pychecker warnings.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0mfv\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflag_values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 104\u001b[0;31m \u001b[0mfv\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mflag\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 105\u001b[0m \u001b[0;31m# Tell flag_values who's defining the flag.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmodule_name\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda3/envs/bert/lib/python3.6/site-packages/absl/flags/_flagvalues.py\u001b[0m in \u001b[0;36m__setitem__\u001b[0;34m(self, name, flag)\u001b[0m\n\u001b[1;32m 427\u001b[0m \u001b[0;31m# module is simply being imported a subsequent time.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 428\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 429\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0m_exceptions\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDuplicateFlagError\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_flag\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 430\u001b[0m \u001b[0mshort_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflag\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshort_name\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 431\u001b[0m \u001b[0;31m# If a new flag overrides an old one, we need to cleanup the old flag's\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDuplicateFlagError\u001b[0m: The flag 'input_file' is defined twice. First from *, Second from *. Description from first occurrence: (no help available)"
]
}
],
"source": [
"import importlib.util\n",
"import sys\n",
"\n",
"spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/extract_features.py')\n",
"spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/extract_features_tensorflow.py')\n",
"module = importlib.util.module_from_spec(spec)\n",
"spec.loader.exec_module(module)\n",
"sys.modules['extract_features_tensorflow'] = module\n",
......@@ -85,11 +105,11 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:52.325822Z",
"start_time": "2018-11-05T13:58:52.205361Z"
"end_time": "2018-11-15T14:58:05.650987Z",
"start_time": "2018-11-15T14:58:05.541620Z"
}
},
"outputs": [
......@@ -122,11 +142,11 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:55.939938Z",
"start_time": "2018-11-05T13:58:52.330202Z"
"end_time": "2018-11-15T14:58:11.562443Z",
"start_time": "2018-11-15T14:58:08.036485Z"
}
},
"outputs": [
......@@ -134,15 +154,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x12839dbf8>) includes params argument, but params are not passed to Estimator.\n",
"WARNING:tensorflow:Using temporary folder as model directory: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpdbx_h23u\n",
"INFO:tensorflow:Using config: {'_model_dir': '/var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpdbx_h23u', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
"WARNING:tensorflow:Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x11ea7f1e0>) includes params argument, but params are not passed to Estimator.\n",
"WARNING:tensorflow:Using temporary folder as model directory: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphs4_nsq9\n",
"INFO:tensorflow:Using config: {'_model_dir': '/var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphs4_nsq9', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
"graph_options {\n",
" rewrite_options {\n",
" meta_optimizer_iterations: ONE\n",
" }\n",
"}\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x12b3e1c18>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=1, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x121b163c8>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=1, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n",
"WARNING:tensorflow:Setting TPUConfig.num_shards==1 is an unsupported behavior. Please fix as soon as possible (leaving num_shards as None.\n",
"INFO:tensorflow:_TPUContext: eval_on_tpu True\n",
"WARNING:tensorflow:eval_on_tpu ignored because use_tpu is False.\n"
......@@ -178,11 +198,11 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:01.717585Z",
"start_time": "2018-11-05T13:58:55.941869Z"
"end_time": "2018-11-15T14:58:21.736543Z",
"start_time": "2018-11-15T14:58:16.723829Z"
}
},
"outputs": [
......@@ -190,7 +210,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Could not find trained model in model_dir: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpdbx_h23u, running initialization to predict.\n",
"INFO:tensorflow:Could not find trained model in model_dir: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmphs4_nsq9, running initialization to predict.\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Running infer on CPU\n",
"INFO:tensorflow:Done calling model_fn.\n",
......@@ -241,11 +261,11 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:01.769845Z",
"start_time": "2018-11-05T13:59:01.719878Z"
"end_time": "2018-11-15T14:58:23.970714Z",
"start_time": "2018-11-15T14:58:23.931930Z"
}
},
"outputs": [
......@@ -266,7 +286,7 @@
"(128, 768)"
]
},
"execution_count": 7,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
......@@ -282,11 +302,11 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:01.807638Z",
"start_time": "2018-11-05T13:59:01.771422Z"
"end_time": "2018-11-15T14:58:25.547012Z",
"start_time": "2018-11-15T14:58:25.516076Z"
}
},
"outputs": [],
......@@ -303,361 +323,391 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"os.chdir('./examples')"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:02.020918Z",
"start_time": "2018-11-05T13:59:01.810061Z"
"end_time": "2018-11-15T15:03:49.528679Z",
"start_time": "2018-11-15T15:03:49.497697Z"
}
},
"outputs": [],
"source": [
"import extract_features\n",
"import pytorch_pretrained_bert as ppb\n",
"from extract_features import *"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 25,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:02.058211Z",
"start_time": "2018-11-05T13:59:02.022785Z"
"end_time": "2018-11-15T15:21:18.001177Z",
"start_time": "2018-11-15T15:21:17.970369Z"
}
},
"outputs": [],
"source": [
"init_checkpoint_pt = \"../google_models/uncased_L-12_H-768_A-12/pytorch_model.bin\""
"init_checkpoint_pt = \"../../google_models/uncased_L-12_H-768_A-12/\""
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 26,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:03.740561Z",
"start_time": "2018-11-05T13:59:02.059877Z"
"end_time": "2018-11-15T15:21:20.893669Z",
"start_time": "2018-11-15T15:21:18.786623Z"
},
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling - loading archive file ../../google_models/uncased_L-12_H-768_A-12/\n",
"11/15/2018 16:21:18 - INFO - pytorch_pretrained_bert.modeling - Model config {\n",
" \"attention_probs_dropout_prob\": 0.1,\n",
" \"hidden_act\": \"gelu\",\n",
" \"hidden_dropout_prob\": 0.1,\n",
" \"hidden_size\": 768,\n",
" \"initializer_range\": 0.02,\n",
" \"intermediate_size\": 3072,\n",
" \"max_position_embeddings\": 512,\n",
" \"num_attention_heads\": 12,\n",
" \"num_hidden_layers\": 12,\n",
" \"type_vocab_size\": 2,\n",
" \"vocab_size\": 30522\n",
"}\n",
"\n"
]
},
{
"data": {
"text/plain": [
"BertModel(\n",
" (embeddings): BERTEmbeddings(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(30522, 768)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (encoder): BERTEncoder(\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (1): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (2): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (3): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (4): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (5): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (6): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (7): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (8): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (9): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (10): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (11): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BERTPooler(\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
")"
]
},
"execution_count": 11,
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device(\"cpu\")\n",
"model = extract_features.BertModel(bert_config)\n",
"model.load_state_dict(torch.load(init_checkpoint_pt, map_location='cpu'))\n",
"model = ppb.BertModel.from_pretrained(init_checkpoint_pt)\n",
"model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 27,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:03.780145Z",
"start_time": "2018-11-05T13:59:03.742407Z"
"end_time": "2018-11-15T15:21:26.963427Z",
"start_time": "2018-11-15T15:21:26.922494Z"
},
"code_folding": []
},
......@@ -666,301 +716,301 @@
"data": {
"text/plain": [
"BertModel(\n",
" (embeddings): BERTEmbeddings(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(30522, 768)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (encoder): BERTEncoder(\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (1): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (2): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (3): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (4): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (5): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (6): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (7): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (8): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (9): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (10): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (11): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BERTPooler(\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
")"
]
},
"execution_count": 12,
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
......@@ -980,11 +1030,11 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 28,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:04.233844Z",
"start_time": "2018-11-05T13:59:03.782525Z"
"end_time": "2018-11-15T15:21:30.718724Z",
"start_time": "2018-11-15T15:21:30.329205Z"
}
},
"outputs": [
......@@ -1068,11 +1118,11 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 29,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:04.278496Z",
"start_time": "2018-11-05T13:59:04.235703Z"
"end_time": "2018-11-15T15:21:35.703615Z",
"start_time": "2018-11-15T15:21:35.666150Z"
}
},
"outputs": [
......@@ -1094,7 +1144,7 @@
"(128, 768)"
]
},
"execution_count": 14,
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
......@@ -1111,11 +1161,11 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 30,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:04.313952Z",
"start_time": "2018-11-05T13:59:04.280352Z"
"end_time": "2018-11-15T15:21:36.999073Z",
"start_time": "2018-11-15T15:21:36.966762Z"
}
},
"outputs": [
......@@ -1136,11 +1186,11 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 31,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:04.350048Z",
"start_time": "2018-11-05T13:59:04.316003Z"
"end_time": "2018-11-15T15:21:37.936522Z",
"start_time": "2018-11-15T15:21:37.905269Z"
}
},
"outputs": [
......@@ -1167,11 +1217,11 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 32,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:04.382430Z",
"start_time": "2018-11-05T13:59:04.351550Z"
"end_time": "2018-11-15T15:21:39.437137Z",
"start_time": "2018-11-15T15:21:39.406150Z"
}
},
"outputs": [],
......@@ -1181,11 +1231,11 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 33,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:04.428334Z",
"start_time": "2018-11-05T13:59:04.386070Z"
"end_time": "2018-11-15T15:21:40.181870Z",
"start_time": "2018-11-15T15:21:40.137023Z"
}
},
"outputs": [
......
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Comparing TensorFlow (original) and PyTorch models\n",
"\n",
"You can use this small notebook to check the conversion of the model's weights from the TensorFlow model to the PyTorch model. In the following, we compare the weights of the last layer on a simple example (in `input.txt`) but both models returns all the hidden layers so you can check every stage of the model.\n",
"\n",
"To run this notebook, follow these instructions:\n",
"- make sure that your Python environment has both TensorFlow and PyTorch installed,\n",
"- download the original TensorFlow implementation,\n",
"- download a pre-trained TensorFlow model as indicaded in the TensorFlow implementation readme,\n",
"- run the script `convert_tf_checkpoint_to_pytorch.py` as indicated in the `README` to convert the pre-trained TensorFlow model to PyTorch.\n",
"\n",
"If needed change the relative paths indicated in this notebook (at the beggining of Sections 1 and 2) to point to the relevent models and code."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:50.559657Z",
"start_time": "2018-11-05T13:58:50.546096Z"
}
},
"outputs": [],
"source": [
"import os\n",
"os.chdir('../')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1/ TensorFlow code"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:50.574455Z",
"start_time": "2018-11-05T13:58:50.561988Z"
}
},
"outputs": [],
"source": [
"original_tf_inplem_dir = \"./tensorflow_code/\"\n",
"model_dir = \"../google_models/uncased_L-12_H-768_A-12/\"\n",
"\n",
"vocab_file = model_dir + \"vocab.txt\"\n",
"bert_config_file = model_dir + \"bert_config.json\"\n",
"init_checkpoint = model_dir + \"bert_model.ckpt\"\n",
"\n",
"input_file = \"./samples/input.txt\"\n",
"max_seq_length = 128"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:52.202531Z",
"start_time": "2018-11-05T13:58:50.576198Z"
}
},
"outputs": [],
"source": [
"import importlib.util\n",
"import sys\n",
"\n",
"spec = importlib.util.spec_from_file_location('*', original_tf_inplem_dir + '/extract_features.py')\n",
"module = importlib.util.module_from_spec(spec)\n",
"spec.loader.exec_module(module)\n",
"sys.modules['extract_features_tensorflow'] = module\n",
"\n",
"from extract_features_tensorflow import *"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:52.325822Z",
"start_time": "2018-11-05T13:58:52.205361Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Example ***\n",
"INFO:tensorflow:unique_id: 0\n",
"INFO:tensorflow:tokens: [CLS] who was jim henson ? [SEP] jim henson was a puppet ##eer [SEP]\n",
"INFO:tensorflow:input_ids: 101 2040 2001 3958 27227 1029 102 3958 27227 2001 1037 13997 11510 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
"INFO:tensorflow:input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n",
"INFO:tensorflow:input_type_ids: 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n"
]
}
],
"source": [
"layer_indexes = list(range(12))\n",
"bert_config = modeling.BertConfig.from_json_file(bert_config_file)\n",
"tokenizer = tokenization.BertTokenizer(\n",
" vocab_file=vocab_file, do_lower_case=True)\n",
"examples = read_examples(input_file)\n",
"\n",
"features = convert_examples_to_features(\n",
" examples=examples, seq_length=max_seq_length, tokenizer=tokenizer)\n",
"unique_id_to_feature = {}\n",
"for feature in features:\n",
" unique_id_to_feature[feature.unique_id] = feature"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:58:55.939938Z",
"start_time": "2018-11-05T13:58:52.330202Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:Estimator's model_fn (<function model_fn_builder.<locals>.model_fn at 0x12839dbf8>) includes params argument, but params are not passed to Estimator.\n",
"WARNING:tensorflow:Using temporary folder as model directory: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpdbx_h23u\n",
"INFO:tensorflow:Using config: {'_model_dir': '/var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpdbx_h23u', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true\n",
"graph_options {\n",
" rewrite_options {\n",
" meta_optimizer_iterations: ONE\n",
" }\n",
"}\n",
", '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x12b3e1c18>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=1, num_cores_per_replica=None, per_host_input_for_training=3, tpu_job_name=None, initial_infeed_sleep_secs=None, input_partition_dims=None), '_cluster': None}\n",
"WARNING:tensorflow:Setting TPUConfig.num_shards==1 is an unsupported behavior. Please fix as soon as possible (leaving num_shards as None.\n",
"INFO:tensorflow:_TPUContext: eval_on_tpu True\n",
"WARNING:tensorflow:eval_on_tpu ignored because use_tpu is False.\n"
]
}
],
"source": [
"is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2\n",
"run_config = tf.contrib.tpu.RunConfig(\n",
" master=None,\n",
" tpu_config=tf.contrib.tpu.TPUConfig(\n",
" num_shards=1,\n",
" per_host_input_for_training=is_per_host))\n",
"\n",
"model_fn = model_fn_builder(\n",
" bert_config=bert_config,\n",
" init_checkpoint=init_checkpoint,\n",
" layer_indexes=layer_indexes,\n",
" use_tpu=False,\n",
" use_one_hot_embeddings=False)\n",
"\n",
"# If TPU is not available, this will fall back to normal Estimator on CPU\n",
"# or GPU.\n",
"estimator = tf.contrib.tpu.TPUEstimator(\n",
" use_tpu=False,\n",
" model_fn=model_fn,\n",
" config=run_config,\n",
" predict_batch_size=1)\n",
"\n",
"input_fn = input_fn_builder(\n",
" features=features, seq_length=max_seq_length)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:01.717585Z",
"start_time": "2018-11-05T13:58:55.941869Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"INFO:tensorflow:Could not find trained model in model_dir: /var/folders/yx/cw8n_njx3js5jksyw_qlp8p00000gn/T/tmpdbx_h23u, running initialization to predict.\n",
"INFO:tensorflow:Calling model_fn.\n",
"INFO:tensorflow:Running infer on CPU\n",
"INFO:tensorflow:Done calling model_fn.\n",
"INFO:tensorflow:Graph was finalized.\n",
"INFO:tensorflow:Running local_init_op.\n",
"INFO:tensorflow:Done running local_init_op.\n",
"extracting layer 0\n",
"extracting layer 1\n",
"extracting layer 2\n",
"extracting layer 3\n",
"extracting layer 4\n",
"extracting layer 5\n",
"extracting layer 6\n",
"extracting layer 7\n",
"extracting layer 8\n",
"extracting layer 9\n",
"extracting layer 10\n",
"extracting layer 11\n",
"INFO:tensorflow:prediction_loop marked as finished\n",
"INFO:tensorflow:prediction_loop marked as finished\n"
]
}
],
"source": [
"tensorflow_all_out = []\n",
"for result in estimator.predict(input_fn, yield_single_examples=True):\n",
" unique_id = int(result[\"unique_id\"])\n",
" feature = unique_id_to_feature[unique_id]\n",
" output_json = collections.OrderedDict()\n",
" output_json[\"linex_index\"] = unique_id\n",
" tensorflow_all_out_features = []\n",
" # for (i, token) in enumerate(feature.tokens):\n",
" all_layers = []\n",
" for (j, layer_index) in enumerate(layer_indexes):\n",
" print(\"extracting layer {}\".format(j))\n",
" layer_output = result[\"layer_output_%d\" % j]\n",
" layers = collections.OrderedDict()\n",
" layers[\"index\"] = layer_index\n",
" layers[\"values\"] = layer_output\n",
" all_layers.append(layers)\n",
" tensorflow_out_features = collections.OrderedDict()\n",
" tensorflow_out_features[\"layers\"] = all_layers\n",
" tensorflow_all_out_features.append(tensorflow_out_features)\n",
"\n",
" output_json[\"features\"] = tensorflow_all_out_features\n",
" tensorflow_all_out.append(output_json)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:01.769845Z",
"start_time": "2018-11-05T13:59:01.719878Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"2\n",
"odict_keys(['linex_index', 'features'])\n",
"number of tokens 1\n",
"number of layers 12\n"
]
},
{
"data": {
"text/plain": [
"(128, 768)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(len(tensorflow_all_out))\n",
"print(len(tensorflow_all_out[0]))\n",
"print(tensorflow_all_out[0].keys())\n",
"print(\"number of tokens\", len(tensorflow_all_out[0]['features']))\n",
"print(\"number of layers\", len(tensorflow_all_out[0]['features'][0]['layers']))\n",
"tensorflow_all_out[0]['features'][0]['layers'][0]['values'].shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:01.807638Z",
"start_time": "2018-11-05T13:59:01.771422Z"
}
},
"outputs": [],
"source": [
"tensorflow_outputs = list(tensorflow_all_out[0]['features'][0]['layers'][t]['values'] for t in layer_indexes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2/ PyTorch code"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:02.020918Z",
"start_time": "2018-11-05T13:59:01.810061Z"
}
},
"outputs": [],
"source": [
"import extract_features\n",
"from extract_features import *"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:02.058211Z",
"start_time": "2018-11-05T13:59:02.022785Z"
}
},
"outputs": [],
"source": [
"init_checkpoint_pt = \"../google_models/uncased_L-12_H-768_A-12/pytorch_model.bin\""
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:03.740561Z",
"start_time": "2018-11-05T13:59:02.059877Z"
},
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"BertModel(\n",
" (embeddings): BERTEmbeddings(\n",
" (word_embeddings): Embedding(30522, 768)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (encoder): BERTEncoder(\n",
" (layer): ModuleList(\n",
" (0): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (1): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (2): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (3): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (4): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (5): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (6): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (7): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (8): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (9): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (10): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (11): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BERTPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
")"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device = torch.device(\"cpu\")\n",
"model = extract_features.BertModel(bert_config)\n",
"model.load_state_dict(torch.load(init_checkpoint_pt, map_location='cpu'))\n",
"model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:03.780145Z",
"start_time": "2018-11-05T13:59:03.742407Z"
},
"code_folding": []
},
"outputs": [
{
"data": {
"text/plain": [
"BertModel(\n",
" (embeddings): BERTEmbeddings(\n",
" (word_embeddings): Embedding(30522, 768)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (encoder): BERTEncoder(\n",
" (layer): ModuleList(\n",
" (0): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (1): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (2): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (3): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (4): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (5): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (6): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (7): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (8): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (9): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (10): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (11): BERTLayer(\n",
" (attention): BERTAttention(\n",
" (self): BERTSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BERTSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BERTIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BERTOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BERTLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BERTPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
")"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)\n",
"all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)\n",
"all_input_type_ids = torch.tensor([f.input_type_ids for f in features], dtype=torch.long)\n",
"all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)\n",
"\n",
"eval_data = TensorDataset(all_input_ids, all_input_mask, all_input_type_ids, all_example_index)\n",
"eval_sampler = SequentialSampler(eval_data)\n",
"eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)\n",
"\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:04.233844Z",
"start_time": "2018-11-05T13:59:03.782525Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[ 101, 2040, 2001, 3958, 27227, 1029, 102, 3958, 27227, 2001,\n",
" 1037, 13997, 11510, 102, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0]])\n",
"tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0]])\n",
"tensor([0])\n",
"layer 0 0\n",
"layer 1 1\n",
"layer 2 2\n",
"layer 3 3\n",
"layer 4 4\n",
"layer 5 5\n",
"layer 6 6\n",
"layer 7 7\n",
"layer 8 8\n",
"layer 9 9\n",
"layer 10 10\n",
"layer 11 11\n"
]
}
],
"source": [
"layer_indexes = list(range(12))\n",
"\n",
"pytorch_all_out = []\n",
"for input_ids, input_mask, input_type_ids, example_indices in eval_dataloader:\n",
" print(input_ids)\n",
" print(input_mask)\n",
" print(example_indices)\n",
" input_ids = input_ids.to(device)\n",
" input_mask = input_mask.to(device)\n",
"\n",
" all_encoder_layers, _ = model(input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)\n",
"\n",
" for b, example_index in enumerate(example_indices):\n",
" feature = features[example_index.item()]\n",
" unique_id = int(feature.unique_id)\n",
" # feature = unique_id_to_feature[unique_id]\n",
" output_json = collections.OrderedDict()\n",
" output_json[\"linex_index\"] = unique_id\n",
" all_out_features = []\n",
" # for (i, token) in enumerate(feature.tokens):\n",
" all_layers = []\n",
" for (j, layer_index) in enumerate(layer_indexes):\n",
" print(\"layer\", j, layer_index)\n",
" layer_output = all_encoder_layers[int(layer_index)].detach().cpu().numpy()\n",
" layer_output = layer_output[b]\n",
" layers = collections.OrderedDict()\n",
" layers[\"index\"] = layer_index\n",
" layer_output = layer_output\n",
" layers[\"values\"] = layer_output if not isinstance(layer_output, (int, float)) else [layer_output]\n",
" all_layers.append(layers)\n",
"\n",
" out_features = collections.OrderedDict()\n",
" out_features[\"layers\"] = all_layers\n",
" all_out_features.append(out_features)\n",
" output_json[\"features\"] = all_out_features\n",
" pytorch_all_out.append(output_json)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:04.278496Z",
"start_time": "2018-11-05T13:59:04.235703Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1\n",
"2\n",
"odict_keys(['linex_index', 'features'])\n",
"number of tokens 1\n",
"number of layers 12\n",
"hidden_size 128\n"
]
},
{
"data": {
"text/plain": [
"(128, 768)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(len(pytorch_all_out))\n",
"print(len(pytorch_all_out[0]))\n",
"print(pytorch_all_out[0].keys())\n",
"print(\"number of tokens\", len(pytorch_all_out))\n",
"print(\"number of layers\", len(pytorch_all_out[0]['features'][0]['layers']))\n",
"print(\"hidden_size\", len(pytorch_all_out[0]['features'][0]['layers'][0]['values']))\n",
"pytorch_all_out[0]['features'][0]['layers'][0]['values'].shape"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:04.313952Z",
"start_time": "2018-11-05T13:59:04.280352Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(128, 768)\n",
"(128, 768)\n"
]
}
],
"source": [
"pytorch_outputs = list(pytorch_all_out[0]['features'][0]['layers'][t]['values'] for t in layer_indexes)\n",
"print(pytorch_outputs[0].shape)\n",
"print(pytorch_outputs[1].shape)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:04.350048Z",
"start_time": "2018-11-05T13:59:04.316003Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(128, 768)\n",
"(128, 768)\n"
]
}
],
"source": [
"print(tensorflow_outputs[0].shape)\n",
"print(tensorflow_outputs[1].shape)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3/ Comparing the standard deviation on the last layer of both models"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:04.382430Z",
"start_time": "2018-11-05T13:59:04.351550Z"
}
},
"outputs": [],
"source": [
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"ExecuteTime": {
"end_time": "2018-11-05T13:59:04.428334Z",
"start_time": "2018-11-05T13:59:04.386070Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"shape tensorflow layer, shape pytorch layer, standard deviation\n",
"((128, 768), (128, 768), 1.5258875e-07)\n",
"((128, 768), (128, 768), 2.342731e-07)\n",
"((128, 768), (128, 768), 2.801949e-07)\n",
"((128, 768), (128, 768), 3.5904986e-07)\n",
"((128, 768), (128, 768), 4.2842768e-07)\n",
"((128, 768), (128, 768), 5.127951e-07)\n",
"((128, 768), (128, 768), 6.14668e-07)\n",
"((128, 768), (128, 768), 7.063922e-07)\n",
"((128, 768), (128, 768), 7.906173e-07)\n",
"((128, 768), (128, 768), 8.475192e-07)\n",
"((128, 768), (128, 768), 8.975489e-07)\n",
"((128, 768), (128, 768), 4.1671223e-07)\n"
]
}
],
"source": [
"print('shape tensorflow layer, shape pytorch layer, standard deviation')\n",
"print('\\n'.join(list(str((np.array(tensorflow_outputs[i]).shape,\n",
" np.array(pytorch_outputs[i]).shape, \n",
" np.sqrt(np.mean((np.array(tensorflow_outputs[i]) - np.array(pytorch_outputs[i]))**2.0)))) for i in range(12))))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
},
"toc": {
"colors": {
"hover_highlight": "#DAA520",
"running_highlight": "#FF0000",
"selected_highlight": "#FFD700"
},
"moveMenuLeft": true,
"nav_menu": {
"height": "48px",
"width": "252px"
},
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 4,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer
from .modeling import (BertConfig, BertModel, BertForPreTraining,
BertForMaskedLM, BertForNextSentencePrediction,
BertForSequenceClassification, BertForQuestionAnswering)
from .optimization import BERTAdam
# coding: utf8
if __name__ == '__main__':
import sys
try:
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ModuleNotFoundError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if len(sys.argv) != 5:
# pylint: disable=line-too-long
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
else:
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
......@@ -18,66 +18,39 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import re
import argparse
import tensorflow as tf
import torch
import numpy as np
from modeling import BertConfig, BertModel
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
def convert():
# Initialise PyTorch model
config = BertConfig.from_json_file(args.bert_config_file)
model = BertModel(config)
from .modeling import BertConfig, BertForPreTraining
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
config_path = os.path.abspath(bert_config_file)
tf_path = os.path.abspath(tf_checkpoint_path)
print("Converting TensorFlow checkpoint from {} with config at {}".format(tf_path, config_path))
# Load weights from TF model
path = args.tf_checkpoint_path
print("Converting TensorFlow checkpoint from {}".format(path))
init_vars = tf.train.list_variables(path)
init_vars = tf.train.list_variables(tf_path)
names = []
arrays = []
for name, shape in init_vars:
print("Loading {} with shape {}".format(name, shape))
array = tf.train.load_variable(path, name)
print("Numpy array shape {}".format(array.shape))
print("Loading TF weight {} with shape {}".format(name, shape))
array = tf.train.load_variable(tf_path, name)
names.append(name)
arrays.append(array)
# Initialise PyTorch model
config = BertConfig.from_json_file(bert_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = BertForPreTraining(config)
for name, array in zip(names, arrays):
if not name.startswith("bert"):
print("Skipping {}".format(name))
continue
else:
name = name.replace("bert/", "") # skip "bert/"
print("Loading {}".format(name))
name = name.split('/')
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if name[0] in ['redictions', 'eq_relationship'] or name[-1] == "adam_v" or name[-1] == "adam_m":
if name[-1] in ["adam_v", "adam_m"]:
print("Skipping {}".format("/".join(name)))
continue
pointer = model
......@@ -88,6 +61,10 @@ def convert():
l = [m_name]
if l[0] == 'kernel':
pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias':
pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight')
else:
pointer = getattr(pointer, l[0])
if len(l) >= 2:
......@@ -102,10 +79,34 @@ def convert():
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
print("Initialize PyTorch weight {}".format(name))
pointer.data = torch.from_numpy(array)
# Save pytorch-model
torch.save(model.state_dict(), args.pytorch_dump_path)
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
if __name__ == "__main__":
convert()
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--tf_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the TensorFlow checkpoint path.")
parser.add_argument("--bert_config_file",
default = None,
type = str,
required = True,
help = "The config json file corresponding to the pre-trained BERT model. \n"
"This specifies the model architecture.")
parser.add_argument("--pytorch_dump_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.bert_config_file,
args.pytorch_dump_path)
"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""
import os
import logging
import shutil
import tempfile
import json
from urllib.parse import urlparse
from pathlib import Path
from typing import Optional, Tuple, Union, IO, Callable, Set
from hashlib import sha256
from functools import wraps
from tqdm import tqdm
import boto3
from botocore.exceptions import ClientError
import requests
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
Path.home() / '.pytorch_pretrained_bert'))
def url_to_filename(url: str, etag: str = None) -> str:
"""
Convert `url` into a hashed filename in a repeatable way.
If `etag` is specified, append its hash to the url's, delimited
by a period.
"""
url_bytes = url.encode('utf-8')
url_hash = sha256(url_bytes)
filename = url_hash.hexdigest()
if etag:
etag_bytes = etag.encode('utf-8')
etag_hash = sha256(etag_bytes)
filename += '.' + etag_hash.hexdigest()
return filename
def filename_to_url(filename: str, cache_dir: str = None) -> Tuple[str, str]:
"""
Return the url and etag (which may be ``None``) stored for `filename`.
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise FileNotFoundError("file {} not found".format(cache_path))
meta_path = cache_path + '.json'
if not os.path.exists(meta_path):
raise FileNotFoundError("file {} not found".format(meta_path))
with open(meta_path) as meta_file:
metadata = json.load(meta_file)
url = metadata['url']
etag = metadata['etag']
return url, etag
def cached_path(url_or_filename: Union[str, Path], cache_dir: str = None) -> str:
"""
Given something that might be a URL (or might be a local path),
determine which. If it's a URL, download the file and cache it, and
return the path to the cached file. If it's already a local path,
make sure the file exists and then return the path.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
parsed = urlparse(url_or_filename)
if parsed.scheme in ('http', 'https', 's3'):
# URL, so get it from the cache (downloading if necessary)
return get_from_cache(url_or_filename, cache_dir)
elif os.path.exists(url_or_filename):
# File, and it exists.
return url_or_filename
elif parsed.scheme == '':
# File, but it doesn't exist.
raise FileNotFoundError("file {} not found".format(url_or_filename))
else:
# Something unknown
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
def split_s3_path(url: str) -> Tuple[str, str]:
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path
def s3_request(func: Callable):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""
@wraps(func)
def wrapper(url: str, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise FileNotFoundError("file {} not found".format(url))
else:
raise
return wrapper
@s3_request
def s3_etag(url: str) -> Optional[str]:
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag
@s3_request
def s3_get(url: str, temp_file: IO) -> None:
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3")
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
def http_get(url: str, temp_file: IO) -> None:
req = requests.get(url, stream=True)
content_length = req.headers.get('Content-Length')
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total)
for chunk in req.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(url: str, cache_dir: str = None) -> str:
"""
Given a URL, look for the corresponding dataset in the local cache.
If it's not there, download it. Then return the path to the cached file.
"""
if cache_dir is None:
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
os.makedirs(cache_dir, exist_ok=True)
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url)
else:
response = requests.head(url, allow_redirects=True)
if response.status_code != 200:
raise IOError("HEAD request failed for url {} with status code {}"
.format(url, response.status_code))
etag = response.headers.get("ETag")
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with tempfile.NamedTemporaryFile() as temp_file:
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
# GET file object
if url.startswith("s3://"):
s3_get(url, temp_file)
else:
http_get(url, temp_file)
# we are copying the file before closing it, so flush to avoid truncation
temp_file.flush()
# shutil.copyfileobj() starts at the current position, so go to the start
temp_file.seek(0)
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
with open(cache_path, 'wb') as cache_file:
shutil.copyfileobj(temp_file, cache_file)
logger.info("creating metadata file for %s", cache_path)
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w') as meta_file:
json.dump(meta, meta_file)
logger.info("removing temp file %s", temp_file.name)
return cache_path
def read_set_from_file(filename: str) -> Set[str]:
'''
Extract a de-duped collection (set) of text from a file.
Expected file format is one item per line.
'''
collection = set()
with open(filename, 'r') as file_:
for line in file_:
collection.add(line.rstrip())
return collection
def get_file_extension(path: str, dot=True, lower: bool = True):
ext = os.path.splitext(path)[1]
ext = ext if dot else ext[1:]
return ext.lower() if lower else ext
......@@ -18,14 +18,35 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import copy
import json
import math
import six
import logging
import tarfile
import tempfile
import shutil
import torch
import torch.nn as nn
from torch import nn
from torch.nn import CrossEntropyLoss
from six import string_types
from .file_utils import cached_path
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-base-multilingual': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",
}
CONFIG_NAME = 'bert_config.json'
WEIGHTS_NAME = 'pytorch_model.bin'
def gelu(x):
"""Implementation of the gelu activation function.
......@@ -46,7 +67,7 @@ class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
def __init__(self,
vocab_size,
vocab_size_or_config_json_file,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
......@@ -55,12 +76,12 @@ class BertConfig(object):
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
type_vocab_size=2,
initializer_range=0.02):
"""Constructs BertConfig.
Args:
vocab_size: Vocabulary size of `inputs_ids` in `BertModel`.
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
......@@ -81,7 +102,13 @@ class BertConfig(object):
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
"""
self.vocab_size = vocab_size
if isinstance(vocab_size_or_config_json_file, str):
with open(vocab_size_or_config_json_file, "r") as reader:
json_config = json.loads(reader.read())
for key, value in json_config.items():
self.__dict__[key] = value
elif isinstance(vocab_size_or_config_json_file, int):
self.vocab_size = vocab_size_or_config_json_file
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
......@@ -92,12 +119,15 @@ class BertConfig(object):
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
else:
raise ValueError("First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)")
@classmethod
def from_dict(cls, json_object):
"""Constructs a `BertConfig` from a Python dictionary of parameters."""
config = BertConfig(vocab_size=None)
for (key, value) in six.iteritems(json_object):
config = BertConfig(vocab_size_or_config_json_file=-1)
for key, value in json_object.items():
config.__dict__[key] = value
return config
......@@ -108,6 +138,9 @@ class BertConfig(object):
text = reader.read()
return cls.from_dict(json.loads(text))
def __repr__(self):
return str(self.to_json_string())
def to_dict(self):
"""Serializes this instance to a Python dictionary."""
output = copy.deepcopy(self.__dict__)
......@@ -118,11 +151,11 @@ class BertConfig(object):
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
class BERTLayerNorm(nn.Module):
class BertLayerNorm(nn.Module):
def __init__(self, config, variance_epsilon=1e-12):
"""Construct a layernorm module in the TF style (epsilon inside the square root).
"""
super(BERTLayerNorm, self).__init__()
super(BertLayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(config.hidden_size))
self.beta = nn.Parameter(torch.zeros(config.hidden_size))
self.variance_epsilon = variance_epsilon
......@@ -133,18 +166,19 @@ class BERTLayerNorm(nn.Module):
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.gamma * x + self.beta
class BERTEmbeddings(nn.Module):
def __init__(self, config):
super(BERTEmbeddings, self).__init__()
"""Construct the embedding module from word, position and token_type embeddings.
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.
"""
def __init__(self, config):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BERTLayerNorm(config)
self.LayerNorm = BertLayerNorm(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
......@@ -164,9 +198,9 @@ class BERTEmbeddings(nn.Module):
return embeddings
class BERTSelfAttention(nn.Module):
class BertSelfAttention(nn.Module):
def __init__(self, config):
super(BERTSelfAttention, self).__init__()
super(BertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
......@@ -215,11 +249,11 @@ class BERTSelfAttention(nn.Module):
return context_layer
class BERTSelfOutput(nn.Module):
class BertSelfOutput(nn.Module):
def __init__(self, config):
super(BERTSelfOutput, self).__init__()
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = BERTLayerNorm(config)
self.LayerNorm = BertLayerNorm(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
......@@ -229,11 +263,11 @@ class BERTSelfOutput(nn.Module):
return hidden_states
class BERTAttention(nn.Module):
class BertAttention(nn.Module):
def __init__(self, config):
super(BERTAttention, self).__init__()
self.self = BERTSelfAttention(config)
self.output = BERTSelfOutput(config)
super(BertAttention, self).__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
......@@ -241,12 +275,12 @@ class BERTAttention(nn.Module):
return attention_output
class BERTIntermediate(nn.Module):
class BertIntermediate(nn.Module):
def __init__(self, config):
super(BERTIntermediate, self).__init__()
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, string_types) else config.hidden_act
if isinstance(config.hidden_act, str) else config.hidden_act
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
......@@ -254,11 +288,11 @@ class BERTIntermediate(nn.Module):
return hidden_states
class BERTOutput(nn.Module):
class BertOutput(nn.Module):
def __init__(self, config):
super(BERTOutput, self).__init__()
super(BertOutput, self).__init__()
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = BERTLayerNorm(config)
self.LayerNorm = BertLayerNorm(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
......@@ -268,12 +302,12 @@ class BERTOutput(nn.Module):
return hidden_states
class BERTLayer(nn.Module):
class BertLayer(nn.Module):
def __init__(self, config):
super(BERTLayer, self).__init__()
self.attention = BERTAttention(config)
self.intermediate = BERTIntermediate(config)
self.output = BERTOutput(config)
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
......@@ -282,23 +316,26 @@ class BERTLayer(nn.Module):
return layer_output
class BERTEncoder(nn.Module):
class BertEncoder(nn.Module):
def __init__(self, config):
super(BERTEncoder, self).__init__()
layer = BERTLayer(config)
super(BertEncoder, self).__init__()
layer = BertLayer(config)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask):
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
return all_encoder_layers
class BERTPooler(nn.Module):
class BertPooler(nn.Module):
def __init__(self, config):
super(BERTPooler, self).__init__()
super(BertPooler, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()
......@@ -311,9 +348,220 @@ class BERTPooler(nn.Module):
return pooled_output
class BertModel(nn.Module):
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.transform_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act
self.LayerNorm = BertLayerNorm(config)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
bert_model_embedding_weights.size(0),
bias=False)
self.decoder.weight = bert_model_embedding_weights
self.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
class BertOnlyMLMHead(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertOnlyMLMHead, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class BertOnlyNSPHead(nn.Module):
def __init__(self, config):
super(BertOnlyNSPHead, self).__init__()
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, pooled_output):
seq_relationship_score = self.seq_relationship(pooled_output)
return seq_relationship_score
class BertPreTrainingHeads(nn.Module):
def __init__(self, config, bert_model_embedding_weights):
super(BertPreTrainingHeads, self).__init__()
self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
self.seq_relationship = nn.Linear(config.hidden_size, 2)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class PreTrainedBertModel(nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedBertModel, self).__init__()
if not isinstance(config, BertConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
self.__class__.__name__, self.__class__.__name__
))
self.config = config
def init_bert_weights(self, module):
""" Initialize the weights.
"""
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.beta.data.normal_(mean=0.0, std=self.config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
@classmethod
def from_pretrained(cls, pretrained_model_name, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-base-multilingual`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
if pretrained_model_name in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name]
else:
archive_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_archive_file = cached_path(archive_file)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
', '.join(PRETRAINED_MODEL_ARCHIVE_MAP.keys()),
pretrained_model_name))
return None
if resolved_archive_file == archive_file:
logger.info("loading archive file {}".format(archive_file))
else:
logger.info("loading archive file {} from cache at {}".format(
archive_file, resolved_archive_file))
tempdir = None
if os.path.isdir(resolved_archive_file):
serialization_dir = resolved_archive_file
else:
# Extract archive to temp dir
tempdir = tempfile.mkdtemp()
logger.info("extracting archive file {} to temp dir {}".format(
resolved_archive_file, tempdir))
with tarfile.open(resolved_archive_file, 'r:gz') as archive:
archive.extractall(tempdir)
serialization_dir = tempdir
# Load config
config_file = os.path.join(serialization_dir, CONFIG_NAME)
config = BertConfig.from_json_file(config_file)
logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(config, *inputs, **kwargs)
weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
state_dict = torch.load(weights_path)
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if tempdir:
# Clean up temp dir
shutil.rmtree(tempdir)
return model
class BertModel(PreTrainedBertModel):
"""BERT model ("Bidirectional Embedding Representations from a Transformer").
Params:
config: a BertConfig class instance with the configuration to build a new model
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block,
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLF`) to train on the Next-Sentence task (see BERT's paper).
Example usage:
```python
# Already been converted into WordPiece token ids
......@@ -328,18 +576,14 @@ class BertModel(nn.Module):
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config: BertConfig):
"""Constructor for BertModel.
Args:
config: `BertConfig` instance.
"""
super(BertModel, self).__init__()
self.embeddings = BERTEmbeddings(config)
self.encoder = BERTEncoder(config)
self.pooler = BERTPooler(config)
def __init__(self, config):
super(BertModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
......@@ -361,16 +605,239 @@ class BertModel(nn.Module):
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)
all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
sequence_output = all_encoder_layers[-1]
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
return all_encoder_layers, pooled_output
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output
class BertForPreTraining(PreTrainedBertModel):
"""BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads:
- the masked language modeling head, and
- the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
Outputs:
if `masked_lm_labels` and `next_sentence_label` are not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `masked_lm_labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising
- the masked language modeling logits, and
- the next sentence classification logits.
class BertForSequenceClassification(nn.Module):
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = BertForPreTraining(config)
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForPreTraining, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels)
next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label)
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
else:
return prediction_scores, seq_relationship_score
class BertForMaskedLM(PreTrainedBertModel):
"""BERT model with the masked language modeling head.
This module comprises the BERT model followed by the masked language modeling head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
is only computed for the labels set in [0, ..., vocab_size]
Outputs:
if `masked_lm_labels` is `None`:
Outputs the masked language modeling loss.
if `masked_lm_labels` is `None`:
Outputs the masked language modeling logits.
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = BertForMaskedLM(config)
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForMaskedLM, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
prediction_scores = self.cls(sequence_output)
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores, masked_lm_labels)
return masked_lm_loss
else:
return prediction_scores
class BertForNextSentencePrediction(PreTrainedBertModel):
"""BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence classification head.
Params:
config: a BertConfig class instance with the configuration to build a new model.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
with indices selected in [0, 1].
0 => next sentence is the continuation, 1 => next sentence is a random sentence.
Outputs:
if `next_sentence_label` is not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `next_sentence_label` is `None`:
Outputs the next sentence classification logits.
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 2, 0]])
config = BertConfig(vocab_size=32000, hidden_size=512,
num_hidden_layers=8, num_attention_heads=6, intermediate_size=1024)
model = BertForNextSentencePrediction(config)
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
super(BertForNextSentencePrediction, self).__init__(config)
self.bert = BertModel(config)
self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
seq_relationship_score = self.cls( pooled_output)
if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score, next_sentence_label)
return next_sentence_loss
else:
return seq_relationship_score
class BertForSequenceClassification(PreTrainedBertModel):
"""BERT model for classification.
This module is composed of the BERT model with a linear layer on top of
the pooled output.
Params:
`config`: a BertConfig class instance with the configuration to build a new model.
`num_labels`: the number of classes for the classifier. Default = 2.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`labels`: labels for the classification output: torch.LongTensor of shape [batch_size]
with indices selected in [0, ..., num_labels].
Outputs:
if `labels` is not `None`:
Outputs the CrossEntropy classification loss of the output with the labels.
if `labels` is `None`:
Outputs the classification logits.
Example usage:
```python
# Already been converted into WordPiece token ids
......@@ -387,26 +854,15 @@ class BertForSequenceClassification(nn.Module):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels):
super(BertForSequenceClassification, self).__init__()
def __init__(self, config, num_labels=2):
super(BertForSequenceClassification, self).__init__(config)
self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def init_weights(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
......@@ -417,11 +873,48 @@ class BertForSequenceClassification(nn.Module):
else:
return logits
class BertForQuestionAnswering(nn.Module):
class BertForQuestionAnswering(PreTrainedBertModel):
"""BERT model for Question Answering (span extraction).
This module is composed of the BERT model with a linear layer on top of
the sequence output that computes start_logits and end_logits
Params:
`config`: either
- a BertConfig class instance with the configuration to build a new model, or
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-base-multilingual`
. `bert-base-chinese`
The pre-trained model will be downloaded and cached if needed.
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`extract_features.py`, `run_classifier.py` and `run_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
`end_positions`: position of the last token for the labeled span: torch.LongTensor of shape [batch_size].
Positions are clamped to the length of the sequence and position outside of the sequence are not taken
into account for computing the loss.
Outputs:
if `start_positions` and `end_positions` are not `None`:
Outputs the total_loss which is the sum of the CrossEntropy loss for the start and end token positions.
if `start_positions` or `end_positions` is `None`:
Outputs a tuple of start_logits, end_logits which are the logits respectively for the start and end
position tokens.
Example usage:
```python
# Already been converted into WordPiece token ids
......@@ -437,27 +930,15 @@ class BertForQuestionAnswering(nn.Module):
```
"""
def __init__(self, config):
super(BertForQuestionAnswering, self).__init__()
super(BertForQuestionAnswering, self).__init__(config)
self.bert = BertModel(config)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_bert_weights)
def init_weights(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, start_positions=None, end_positions=None):
all_encoder_layers, _ = self.bert(input_ids, token_type_ids, attention_mask)
sequence_output = all_encoder_layers[-1]
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
......
......@@ -20,27 +20,32 @@ from __future__ import print_function
import collections
import unicodedata
import six
import os
import logging
from .file_utils import cached_path
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logger = logging.getLogger(__name__)
PRETRAINED_VOCAB_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt",
'bert-base-multilingual': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-vocab.txt",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt",
}
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
......@@ -48,22 +53,12 @@ def printable_text(text):
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
......@@ -81,14 +76,6 @@ def load_vocab(vocab_file):
return vocab
def convert_tokens_to_ids(vocab, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
ids.append(vocab[token])
return ids
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a peice of text."""
text = text.strip()
......@@ -98,11 +85,16 @@ def whitespace_tokenize(text):
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
class BertTokenizer(object):
"""Runs end-to-end tokenization: punctuation splitting + wordpiece"""
def __init__(self, vocab_file, do_lower_case=True):
if not os.path.isfile(vocab_file):
raise ValueError(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
self.vocab = load_vocab(vocab_file)
self.ids_to_tokens = collections.OrderedDict(
[(ids, tok) for tok, ids in self.vocab.items()])
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
......@@ -111,11 +103,52 @@ class FullTokenizer(object):
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_tokens_to_ids(self.vocab, tokens)
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
ids.append(self.vocab[token])
return ids
def convert_ids_to_tokens(self, ids):
"""Converts a sequence of ids in wordpiece tokens using the vocab."""
tokens = []
for i in ids:
tokens.append(self.ids_to_tokens[i])
return tokens
@classmethod
def from_pretrained(cls, pretrained_model_name, do_lower_case=True):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
else:
vocab_file = pretrained_model_name
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file)
if resolved_vocab_file == vocab_file:
logger.info("loading vocabulary file {}".format(vocab_file))
else:
logger.info("loading vocabulary file {} from cache at {}".format(
vocab_file, resolved_vocab_file))
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, do_lower_case)
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
pretrained_model_name))
tokenizer = None
return tokenizer
class BasicTokenizer(object):
......
torch
tqdm
\ No newline at end of file
# This installs Pytorch for CUDA 8 only. If you are using a newer version,
# please visit http://pytorch.org/ and install the relevant version.
torch>=0.4.1,<0.5.0
# progress bars in model download and training scripts
tqdm>=4.19
# Accessing files from S3 directly.
boto3
# Used for downloading models over HTTP
requests>=2.18
\ No newline at end of file
from setuptools import find_packages, setup
setup(
name="pytorch_pretrained_bert",
version="0.1.0",
author="Thomas Wolf, Victor Sanh, Tim Rault, Google AI Language Team Authors",
author_email="thomas@huggingface.co",
description="PyTorch version of Google AI BERT model with script to load Google pre-trained models",
long_description=open("README.md", "r").read(),
long_description_content_type="text/markdown",
keywords='BERT NLP deep learning google',
license='Apache',
url="https://github.com/huggingface/pytorch-pretrained-BERT",
packages=find_packages(exclude=["*.tests", "*.tests.*",
"tests.*", "tests"]),
install_requires=['numpy',
'torch>=0.4.1',
'boto3',
'requests>=2.18',
'tqdm>=4.19'],
scripts=["bin/pytorch_pretrained_bert"],
python_requires='>=3.5.0',
tests_require=['pytest'],
classifiers=[
'Intended Audience :: Science/Research',
'Development Status :: 1 - Alpha',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
)
......@@ -34,7 +34,7 @@ class TokenizationTest(unittest.TestCase):
vocab_file = vocab_writer.name
tokenizer = tokenization.FullTokenizer(vocab_file)
tokenizer = tokenization.BertTokenizer(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment