Unverified Commit 4bafc43b authored by Xu Song's avatar Xu Song Committed by GitHub
Browse files

Fix param error (#9273)

TypeError: forward() got an unexpected keyword argument 'token_type_ids'
parent 58e8a761
......@@ -130,7 +130,7 @@ def load_tf_weights_in_bert_generation(
class BertGenerationEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
"""Construct the embeddings from word and position embeddings."""
def __init__(self, config):
super().__init__()
......@@ -468,7 +468,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
>>> config.is_decoder = True
>>> model = BertGenerationDecoder.from_pretrained('google/bert_for_seq_generation_L-24_bbc_encoder', config=config)
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> inputs = tokenizer("Hello, my dog is cute", return_token_type_ids=False, return_tensors="pt")
>>> outputs = model(**inputs)
>>> prediction_logits = outputs.logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment