"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8edfaaa81b9995cedea2f8805e4c18c2b6cb5bfc"
Commit bfd6f6b2 authored by Ailing Zhang's avatar Ailing Zhang
Browse files

fix from_pretrained positional args

parent ae4c9fee
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import (BertForNextSentencePrediction, from pytorch_pretrained_bert.modeling import (
BertForMaskedLM, BertModel,
BertForMultipleChoice, BertForNextSentencePrediction,
BertForPreTraining, BertForMaskedLM,
BertForQuestionAnswering, BertForMultipleChoice,
BertForSequenceClassification, BertForPreTraining,
) BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
)
dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex'] dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex']
# A lot of models share the same param doc. Use a decorator
# to save typing
bert_docstring = """
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining
instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow
checkpoint
cache_dir: an optional path to a folder in which the pre-trained models
will be cached.
state_dict: an optional state dictionnary
(collections.OrderedDict object) to use instead of Google
pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
def _append_from_pretrained_docstring(docstr):
def docstring_decorator(fn):
fn.__doc__ = fn.__doc__ + docstr
return fn
return docstring_decorator
def bertTokenizer(*args, **kwargs): def bertTokenizer(*args, **kwargs):
""" """
...@@ -43,7 +84,7 @@ def bertTokenizer(*args, **kwargs): ...@@ -43,7 +84,7 @@ def bertTokenizer(*args, **kwargs):
Example: Example:
>>> sentence = 'Hello, World!' >>> sentence = 'Hello, World!'
>>> tokenizer = torch.hub.load('ailzhang/pytorch-pretrained-BERT:hubconf', 'BertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False) >>> tokenizer = torch.hub.load('ailzhang/pytorch-pretrained-BERT:hubconf', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
>>> toks = tokenizer.tokenize(sentence) >>> toks = tokenizer.tokenize(sentence)
['Hello', '##,', 'World', '##!'] ['Hello', '##,', 'World', '##!']
>>> ids = tokenizer.convert_tokens_to_ids(toks) >>> ids = tokenizer.convert_tokens_to_ids(toks)
...@@ -53,135 +94,94 @@ def bertTokenizer(*args, **kwargs): ...@@ -53,135 +94,94 @@ def bertTokenizer(*args, **kwargs):
return tokenizer return tokenizer
@_append_from_pretrained_docstring(bert_docstring)
def bertModel(*args, **kwargs):
"""
BertModel is the basic BERT Transformer model with a layer of summed token,
position and sequence embeddings followed by a series of identical
self-attention blocks (12 for BERT-base, 24 for BERT-large).
"""
model = BertModel.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForNextSentencePrediction(*args, **kwargs): def bertForNextSentencePrediction(*args, **kwargs):
"""BERT model with next sentence prediction head. """
This module comprises the BERT model followed by the next sentence classification head. BERT model with next sentence prediction head.
Params: This module comprises the BERT model followed by the next sentence
pretrained_model_name_or_path: either: classification head.
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
""" """
model = BertForNextSentencePrediction.from_pretrained(*args, **kwargs) model = BertForNextSentencePrediction.from_pretrained(*args, **kwargs)
return model return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForPreTraining(*args, **kwargs): def bertForPreTraining(*args, **kwargs):
"""BERT model with pre-training heads. """
This module comprises the BERT model followed by the two pre-training heads: BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads
- the masked language modeling head, and - the masked language modeling head, and
- the next sentence classification head. - the next sentence classification head.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
""" """
model = BertForPreTraining.from_pretrained(*args, **kwargs) model = BertForPreTraining.from_pretrained(*args, **kwargs)
return model return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForMaskedLM(*args, **kwargs): def bertForMaskedLM(*args, **kwargs):
""" """
BertForMaskedLM includes the BertModel Transformer followed by the (possibly) BertForMaskedLM includes the BertModel Transformer followed by the
pre-trained masked language modeling head. (possibly) pre-trained masked language modeling head.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
""" """
model = BertForMaskedLM.from_pretrained(*args, **kwargs) model = BertForMaskedLM.from_pretrained(*args, **kwargs)
return model return model
#def bertForSequenceClassification(*args, **kwargs): @_append_from_pretrained_docstring(bert_docstring)
# model = BertForSequenceClassification.from_pretrained(*args, **kwargs) def bertForSequenceClassification(*args, **kwargs):
# return model """
BertForSequenceClassification is a fine-tuning model that includes
BertModel and a sequence-level (sequence or pair of sequences) classifier
on top of the BertModel.
The sequence-level classifier is a linear layer that takes as input the
last hidden state of the first character in the input sequence
(see Figures 3a and 3b in the BERT paper).
"""
model = BertForSequenceClassification.from_pretrained(*args, **kwargs)
return model
#def bertForMultipleChoice(*args, **kwargs): @_append_from_pretrained_docstring(bert_docstring)
# model = BertForMultipleChoice.from_pretrained(*args, **kwargs) def bertForMultipleChoice(*args, **kwargs):
# return model """
BertForMultipleChoice is a fine-tuning model that includes BertModel and a
linear layer on top of the BertModel.
"""
model = BertForMultipleChoice.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForQuestionAnswering(*args, **kwargs): def bertForQuestionAnswering(*args, **kwargs):
""" """
BertForQuestionAnswering is a fine-tuning model that includes BertModel with BertForQuestionAnswering is a fine-tuning model that includes BertModel
a token-level classifiers on top of the full sequence of last hidden states. with a token-level classifiers on top of the full sequence of last hidden
Params: states.
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
""" """
model = BertForQuestionAnswering.from_pretrained(*args, **kwargs) model = BertForQuestionAnswering.from_pretrained(*args, **kwargs)
return model return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForTokenClassification(*args, **kwargs):
"""
BertForTokenClassification is a fine-tuning model that includes BertModel
and a token-level classifier on top of the BertModel.
The token-level classifier is a linear layer that takes as input the last
hidden state of the sequence.
"""
model = BertForTokenClassification.from_pretrained(*args, **kwargs)
return model
...@@ -519,8 +519,7 @@ class BertPreTrainedModel(nn.Module): ...@@ -519,8 +519,7 @@ class BertPreTrainedModel(nn.Module):
module.bias.data.zero_() module.bias.data.zero_()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
from_tf=False, *inputs, **kwargs):
""" """
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict. Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed. Download and cache the pre-trained model file if needed.
...@@ -547,6 +546,13 @@ class BertPreTrainedModel(nn.Module): ...@@ -547,6 +546,13 @@ class BertPreTrainedModel(nn.Module):
*inputs, **kwargs: additional input for the specific Bert class *inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification) (ex: num_labels for BertForSequenceClassification)
""" """
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment