Unverified Commit 9a0a8c1c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

add examples to doc (#4045)

parent fa49b9af
...@@ -32,13 +32,30 @@ class EncoderDecoderConfig(PretrainedConfig): ...@@ -32,13 +32,30 @@ class EncoderDecoderConfig(PretrainedConfig):
and can be used to control the model outputs. and can be used to control the model outputs.
See the documentation for :class:`~transformers.PretrainedConfig` for more information. See the documentation for :class:`~transformers.PretrainedConfig` for more information.
Args:
Arguments: kwargs (`optional`):
kwargs: (`optional`) Remaining dictionary of keyword arguments. Notably: Remaining dictionary of keyword arguments. Notably:
encoder (:class:`PretrainedConfig`, optional, defaults to `None`): encoder (:class:`PretrainedConfig`, optional, defaults to `None`):
An instance of a configuration object that defines the encoder config. An instance of a configuration object that defines the encoder config.
encoder (:class:`PretrainedConfig`, optional, defaults to `None`): encoder (:class:`PretrainedConfig`, optional, defaults to `None`):
An instance of a configuration object that defines the decoder config. An instance of a configuration object that defines the decoder config.
Example::
from transformers import BertConfig, EncoderDecoderConfig, EncoderDecoderModel
# Initializing a BERT bert-base-uncased style configuration
config_encoder = BertConfig()
config_decoder = BertConfig()
config = EncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
# Initializing a Bert2Bert model from the bert-base-uncased style configurations
model = EncoderDecoderModel(config=config)
# Accessing the model configuration
config_encoder = model.config.encoder
config_decoder = model.config.decoder
""" """
model_type = "encoder_decoder" model_type = "encoder_decoder"
......
...@@ -125,6 +125,8 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -125,6 +125,8 @@ class EncoderDecoderModel(PreTrainedModel):
Examples:: Examples::
from tranformers import EncoderDecoder
model = EncoderDecoder.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert model = EncoderDecoder.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
""" """
...@@ -230,6 +232,25 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -230,6 +232,25 @@ class EncoderDecoderModel(PreTrainedModel):
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function. - With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
Examples::
from transformers import EncoderDecoderModel, BertTokenizer
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
# forward
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
# training
loss, outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, lm_labels=input_ids)[:2]
# generation
generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
""" """
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment