Unverified Commit c6c9db3d authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix gradient checkpoint test in encoder-decoder (#20017)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent a6b77598
......@@ -618,8 +618,10 @@ class EncoderDecoderMixin:
)
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.train()
model.to(torch_device)
model.gradient_checkpointing_enable()
model.train()
model.config.decoder_start_token_id = 0
model.config.pad_token_id = 0
......@@ -629,6 +631,8 @@ class EncoderDecoderMixin:
"labels": inputs_dict["labels"],
"decoder_input_ids": inputs_dict["decoder_input_ids"],
}
model_inputs = {k: v.to(torch_device) for k, v in model_inputs.items()}
loss = model(**model_inputs).loss
loss.backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment