Unverified Commit aea7c5b0 authored by yssjtu's avatar yssjtu Committed by GitHub
Browse files

T5ForConditionalGeneration: enabling using past_key_values and labels in training (#13805)

* enabling using past_key_values together with labels when training in T5ForConditionalGeneration

* test

* Enable past_key_values in T5ForconditionalGeneration while training.

* delete comments
parent dac77981
......@@ -1593,15 +1593,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
# If decoding with past key value states, only the last tokens
# should be given as an input
if past_key_values is not None:
assert labels is None, "Decoder should not use cached key value states when training."
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
# Set device for model parallelism
if self.model_parallel:
torch.cuda.set_device(self.decoder.first_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment