Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bfd6f6b2
Commit
bfd6f6b2
authored
Apr 17, 2019
by
Ailing Zhang
Browse files
fix from_pretrained positional args
parent
ae4c9fee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
115 additions
and
109 deletions
+115
-109
hubconf.py
hubconf.py
+107
-107
pytorch_pretrained_bert/modeling.py
pytorch_pretrained_bert/modeling.py
+8
-2
No files found.
hubconf.py
View file @
bfd6f6b2
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.modeling
import
(
BertForNextSentencePrediction
,
BertForMaskedLM
,
BertForMultipleChoice
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
)
from
pytorch_pretrained_bert.modeling
import
(
BertModel
,
BertForNextSentencePrediction
,
BertForMaskedLM
,
BertForMultipleChoice
,
BertForPreTraining
,
BertForQuestionAnswering
,
BertForSequenceClassification
,
BertForTokenClassification
,
)
dependencies
=
[
'torch'
,
'tqdm'
,
'boto3'
,
'requests'
,
'regex'
]
# A lot of models share the same param doc. Use a decorator
# to save typing
bert_docstring
=
"""
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining
instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow
checkpoint
cache_dir: an optional path to a folder in which the pre-trained models
will be cached.
state_dict: an optional state dictionnary
(collections.OrderedDict object) to use instead of Google
pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
def
_append_from_pretrained_docstring
(
docstr
):
def
docstring_decorator
(
fn
):
fn
.
__doc__
=
fn
.
__doc__
+
docstr
return
fn
return
docstring_decorator
def
bertTokenizer
(
*
args
,
**
kwargs
):
"""
...
...
@@ -43,7 +84,7 @@ def bertTokenizer(*args, **kwargs):
Example:
>>> sentence = 'Hello, World!'
>>> tokenizer = torch.hub.load('ailzhang/pytorch-pretrained-BERT:hubconf', '
B
ertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
>>> tokenizer = torch.hub.load('ailzhang/pytorch-pretrained-BERT:hubconf', '
b
ertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
>>> toks = tokenizer.tokenize(sentence)
['Hello', '##,', 'World', '##!']
>>> ids = tokenizer.convert_tokens_to_ids(toks)
...
...
@@ -53,135 +94,94 @@ def bertTokenizer(*args, **kwargs):
return
tokenizer
@
_append_from_pretrained_docstring
(
bert_docstring
)
def
bertModel
(
*
args
,
**
kwargs
):
"""
BertModel is the basic BERT Transformer model with a layer of summed token,
position and sequence embeddings followed by a series of identical
self-attention blocks (12 for BERT-base, 24 for BERT-large).
"""
model
=
BertModel
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
@
_append_from_pretrained_docstring
(
bert_docstring
)
def
bertForNextSentencePrediction
(
*
args
,
**
kwargs
):
"""BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence classification head.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence
classification head.
"""
model
=
BertForNextSentencePrediction
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
@
_append_from_pretrained_docstring
(
bert_docstring
)
def
bertForPreTraining
(
*
args
,
**
kwargs
):
"""BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads:
"""
BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads
- the masked language modeling head, and
- the next sentence classification head.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
model
=
BertForPreTraining
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
@
_append_from_pretrained_docstring
(
bert_docstring
)
def
bertForMaskedLM
(
*
args
,
**
kwargs
):
"""
BertForMaskedLM includes the BertModel Transformer followed by the (possibly)
pre-trained masked language modeling head.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
BertForMaskedLM includes the BertModel Transformer followed by the
(possibly) pre-trained masked language modeling head.
"""
model
=
BertForMaskedLM
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
#def bertForSequenceClassification(*args, **kwargs):
# model = BertForSequenceClassification.from_pretrained(*args, **kwargs)
# return model
@
_append_from_pretrained_docstring
(
bert_docstring
)
def
bertForSequenceClassification
(
*
args
,
**
kwargs
):
"""
BertForSequenceClassification is a fine-tuning model that includes
BertModel and a sequence-level (sequence or pair of sequences) classifier
on top of the BertModel.
The sequence-level classifier is a linear layer that takes as input the
last hidden state of the first character in the input sequence
(see Figures 3a and 3b in the BERT paper).
"""
model
=
BertForSequenceClassification
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
#def bertForMultipleChoice(*args, **kwargs):
# model = BertForMultipleChoice.from_pretrained(*args, **kwargs)
# return model
@
_append_from_pretrained_docstring
(
bert_docstring
)
def
bertForMultipleChoice
(
*
args
,
**
kwargs
):
"""
BertForMultipleChoice is a fine-tuning model that includes BertModel and a
linear layer on top of the BertModel.
"""
model
=
BertForMultipleChoice
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
@
_append_from_pretrained_docstring
(
bert_docstring
)
def
bertForQuestionAnswering
(
*
args
,
**
kwargs
):
"""
BertForQuestionAnswering is a fine-tuning model that includes BertModel with
a token-level classifiers on top of the full sequence of last hidden states.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of Google pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
BertForQuestionAnswering is a fine-tuning model that includes BertModel
with a token-level classifiers on top of the full sequence of last hidden
states.
"""
model
=
BertForQuestionAnswering
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
@
_append_from_pretrained_docstring
(
bert_docstring
)
def
bertForTokenClassification
(
*
args
,
**
kwargs
):
"""
BertForTokenClassification is a fine-tuning model that includes BertModel
and a token-level classifier on top of the BertModel.
The token-level classifier is a linear layer that takes as input the last
hidden state of the sequence.
"""
model
=
BertForTokenClassification
.
from_pretrained
(
*
args
,
**
kwargs
)
return
model
pytorch_pretrained_bert/modeling.py
View file @
bfd6f6b2
...
...
@@ -519,8 +519,7 @@ class BertPreTrainedModel(nn.Module):
module
.
bias
.
data
.
zero_
()
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
state_dict
=
None
,
cache_dir
=
None
,
from_tf
=
False
,
*
inputs
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
):
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
...
...
@@ -547,6 +546,13 @@ class BertPreTrainedModel(nn.Module):
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
state_dict
=
kwargs
.
get
(
'state_dict'
,
None
)
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
get
(
'cache_dir'
,
None
)
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
get
(
'from_tf'
,
False
)
kwargs
.
pop
(
'from_tf'
,
None
)
if
pretrained_model_name_or_path
in
PRETRAINED_MODEL_ARCHIVE_MAP
:
archive_file
=
PRETRAINED_MODEL_ARCHIVE_MAP
[
pretrained_model_name_or_path
]
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment