Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f56b8033
Commit
f56b8033
authored
Jun 26, 2019
by
thomwolf
Browse files
more versatile loading
parent
4d47f498
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
57 additions
and
6 deletions
+57
-6
pytorch_pretrained_bert/model_utils.py
pytorch_pretrained_bert/model_utils.py
+9
-4
pytorch_pretrained_bert/modeling_xlm.py
pytorch_pretrained_bert/modeling_xlm.py
+48
-2
No files found.
pytorch_pretrained_bert/model_utils.py
View file @
f56b8033
...
@@ -255,7 +255,7 @@ class PreTrainedModel(nn.Module):
...
@@ -255,7 +255,7 @@ class PreTrainedModel(nn.Module):
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
state_dict
=
torch
.
load
(
resolved_archive_file
,
map_location
=
'cpu'
)
if
from_tf
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
# Directly load from a TensorFlow checkpoint
return
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
return
cls
.
load_tf_weights
(
model
,
config
,
resolved_archive_file
[:
-
6
])
# Remove the '.index'
# Load from a PyTorch state_dict
# Load from a PyTorch state_dict
missing_keys
=
[]
missing_keys
=
[]
...
@@ -275,10 +275,15 @@ class PreTrainedModel(nn.Module):
...
@@ -275,10 +275,15 @@ class PreTrainedModel(nn.Module):
if
child
is
not
None
:
if
child
is
not
None
:
load
(
child
,
prefix
+
name
+
'.'
)
load
(
child
,
prefix
+
name
+
'.'
)
# Be able to load base models as well as derived models (with heads)
start_prefix
=
''
start_prefix
=
''
model_to_load
=
model
if
not
hasattr
(
model
,
cls
.
base_model_prefix
)
and
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()):
if
not
hasattr
(
model
,
cls
.
base_model_prefix
)
and
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()):
start_prefix
=
cls
.
base_model_prefix
+
'.'
# Used to be able to load base models as well as derived modesl (with heads)
start_prefix
=
cls
.
base_model_prefix
+
'.'
load
(
model
,
prefix
=
start_prefix
)
if
hasattr
(
model
,
cls
.
base_model_prefix
)
and
not
any
(
s
.
startswith
(
cls
.
base_model_prefix
)
for
s
in
state_dict
.
keys
()):
model_to_load
=
getattr
(
model
,
cls
.
base_model_prefix
)
load
(
model_to_load
,
prefix
=
start_prefix
)
if
len
(
missing_keys
)
>
0
:
if
len
(
missing_keys
)
>
0
:
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
logger
.
info
(
"Weights of {} not initialized from pretrained model: {}"
.
format
(
model
.
__class__
.
__name__
,
missing_keys
))
model
.
__class__
.
__name__
,
missing_keys
))
...
@@ -289,7 +294,7 @@ class PreTrainedModel(nn.Module):
...
@@ -289,7 +294,7 @@ class PreTrainedModel(nn.Module):
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
hasattr
(
model
,
tie_weights
):
if
hasattr
(
model
,
'
tie_weights
'
):
model
.
tie_weights
()
# make sure word embedding weights are still tied
model
.
tie_weights
()
# make sure word embedding weights are still tied
return
model
return
model
...
...
pytorch_pretrained_bert/modeling_xlm.py
View file @
f56b8033
...
@@ -430,8 +430,54 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -430,8 +430,54 @@ class XLMModel(XLMPreTrainedModel):
'asm_cutoffs'
,
'asm_div_value'
]
'asm_cutoffs'
,
'asm_div_value'
]
def
__init__
(
self
,
params
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
#, dico, is_encoder, with_output):
def
__init__
(
self
,
params
,
output_attentions
=
False
,
keep_multihead_output
=
False
):
#, dico, is_encoder, with_output):
"""
"""XLM model ("Bidirectional Embedding Representations from a Transformer").
Transformer model (encoder or decoder).
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.BertModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
"""
"""
super
(
XLMModel
,
self
).
__init__
(
params
)
super
(
XLMModel
,
self
).
__init__
(
params
)
self
.
output_attentions
=
output_attentions
self
.
output_attentions
=
output_attentions
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment