"tests/L1/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "c142714ba4a3485036aab2e0ef9d87aa67827d46"
Commit 2397f958 authored by thomwolf's avatar thomwolf
Browse files

updating examples and doc

parent c490f5ce
...@@ -131,11 +131,8 @@ This package comprises the following classes that can be imported in Python and ...@@ -131,11 +131,8 @@ This package comprises the following classes that can be imported in Python and
- Tokenizer for **OpenAI GPT-2** (using byte-level Byte-Pair-Encoding) (in the [`tokenization_gpt2.py`](./pytorch_transformers/tokenization_gpt2.py) file): - Tokenizer for **OpenAI GPT-2** (using byte-level Byte-Pair-Encoding) (in the [`tokenization_gpt2.py`](./pytorch_transformers/tokenization_gpt2.py) file):
- `GPT2Tokenizer` - perform byte-level Byte-Pair-Encoding (BPE) tokenization. - `GPT2Tokenizer` - perform byte-level Byte-Pair-Encoding (BPE) tokenization.
- Optimizer for **BERT** (in the [`optimization.py`](./pytorch_transformers/optimization.py) file): - Optimizer (in the [`optimization.py`](./pytorch_transformers/optimization.py) file):
- `BertAdam` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate. - `AdamW` - Version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
- Optimizer for **OpenAI GPT** (in the [`optimization_openai.py`](./pytorch_transformers/optimization_openai.py) file):
- `OpenAIAdam` - OpenAI GPT version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
- Configuration classes for BERT, OpenAI GPT and Transformer-XL (in the respective [`modeling.py`](./pytorch_transformers/modeling.py), [`modeling_openai.py`](./pytorch_transformers/modeling_openai.py), [`modeling_transfo_xl.py`](./pytorch_transformers/modeling_transfo_xl.py) files): - Configuration classes for BERT, OpenAI GPT and Transformer-XL (in the respective [`modeling.py`](./pytorch_transformers/modeling.py), [`modeling_openai.py`](./pytorch_transformers/modeling_openai.py), [`modeling_transfo_xl.py`](./pytorch_transformers/modeling_transfo_xl.py) files):
- `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilities to read and write from JSON configuration files. - `BertConfig` - Configuration class to store the configuration of a `BertModel` with utilities to read and write from JSON configuration files.
...@@ -1104,12 +1101,11 @@ Please refer to [`tokenization_gpt2.py`](./pytorch_transformers/tokenization_gpt ...@@ -1104,12 +1101,11 @@ Please refer to [`tokenization_gpt2.py`](./pytorch_transformers/tokenization_gpt
### Optimizers ### Optimizers
#### `BertAdam` #### `AdamW`
`BertAdam` is a `torch.optimizer` adapted to be closer to the optimizer used in the TensorFlow implementation of Bert. The differences with PyTorch Adam optimizer are the following: `AdamW` is a `torch.optimizer` adapted to be closer to the optimizer used in the TensorFlow implementation of Bert. The differences with PyTorch Adam optimizer are the following:
- BertAdam implements weight decay fix, - AdamW implements weight decay fix,
- BertAdam doesn't compensate for bias as in the regular Adam optimizer.
The optimizer accepts the following arguments: The optimizer accepts the following arguments:
...@@ -1127,13 +1123,6 @@ The optimizer accepts the following arguments: ...@@ -1127,13 +1123,6 @@ The optimizer accepts the following arguments:
- `weight_decay:` Weight decay. Default : `0.01` - `weight_decay:` Weight decay. Default : `0.01`
- `max_grad_norm` : Maximum norm for the gradients (`-1` means no clipping). Default : `1.0` - `max_grad_norm` : Maximum norm for the gradients (`-1` means no clipping). Default : `1.0`
#### `OpenAIAdam`
`OpenAIAdam` is similar to `BertAdam`.
The differences with `BertAdam` is that `OpenAIAdam` compensate for bias as in the regular Adam optimizer.
`OpenAIAdam` accepts the same arguments as `BertAdam`.
#### Learning Rate Schedules #### Learning Rate Schedules
The `.optimization` module also provides additional schedules in the form of schedule objects that inherit from `_LRSchedule`. The `.optimization` module also provides additional schedules in the form of schedule objects that inherit from `_LRSchedule`.
......
...@@ -60,10 +60,10 @@ This PyTorch implementation of Transformer-XL is an adaptation of the original ` ...@@ -60,10 +60,10 @@ This PyTorch implementation of Transformer-XL is an adaptation of the original `
This PyTorch implementation of OpenAI GPT-2 is an adaptation of the `OpenAI's implementation <https://github.com/openai/gpt-2>`__ and is provided with `OpenAI's pre-trained model <https://github.com/openai/gpt-2>`__ and a command-line interface that was used to convert the TensorFlow checkpoint in PyTorch. This PyTorch implementation of OpenAI GPT-2 is an adaptation of the `OpenAI's implementation <https://github.com/openai/gpt-2>`__ and is provided with `OpenAI's pre-trained model <https://github.com/openai/gpt-2>`__ and a command-line interface that was used to convert the TensorFlow checkpoint in PyTorch.
**Facebook Research's XLM** was released together with the paper `Cross-lingual Language Model Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau. **Facebook Research's XLM** was released together with the paper `Cross-lingual Language Model Pretraining <https://arxiv.org/abs/1901.07291>`__ by Guillaume Lample and Alexis Conneau.
This PyTorch implementation of XLM is an adaptation of the original `PyTorch implementation <https://github.com/facebookresearch/XLM>`__. TODO Lysandre filled This PyTorch implementation of XLM is an adaptation of the original `PyTorch implementation <https://github.com/facebookresearch/XLM>`__.
**Google's XLNet** was released together with the paper `XLNet: Generalized Autoregressive Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang\*, Zihang Dai\*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov and Quoc V. Le. **Google's XLNet** was released together with the paper `XLNet: Generalized Autoregressive Pretraining for Language Understanding <https://arxiv.org/abs/1906.08237>`__ by Zhilin Yang\*, Zihang Dai\*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov and Quoc V. Le.
This PyTorch implementation of XLM is an adaptation of the `Tensorflow implementation <https://github.com/zihangdai/xlnet>`__. TODO Lysandre filled This PyTorch implementation of XLM is an adaptation of the `Tensorflow implementation <https://github.com/zihangdai/xlnet>`__.
Content Content
...@@ -91,7 +91,7 @@ Content ...@@ -91,7 +91,7 @@ Content
* - `Migration <./migration.html>`__ * - `Migration <./migration.html>`__
- Migrating from ``pytorch_pretrained_BERT`` (v0.6) to ``pytorch_transformers`` (v1.0) - Migrating from ``pytorch_pretrained_BERT`` (v0.6) to ``pytorch_transformers`` (v1.0)
* - `Bertology <./bertology.html>`__ * - `Bertology <./bertology.html>`__
- TODO Lysandre didn't know how to fill - Exploring the internals of the pretrained models.
* - `TorchScript <./torchscript.html>`__ * - `TorchScript <./torchscript.html>`__
- Convert a model to TorchScript for use in other programming languages - Convert a model to TorchScript for use in other programming languages
...@@ -115,8 +115,6 @@ Content ...@@ -115,8 +115,6 @@ Content
* - `XLNet <./model_doc/xlnet.html>`__ * - `XLNet <./model_doc/xlnet.html>`__
- XLNet Models, Tokenizers and optimizers - XLNet Models, Tokenizers and optimizers
TODO Lysandre filled: might need an introduction for both parts. Is it even necessary, since there is a summary? Up to you Thom.
Overview Overview
-------- --------
...@@ -219,17 +217,10 @@ TODO Lysandre filled: I filled in XLM and XLNet. I didn't do the Tokenizers beca ...@@ -219,17 +217,10 @@ TODO Lysandre filled: I filled in XLM and XLNet. I didn't do the Tokenizers beca
* *
Optimizer for **BERT** (in the `optimization.py <./_modules/pytorch_transformers/optimization.html>`__ file): Optimizer (in the `optimization.py <./_modules/pytorch_transformers/optimization.html>`__ file):
* ``BertAdam`` - Bert version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
*
Optimizer for **OpenAI GPT** (in the `optimization_openai.py <./_modules/pytorch_transformers/optimization_openai.html>`__ file):
* ``OpenAIAdam`` - OpenAI GPT version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate. * ``AdamW`` - Version of Adam algorithm with weight decay fix, warmup and linear decay of the learning rate.
* *
......
...@@ -15,10 +15,10 @@ BERT ...@@ -15,10 +15,10 @@ BERT
:members: :members:
``BertAdam`` ``AdamW``
~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.BertAdam .. autoclass:: pytorch_transformers.AdamW
:members: :members:
``BertModel`` ``BertModel``
......
...@@ -15,13 +15,6 @@ OpenAI GPT ...@@ -15,13 +15,6 @@ OpenAI GPT
:members: :members:
``OpenAIAdam``
~~~~~~~~~~~~~~~~~~
.. autoclass:: pytorch_transformers.OpenAIAdam
:members:
``OpenAIGPTModel`` ``OpenAIGPTModel``
~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -236,7 +236,7 @@ Learning Rate Schedules ...@@ -236,7 +236,7 @@ Learning Rate Schedules
The ``.optimization`` module also provides additional schedules in the form of schedule objects that inherit from ``_LRSchedule``. The ``.optimization`` module also provides additional schedules in the form of schedule objects that inherit from ``_LRSchedule``.
All ``_LRSchedule`` subclasses accept ``warmup`` and ``t_total`` arguments at construction. All ``_LRSchedule`` subclasses accept ``warmup`` and ``t_total`` arguments at construction.
When an ``_LRSchedule`` object is passed into ``BertAdam`` or ``OpenAIAdam``\ , When an ``_LRSchedule`` object is passed into ``AdamW``\ ,
the ``warmup`` and ``t_total`` arguments on the optimizer are ignored and the ones in the ``_LRSchedule`` object are used. the ``warmup`` and ``t_total`` arguments on the optimizer are ignored and the ones in the ``_LRSchedule`` object are used.
An overview of the implemented schedules: An overview of the implemented schedules:
......
...@@ -16,7 +16,7 @@ from tqdm import tqdm ...@@ -16,7 +16,7 @@ from tqdm import tqdm
from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME from pytorch_transformers import WEIGHTS_NAME, CONFIG_NAME
from pytorch_transformers.modeling_bert import BertForPreTraining from pytorch_transformers.modeling_bert import BertForPreTraining
from pytorch_transformers.tokenization_bert import BertTokenizer from pytorch_transformers.tokenization_bert import BertTokenizer
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next") InputFeatures = namedtuple("InputFeatures", "input_ids input_mask segment_ids lm_label_ids is_next")
...@@ -273,7 +273,7 @@ def main(): ...@@ -273,7 +273,7 @@ def main():
warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion, warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
t_total=num_train_optimization_steps) t_total=num_train_optimization_steps)
else: else:
optimizer = BertAdam(optimizer_grouped_parameters, optimizer = AdamW(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
t_total=num_train_optimization_steps) t_total=num_train_optimization_steps)
......
This diff is collapsed.
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Generation with GPT/GPT-2/Transformer-XL/XLNet models """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/Transformer-XL/XLNet)
""" """
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Finetuning a classification model (Bert, XLM, XLNet,...) on GLUE.""" """ Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
...@@ -230,6 +230,9 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -230,6 +230,9 @@ def evaluate(args, model, tokenizer, prefix=""):
logger.info(" %s = %s", key, str(result[key])) logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key]))) writer.write("%s = %s\n" % (key, str(result[key])))
if args.local_rank in [-1, 0]:
tb_writer.close()
return results return results
...@@ -242,7 +245,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False): ...@@ -242,7 +245,7 @@ def load_and_cache_examples(args, task, tokenizer, evaluate=False):
list(filter(None, args.model_name.split('/'))).pop(), list(filter(None, args.model_name.split('/'))).pop(),
str(args.max_seq_length), str(args.max_seq_length),
str(task))) str(task)))
if os.path.exists(cached_features_file) and not args.overwrite_cache: if os.path.exists(cached_features_file):
logger.info("Loading features from cached file %s", cached_features_file) logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file) features = torch.load(cached_features_file)
else: else:
...@@ -410,7 +413,7 @@ def main(): ...@@ -410,7 +413,7 @@ def main():
if args.local_rank == 0: if args.local_rank == 0:
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
# Distributed and parrallel training # Distributed and parallel training
model.to(args.device) model.to(args.device)
if args.local_rank != -1: if args.local_rank != -1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Finetuning a question-answering model (Bert, XLM, XLNet,...) on SQuAD.""" """ Finetuning the library models for question-answering on SQuAD (Bert, XLM, XLNet)."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
...@@ -21,7 +21,7 @@ import argparse ...@@ -21,7 +21,7 @@ import argparse
import logging import logging
import os import os
import random import random
from io import open import glob
import numpy as np import numpy as np
import torch import torch
...@@ -43,6 +43,9 @@ from pytorch_transformers import AdamW, WarmupLinearSchedule ...@@ -43,6 +43,9 @@ from pytorch_transformers import AdamW, WarmupLinearSchedule
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
# The follwing import is the official SQuAD evaluation script (2.0).
# You can remove it from the dependencies if you are using this script outside of the library
# We've added it here for automated tests (see examples/test_examples.py file)
from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad from utils_squad_evaluate import EVAL_OPTS, main as evaluate_on_squad
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -123,7 +126,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -123,7 +126,7 @@ def train(args, train_dataset, model, tokenizer):
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc) loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
if args.n_gpu > 1: if args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training loss = loss.mean() # mean() to average on multi-gpu parallel (not distributed) training
if args.gradient_accumulation_steps > 1: if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps loss = loss / args.gradient_accumulation_steps
...@@ -169,6 +172,9 @@ def train(args, train_dataset, model, tokenizer): ...@@ -169,6 +172,9 @@ def train(args, train_dataset, model, tokenizer):
train_iterator.close() train_iterator.close()
break break
if args.local_rank in [-1, 0]:
tb_writer.close()
return global_step, tr_loss / global_step return global_step, tr_loss / global_step
...@@ -208,16 +214,16 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -208,16 +214,16 @@ def evaluate(args, model, tokenizer, prefix=""):
start_logits=start_logits, start_logits=start_logits,
end_logits=end_logits)) end_logits=end_logits))
# Compute predictions
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix)) output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix)) output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix)) output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
all_predictions = write_predictions(examples, features, all_results, write_predictions(examples, features, all_results, args.n_best_size, args.max_answer_length,
args.n_best_size, args.max_answer_length, args.do_lower_case, output_prediction_file, output_nbest_file,
args.do_lower_case, output_prediction_file, output_null_log_odds_file, args.verbose_logging,
output_nbest_file, output_null_log_odds_file, args.version_2_with_negative, args.null_score_diff_threshold)
args.verbose_logging, args.version_2_with_negative,
args.null_score_diff_threshold)
# Evaluate with the official SQuAD script
evaluate_options = EVAL_OPTS(data_file=args.predict_file, evaluate_options = EVAL_OPTS(data_file=args.predict_file,
pred_file=output_prediction_file, pred_file=output_prediction_file,
na_prob_file=output_null_log_odds_file) na_prob_file=output_null_log_odds_file)
...@@ -432,7 +438,7 @@ def main(): ...@@ -432,7 +438,7 @@ def main():
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() # Save the trained model and the tokenizer
if args.local_rank == -1 or torch.distributed.get_rank() == 0: if args.local_rank == -1 or torch.distributed.get_rank() == 0:
# Create output directory if needed # Create output directory if needed
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
...@@ -454,22 +460,30 @@ def main(): ...@@ -454,22 +460,30 @@ def main():
model.to(args.device) model.to(args.device)
# Evaluation # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results = {} results = {}
if args.do_eval and args.local_rank in [-1, 0]: if args.do_eval and args.local_rank in [-1, 0]:
checkpoints = [args.output_dir] checkpoints = [args.output_dir]
if args.eval_all_checkpoints: if args.eval_all_checkpoints:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce model loading logs
logger.info("Evaluate the following checkpoints: %s", checkpoints) logger.info("Evaluate the following checkpoints: %s", checkpoints)
for checkpoint in checkpoints: for checkpoint in checkpoints:
# Reload the model
global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else ""
model = model_class.from_pretrained(checkpoint) model = model_class.from_pretrained(checkpoint)
model.to(args.device) model.to(args.device)
# Evaluate
result = evaluate(args, model, tokenizer, prefix=global_step) result = evaluate(args, model, tokenizer, prefix=global_step)
result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items()) result = dict((k + ('_{}'.format(global_step) if global_step else ''), v) for k, v in result.items())
results.update(result) results.update(result)
logger.info("Results: {}".format(results)) logger.info("Results: {}".format(results))
return results return results
......
...@@ -40,7 +40,7 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, ...@@ -40,7 +40,7 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset) TensorDataset)
from pytorch_transformers import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, from pytorch_transformers import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
OpenAIAdam, cached_path, WEIGHTS_NAME, CONFIG_NAME) AdamW, cached_path, WEIGHTS_NAME, CONFIG_NAME)
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz" ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
...@@ -191,7 +191,7 @@ def main(): ...@@ -191,7 +191,7 @@ def main():
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
] ]
num_train_optimization_steps = len(train_dataloader) * args.num_train_epochs num_train_optimization_steps = len(train_dataloader) * args.num_train_epochs
optimizer = OpenAIAdam(optimizer_grouped_parameters, optimizer = AdamW(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
max_grad_norm=args.max_grad_norm, max_grad_norm=args.max_grad_norm,
......
...@@ -34,7 +34,7 @@ from tqdm import tqdm, trange ...@@ -34,7 +34,7 @@ from tqdm import tqdm, trange
from pytorch_transformers.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME from pytorch_transformers.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
from pytorch_transformers.modeling_bert import BertForMultipleChoice, BertConfig from pytorch_transformers.modeling_bert import BertForMultipleChoice, BertConfig
from pytorch_transformers.optimization import BertAdam, WarmupLinearSchedule from pytorch_transformers.optimization import AdamW, WarmupLinearSchedule
from pytorch_transformers.tokenization_bert import BertTokenizer from pytorch_transformers.tokenization_bert import BertTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
......
...@@ -91,7 +91,6 @@ class ExamplesTests(unittest.TestCase): ...@@ -91,7 +91,6 @@ class ExamplesTests(unittest.TestCase):
self.assertGreaterEqual(result['f1'], 30) self.assertGreaterEqual(result['f1'], 30)
self.assertGreaterEqual(result['exact'], 30) self.assertGreaterEqual(result['exact'], 30)
def test_generation(self): def test_generation(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
......
This diff is collapsed.
...@@ -36,6 +36,13 @@ WEIGHTS_NAME = "pytorch_model.bin" ...@@ -36,6 +36,13 @@ WEIGHTS_NAME = "pytorch_model.bin"
TF_WEIGHTS_NAME = 'model.ckpt' TF_WEIGHTS_NAME = 'model.ckpt'
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = ''.join(docstr) + fn.__doc__
return fn
return docstring_decorator
class PretrainedConfig(object): class PretrainedConfig(object):
""" An abstract class to handle dowloading a model pretrained config. """ An abstract class to handle dowloading a model pretrained config.
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment