Unverified Commit 607611f2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

up (#13448)

parent 6b29bff8
...@@ -201,6 +201,7 @@ class EncoderDecoderMixin: ...@@ -201,6 +201,7 @@ class EncoderDecoderMixin:
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
enc_dec_model.save_pretrained(tmpdirname) enc_dec_model.save_pretrained(tmpdirname)
enc_dec_model = EncoderDecoderModel.from_pretrained(tmpdirname) enc_dec_model = EncoderDecoderModel.from_pretrained(tmpdirname)
enc_dec_model.to(torch_device)
after_outputs = enc_dec_model( after_outputs = enc_dec_model(
input_ids=input_ids, input_ids=input_ids,
...@@ -245,6 +246,7 @@ class EncoderDecoderMixin: ...@@ -245,6 +246,7 @@ class EncoderDecoderMixin:
encoder_pretrained_model_name_or_path=encoder_tmp_dirname, encoder_pretrained_model_name_or_path=encoder_tmp_dirname,
decoder_pretrained_model_name_or_path=decoder_tmp_dirname, decoder_pretrained_model_name_or_path=decoder_tmp_dirname,
) )
enc_dec_model.to(torch_device)
after_outputs = enc_dec_model( after_outputs = enc_dec_model(
input_ids=input_ids, input_ids=input_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment