Unverified Commit 8f2723ca authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Output attention takes an s (#6903)

* Fix output_attention -> output_attentions

* Formatting

* One unsaved file
parent 485da722
...@@ -28,10 +28,10 @@ def config(*args, **kwargs): ...@@ -28,10 +28,10 @@ def config(*args, **kwargs):
config = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased') # Download configuration from S3 and cache. config = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased') # Download configuration from S3 and cache.
config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/my_configuration.json') config = torch.hub.load('huggingface/transformers', 'config', './test/bert_saved_model/my_configuration.json')
config = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased', output_attention=True, foo=False) config = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased', output_attentions=True, foo=False)
assert config.output_attention == True assert config.output_attentions == True
config, unused_kwargs = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased', output_attention=True, foo=False, return_unused_kwargs=True) config, unused_kwargs = torch.hub.load('huggingface/transformers', 'config', 'bert-base-uncased', output_attentions=True, foo=False, return_unused_kwargs=True)
assert config.output_attention == True assert config.output_attentions == True
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
...@@ -61,8 +61,8 @@ def model(*args, **kwargs): ...@@ -61,8 +61,8 @@ def model(*args, **kwargs):
model = torch.hub.load('huggingface/transformers', 'model', 'bert-base-uncased') # Download model and configuration from S3 and cache. model = torch.hub.load('huggingface/transformers', 'model', 'bert-base-uncased') # Download model and configuration from S3 and cache.
model = torch.hub.load('huggingface/transformers', 'model', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = torch.hub.load('huggingface/transformers', 'model', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = torch.hub.load('huggingface/transformers', 'model', 'bert-base-uncased', output_attention=True) # Update configuration during loading model = torch.hub.load('huggingface/transformers', 'model', 'bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = torch.hub.load('huggingface/transformers', 'model', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = torch.hub.load('huggingface/transformers', 'model', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -80,8 +80,8 @@ def modelWithLMHead(*args, **kwargs): ...@@ -80,8 +80,8 @@ def modelWithLMHead(*args, **kwargs):
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased') # Download model and configuration from S3 and cache. model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased') # Download model and configuration from S3 and cache.
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased', output_attention=True) # Update configuration during loading model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', 'bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = torch.hub.load('huggingface/transformers', 'modelWithLMHead', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -98,8 +98,8 @@ def modelForSequenceClassification(*args, **kwargs): ...@@ -98,8 +98,8 @@ def modelForSequenceClassification(*args, **kwargs):
model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'bert-base-uncased') # Download model and configuration from S3 and cache. model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'bert-base-uncased') # Download model and configuration from S3 and cache.
model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'bert-base-uncased', output_attention=True) # Update configuration during loading model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', 'bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = torch.hub.load('huggingface/transformers', 'modelForSequenceClassification', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -117,8 +117,8 @@ def modelForQuestionAnswering(*args, **kwargs): ...@@ -117,8 +117,8 @@ def modelForQuestionAnswering(*args, **kwargs):
model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', 'bert-base-uncased') # Download model and configuration from S3 and cache. model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', 'bert-base-uncased') # Download model and configuration from S3 and cache.
model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', './test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', 'bert-base-uncased', output_attention=True) # Update configuration during loading model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', 'bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = torch.hub.load('huggingface/transformers', 'modelForQuestionAnswering', './tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
......
...@@ -261,11 +261,11 @@ class AutoConfig: ...@@ -261,11 +261,11 @@ class AutoConfig:
config = AutoConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = AutoConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
config = AutoConfig.from_pretrained('./test/bert_saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` config = AutoConfig.from_pretrained('./test/bert_saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json') config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json')
config = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) config = AutoConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False)
assert config.output_attention == True assert config.output_attentions == True
config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True, config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attentions=True,
foo=False, return_unused_kwargs=True) foo=False, return_unused_kwargs=True)
assert config.output_attention == True assert config.output_attentions == True
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
......
...@@ -300,11 +300,11 @@ class PretrainedConfig(object): ...@@ -300,11 +300,11 @@ class PretrainedConfig(object):
config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache.
config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')`
config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json')
config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) config = BertConfig.from_pretrained('bert-base-uncased', output_attentions=True, foo=False)
assert config.output_attention == True assert config.output_attentions == True
config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attentions=True,
foo=False, return_unused_kwargs=True) foo=False, return_unused_kwargs=True)
assert config.output_attention == True assert config.output_attentions == True
assert unused_kwargs == {'foo': False} assert unused_kwargs == {'foo': False}
""" """
......
...@@ -122,7 +122,7 @@ class ModelCard: ...@@ -122,7 +122,7 @@ class ModelCard:
modelcard = ModelCard.from_pretrained('bert-base-uncased') # Download model card from S3 and cache. modelcard = ModelCard.from_pretrained('bert-base-uncased') # Download model card from S3 and cache.
modelcard = ModelCard.from_pretrained('./test/saved_model/') # E.g. model card was saved using `save_pretrained('./test/saved_model/')` modelcard = ModelCard.from_pretrained('./test/saved_model/') # E.g. model card was saved using `save_pretrained('./test/saved_model/')`
modelcard = ModelCard.from_pretrained('./test/saved_model/modelcard.json') modelcard = ModelCard.from_pretrained('./test/saved_model/modelcard.json')
modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attention=True, foo=False) modelcard = ModelCard.from_pretrained('bert-base-uncased', output_attentions=True, foo=False)
""" """
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
......
...@@ -508,6 +508,7 @@ class AutoModel: ...@@ -508,6 +508,7 @@ class AutoModel:
Examples:: Examples::
model = AutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = AutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = AutoModel.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
...@@ -657,7 +658,8 @@ class AutoModelForPreTraining: ...@@ -657,7 +658,8 @@ class AutoModelForPreTraining:
model = AutoModelForPreTraining.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = AutoModelForPreTraining.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = AutoModelForPreTraining.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelForPreTraining.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True model = AutoModelForPreTraining.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModelForPreTraining.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForPreTraining.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -812,7 +814,8 @@ class AutoModelWithLMHead: ...@@ -812,7 +814,8 @@ class AutoModelWithLMHead:
model = AutoModelWithLMHead.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = AutoModelWithLMHead.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = AutoModelWithLMHead.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelWithLMHead.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True model = AutoModelWithLMHead.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelWithLMHead.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -953,7 +956,8 @@ class AutoModelForCausalLM: ...@@ -953,7 +956,8 @@ class AutoModelForCausalLM:
model = AutoModelForCausalLM.from_pretrained('gpt2') # Download model and configuration from S3 and cache. model = AutoModelForCausalLM.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
model = AutoModelForCausalLM.from_pretrained('./test/gpt2_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelForCausalLM.from_pretrained('./test/gpt2_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True model = AutoModelForCausalLM.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/gpt2_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/gpt2_tf_model_config.json')
model = AutoModelForCausalLM.from_pretrained('./tf_model/gpt2_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForCausalLM.from_pretrained('./tf_model/gpt2_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -1096,7 +1100,8 @@ class AutoModelForMaskedLM: ...@@ -1096,7 +1100,8 @@ class AutoModelForMaskedLM:
model = AutoModelForMaskedLM.from_pretrained('bert') # Download model and configuration from S3 and cache. model = AutoModelForMaskedLM.from_pretrained('bert') # Download model and configuration from S3 and cache.
model = AutoModelForMaskedLM.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelForMaskedLM.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True model = AutoModelForMaskedLM.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModelForMaskedLM.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForMaskedLM.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -1229,7 +1234,8 @@ class AutoModelForSeq2SeqLM: ...@@ -1229,7 +1234,8 @@ class AutoModelForSeq2SeqLM:
model = AutoModelForSeq2SeqLM.from_pretrained('t5-base') # Download model and configuration from S3 and cache. model = AutoModelForSeq2SeqLM.from_pretrained('t5-base') # Download model and configuration from S3 and cache.
model = AutoModelForSeq2SeqLM.from_pretrained('./test/t5_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelForSeq2SeqLM.from_pretrained('./test/t5_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True model = AutoModelForSeq2SeqLM.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/t5_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/t5_tf_model_config.json')
model = AutoModelForSeq2SeqLM.from_pretrained('./tf_model/t5_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForSeq2SeqLM.from_pretrained('./tf_model/t5_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -1381,7 +1387,8 @@ class AutoModelForSequenceClassification: ...@@ -1381,7 +1387,8 @@ class AutoModelForSequenceClassification:
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = AutoModelForSequenceClassification.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelForSequenceClassification.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForSequenceClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -1525,7 +1532,8 @@ class AutoModelForQuestionAnswering: ...@@ -1525,7 +1532,8 @@ class AutoModelForQuestionAnswering:
model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = AutoModelForQuestionAnswering.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelForQuestionAnswering.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True model = AutoModelForQuestionAnswering.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForQuestionAnswering.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -1677,7 +1685,8 @@ class AutoModelForTokenClassification: ...@@ -1677,7 +1685,8 @@ class AutoModelForTokenClassification:
model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = AutoModelForTokenClassification.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = AutoModelForTokenClassification.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True model = AutoModelForTokenClassification.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = AutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
......
...@@ -270,7 +270,7 @@ class TransformerBlock(nn.Module): ...@@ -270,7 +270,7 @@ class TransformerBlock(nn.Module):
) )
if output_attentions: if output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
assert type(sa_output) == tuple assert type(sa_output) == tuple
sa_output = sa_output[0] sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
......
...@@ -196,7 +196,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -196,7 +196,7 @@ class EncoderDecoderModel(PreTrainedModel):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method All remaning positional arguments will be passed to the underlying model's ``__init__`` method
kwargs: (`optional`) Remaining dictionary of keyword arguments. kwargs: (`optional`) Remaining dictionary of keyword arguments.
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attentions=True``).
- To update the encoder configuration, use the prefix `encoder_` for each configuration parameter - To update the encoder configuration, use the prefix `encoder_` for each configuration parameter
- To update the decoder configuration, use the prefix `decoder_` for each configuration parameter - To update the decoder configuration, use the prefix `decoder_` for each configuration parameter
- To update the parent model configuration, do not use a prefix for each configuration parameter - To update the parent model configuration, do not use a prefix for each configuration parameter
......
...@@ -442,7 +442,7 @@ class TFAutoModel(object): ...@@ -442,7 +442,7 @@ class TFAutoModel(object):
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments: kwargs: (`optional`) Remaining dictionary of keyword arguments:
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attentions=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.TFPretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.TFPretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
...@@ -451,8 +451,8 @@ class TFAutoModel(object): ...@@ -451,8 +451,8 @@ class TFAutoModel(object):
model = TFAutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = TFAutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = TFAutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFAutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = TFAutoModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading model = TFAutoModel.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = TFAutoModel.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) model = TFAutoModel.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
...@@ -588,7 +588,7 @@ class TFAutoModelForPreTraining(object): ...@@ -588,7 +588,7 @@ class TFAutoModelForPreTraining(object):
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments: kwargs: (`optional`) Remaining dictionary of keyword arguments:
Can be used to update the configuration object (after it being loaded) and initiate the model. Can be used to update the configuration object (after it being loaded) and initiate the model.
(e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or (e.g. ``output_attentions=True``). Behave differently depending on whether a `config` is provided or
automatically loaded: automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
...@@ -604,8 +604,8 @@ class TFAutoModelForPreTraining(object): ...@@ -604,8 +604,8 @@ class TFAutoModelForPreTraining(object):
model = TFAutoModelForPreTraining.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = TFAutoModelForPreTraining.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = TFAutoModelForPreTraining.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFAutoModelForPreTraining.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = TFAutoModelForPreTraining.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading model = TFAutoModelForPreTraining.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = TFAutoModelForPreTraining.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = TFAutoModelForPreTraining.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -767,7 +767,7 @@ class TFAutoModelWithLMHead(object): ...@@ -767,7 +767,7 @@ class TFAutoModelWithLMHead(object):
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments: kwargs: (`optional`) Remaining dictionary of keyword arguments:
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attentions=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.TFPretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.TFPretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
...@@ -776,8 +776,8 @@ class TFAutoModelWithLMHead(object): ...@@ -776,8 +776,8 @@ class TFAutoModelWithLMHead(object):
model = TFAutoModelWithLMHead.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = TFAutoModelWithLMHead.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = TFAutoModelWithLMHead.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFAutoModelWithLMHead.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = TFAutoModelWithLMHead.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading model = TFAutoModelWithLMHead.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = TFAutoModelWithLMHead.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) model = TFAutoModelWithLMHead.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
...@@ -921,7 +921,7 @@ class TFAutoModelForMultipleChoice: ...@@ -921,7 +921,7 @@ class TFAutoModelForMultipleChoice:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments: kwargs: (`optional`) Remaining dictionary of keyword arguments:
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attentions=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.TFPretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.TFPretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
...@@ -930,8 +930,8 @@ class TFAutoModelForMultipleChoice: ...@@ -930,8 +930,8 @@ class TFAutoModelForMultipleChoice:
model = TFAutoModelFormultipleChoice.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = TFAutoModelFormultipleChoice.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = TFAutoModelFormultipleChoice.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFAutoModelFormultipleChoice.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = TFAutoModelFormultipleChoice.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading model = TFAutoModelFormultipleChoice.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = TFAutoModelFormultipleChoice.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) model = TFAutoModelFormultipleChoice.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
...@@ -1068,7 +1068,8 @@ class TFAutoModelForCausalLM: ...@@ -1068,7 +1068,8 @@ class TFAutoModelForCausalLM:
model = TFAutoModelForCausalLM.from_pretrained('gpt2') # Download model and configuration from S3 and cache. model = TFAutoModelForCausalLM.from_pretrained('gpt2') # Download model and configuration from S3 and cache.
model = TFAutoModelForCausalLM.from_pretrained('./test/gpt2_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFAutoModelForCausalLM.from_pretrained('./test/gpt2_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True model = TFAutoModelForCausalLM.from_pretrained('gpt2', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/gpt2_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/gpt2_tf_model_config.json')
model = TFAutoModelForCausalLM.from_pretrained('./tf_model/gpt2_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = TFAutoModelForCausalLM.from_pretrained('./tf_model/gpt2_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -1208,9 +1209,10 @@ class TFAutoModelForMaskedLM: ...@@ -1208,9 +1209,10 @@ class TFAutoModelForMaskedLM:
Examples:: Examples::
model = TFAutoModelForMaskedLM.from_pretrained('bert') # Download model and configuration from S3 and cache. model = TFAutoModelForMaskedLM.from_pretrained(('bert-base-uncased') # Download model and configuration from S3 and cache.
model = TFAutoModelForMaskedLM.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFAutoModelForMaskedLM.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True model = TFAutoModelForMaskedLM.from_pretrained(('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = TFAutoModelForMaskedLM.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = TFAutoModelForMaskedLM.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -1337,7 +1339,8 @@ class TFAutoModelForSeq2SeqLM: ...@@ -1337,7 +1339,8 @@ class TFAutoModelForSeq2SeqLM:
model = TFAutoModelForSeq2SeqLM.from_pretrained('t5-base') # Download model and configuration from S3 and cache. model = TFAutoModelForSeq2SeqLM.from_pretrained('t5-base') # Download model and configuration from S3 and cache.
model = TFAutoModelForSeq2SeqLM.from_pretrained('./test/t5_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFAutoModelForSeq2SeqLM.from_pretrained('./test/t5_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
assert model.config.output_attention == True model = TFAutoModelForSeq2SeqLM.from_pretrained('t5-base', output_attentions=True) # Update configuration during loading
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/t5_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/t5_tf_model_config.json')
model = TFAutoModelForSeq2SeqLM.from_pretrained('./tf_model/t5_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = TFAutoModelForSeq2SeqLM.from_pretrained('./tf_model/t5_tf_checkpoint.ckpt.index', from_tf=True, config=config)
...@@ -1488,7 +1491,7 @@ class TFAutoModelForSequenceClassification(object): ...@@ -1488,7 +1491,7 @@ class TFAutoModelForSequenceClassification(object):
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments: kwargs: (`optional`) Remaining dictionary of keyword arguments:
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attentions=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.TFPretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.TFPretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
...@@ -1497,8 +1500,8 @@ class TFAutoModelForSequenceClassification(object): ...@@ -1497,8 +1500,8 @@ class TFAutoModelForSequenceClassification(object):
model = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = TFAutoModelForSequenceClassification.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFAutoModelForSequenceClassification.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading model = TFAutoModelForSequenceClassification.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = TFAutoModelForSequenceClassification.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) model = TFAutoModelForSequenceClassification.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
...@@ -1652,7 +1655,7 @@ class TFAutoModelForQuestionAnswering(object): ...@@ -1652,7 +1655,7 @@ class TFAutoModelForQuestionAnswering(object):
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments: kwargs: (`optional`) Remaining dictionary of keyword arguments:
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attentions=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.TFPretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.TFPretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
...@@ -1661,8 +1664,8 @@ class TFAutoModelForQuestionAnswering(object): ...@@ -1661,8 +1664,8 @@ class TFAutoModelForQuestionAnswering(object):
model = TFAutoModelForQuestionAnswering.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = TFAutoModelForQuestionAnswering.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = TFAutoModelForQuestionAnswering.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFAutoModelForQuestionAnswering.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = TFAutoModelForQuestionAnswering.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading model = TFAutoModelForQuestionAnswering.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = TFAutoModelForQuestionAnswering.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config) model = TFAutoModelForQuestionAnswering.from_pretrained('./pt_model/bert_pytorch_model.bin', from_pt=True, config=config)
...@@ -1785,7 +1788,7 @@ class TFAutoModelForTokenClassification: ...@@ -1785,7 +1788,7 @@ class TFAutoModelForTokenClassification:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments: kwargs: (`optional`) Remaining dictionary of keyword arguments:
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attentions=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.TFPretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.TFPretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
...@@ -1794,8 +1797,8 @@ class TFAutoModelForTokenClassification: ...@@ -1794,8 +1797,8 @@ class TFAutoModelForTokenClassification:
model = TFAutoModelForTokenClassification.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = TFAutoModelForTokenClassification.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
model = TFAutoModelForTokenClassification.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFAutoModelForTokenClassification.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = TFAutoModelForTokenClassification.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading model = TFAutoModelForTokenClassification.from_pretrained('bert-base-uncased', output_attentions=True) # Update configuration during loading
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower)
config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json')
model = TFAutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = TFAutoModelForTokenClassification.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config)
......
...@@ -344,7 +344,7 @@ class TFTransformerBlock(tf.keras.layers.Layer): ...@@ -344,7 +344,7 @@ class TFTransformerBlock(tf.keras.layers.Layer):
sa_output = self.attention(x, x, x, attn_mask, head_mask, output_attentions, training=training) sa_output = self.attention(x, x, x, attn_mask, head_mask, output_attentions, training=training)
if output_attentions: if output_attentions:
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length) sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples
# assert type(sa_output) == tuple # assert type(sa_output) == tuple
sa_output = sa_output[0] sa_output = sa_output[0]
sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim) sa_output = self.sa_layer_norm(sa_output + x) # (bs, seq_length, dim)
......
...@@ -486,7 +486,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -486,7 +486,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
kwargs (remaining dictionary of keyword arguments, `optional`): kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded: automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
...@@ -506,8 +506,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -506,8 +506,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
# Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
model = TFBertModel.from_pretrained('./test/saved_model/') model = TFBertModel.from_pretrained('./test/saved_model/')
# Update configuration during loading. # Update configuration during loading.
model = TFBertModel.from_pretrained('bert-base-uncased', output_attention=True) model = TFBertModel.from_pretrained('bert-base-uncased', output_attentions=True)
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable). # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json') config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json')
model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config) model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config)
......
...@@ -732,7 +732,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -732,7 +732,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB. our S3 (faster). Should be set to :obj:`False` for checkpoints larger than 20GB.
kwargs (remaining dictionary of keyword arguments, `optional`): kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded: automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
...@@ -752,8 +752,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -752,8 +752,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable). # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
model = BertModel.from_pretrained('./test/saved_model/') model = BertModel.from_pretrained('./test/saved_model/')
# Update configuration during loading. # Update configuration during loading.
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
assert model.config.output_attention == True assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable). # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment