"vscode:/vscode.git/clone" did not exist on "fa765202402499486efd1cb3484c5e70555479c2"
Commit d9d387af authored by thomwolf's avatar thomwolf
Browse files

clean up

parent b7141a1b
...@@ -127,12 +127,8 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -127,12 +127,8 @@ class PreTrainedSeq2seq(nn.Module):
decoder = decoder_model decoder = decoder_model
else: else:
kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc... kwargs.update(decoder_kwargs) # Replace encoder kwargs with decoder specific kwargs like config, state_dict, etc...
kwargs['is_decoder'] = True # Make sure the decoder will be an decoder kwargs['is_decoder'] = True # Make sure the decoder will be a decoder
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs) decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs)
else:
raise ValueError("Unrecognized model identifier in {}. Should contains one of "
"'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', "
"'xlm', 'roberta'".format(decoder_pretrained_model_name_or_path))
model = cls(encoder, decoder) model = cls(encoder, decoder)
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment