Commit f56b8033 authored by thomwolf's avatar thomwolf
Browse files

more versatile loading

parent 4d47f498
...@@ -255,7 +255,7 @@ class PreTrainedModel(nn.Module): ...@@ -255,7 +255,7 @@ class PreTrainedModel(nn.Module):
state_dict = torch.load(resolved_archive_file, map_location='cpu') state_dict = torch.load(resolved_archive_file, map_location='cpu')
if from_tf: if from_tf:
# Directly load from a TensorFlow checkpoint # Directly load from a TensorFlow checkpoint
return load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index' return cls.load_tf_weights(model, config, resolved_archive_file[:-6]) # Remove the '.index'
# Load from a PyTorch state_dict # Load from a PyTorch state_dict
missing_keys = [] missing_keys = []
...@@ -275,10 +275,15 @@ class PreTrainedModel(nn.Module): ...@@ -275,10 +275,15 @@ class PreTrainedModel(nn.Module):
if child is not None: if child is not None:
load(child, prefix + name + '.') load(child, prefix + name + '.')
# Be able to load base models as well as derived models (with heads)
start_prefix = '' start_prefix = ''
model_to_load = model
if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()): if not hasattr(model, cls.base_model_prefix) and any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
start_prefix = cls.base_model_prefix + '.' # Used to be able to load base models as well as derived modesl (with heads) start_prefix = cls.base_model_prefix + '.'
load(model, prefix=start_prefix) if hasattr(model, cls.base_model_prefix) and not any(s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
model_to_load = getattr(model, cls.base_model_prefix)
load(model_to_load, prefix=start_prefix)
if len(missing_keys) > 0: if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format( logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys)) model.__class__.__name__, missing_keys))
...@@ -289,7 +294,7 @@ class PreTrainedModel(nn.Module): ...@@ -289,7 +294,7 @@ class PreTrainedModel(nn.Module):
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
model.__class__.__name__, "\n\t".join(error_msgs))) model.__class__.__name__, "\n\t".join(error_msgs)))
if hasattr(model, tie_weights): if hasattr(model, 'tie_weights'):
model.tie_weights() # make sure word embedding weights are still tied model.tie_weights() # make sure word embedding weights are still tied
return model return model
......
...@@ -430,8 +430,54 @@ class XLMModel(XLMPreTrainedModel): ...@@ -430,8 +430,54 @@ class XLMModel(XLMPreTrainedModel):
'asm_cutoffs', 'asm_div_value'] 'asm_cutoffs', 'asm_div_value']
def __init__(self, params, output_attentions=False, keep_multihead_output=False): #, dico, is_encoder, with_output): def __init__(self, params, output_attentions=False, keep_multihead_output=False): #, dico, is_encoder, with_output):
""" """XLM model ("Bidirectional Embedding Representations from a Transformer").
Transformer model (encoder or decoder).
Params:
`config`: a BertConfig class instance with the configuration to build a new model
`output_attentions`: If True, also output attentions weights computed by the model at each layer. Default: False
`keep_multihead_output`: If True, saves output of the multi-head attention module with its gradient.
This can be used to compute head importance metrics. Default: False
Inputs:
`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
`run_bert_extract_features.py`, `run_bert_classifier.py` and `run_bert_squad.py`)
`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
a `sentence B` token (see BERT paper for more details).
`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
input sequence length in the current batch. It's the mask that we typically use for attention when
a batch has varying length sentences.
`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.
`head_mask`: an optional torch.Tensor of shape [num_heads] or [num_layers, num_heads] with indices between 0 and 1.
It's a mask to be used to nullify some heads of the transformer. 1.0 => head is fully masked, 0.0 => head is not masked.
Outputs: Tuple of (encoded_layers, pooled_output)
`encoded_layers`: controled by `output_all_encoded_layers` argument:
- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
to the last attention block of shape [batch_size, sequence_length, hidden_size],
`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
classifier pretrained on top of the hidden state associated to the first character of the
input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
Example usage:
```python
# Already been converted into WordPiece token ids
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
model = modeling.BertModel(config=config)
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
```
""" """
super(XLMModel, self).__init__(params) super(XLMModel, self).__init__(params)
self.output_attentions = output_attentions self.output_attentions = output_attentions
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment