Unverified Commit da5bb292 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

send model to the correct device (#18800)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent f1fd4606
......@@ -403,6 +403,7 @@ class EncoderDecoderMixin:
)
model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.to(torch_device)
model.train()
model.gradient_checkpointing_enable()
model.config.decoder_start_token_id = 0
......
......@@ -331,6 +331,7 @@ class EncoderDecoderMixin:
)
model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.to(torch_device)
model.train()
model.gradient_checkpointing_enable()
model.config.decoder_start_token_id = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment