Commit d9d387af authored by thomwolf's avatar thomwolf
Browse files

clean up

parent b7141a1b
......@@ -127,12 +127,8 @@ class PreTrainedSeq2seq(nn.Module):
decoder = decoder_model
else:
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
kwargs['is_decoder'] = True # Make sure the decoder will be an decoder
kwargs['is_decoder'] = True # Make sure the decoder will be a decoder
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
else:
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(decoder_pretrained_model_name_or_path))
model = cls(encoder, decoder)
return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment