"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "50dd314d939a86f3a81e19af01459f449fbaeeca"
Unverified Commit afb50c66 authored by LSinev's avatar LSinev Committed by GitHub
Browse files

Fix GPT2DoubleHeadsModel to work with model.generate() (#6601)

* Fix passing token_type_ids during GPT2DoubleHeadsModel.generate() if used

and for GPT2LMHeadModel too

* Update tests to check token_type_ids usage in GPT2 models
parent 04d8136b
...@@ -144,6 +144,10 @@ class GenerationMixin: ...@@ -144,6 +144,10 @@ class GenerationMixin:
) )
input_ids = input_ids.index_select(0, expanded_return_idx) input_ids = input_ids.index_select(0, expanded_return_idx)
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
if attention_mask is not None: if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
...@@ -194,6 +198,11 @@ class GenerationMixin: ...@@ -194,6 +198,11 @@ class GenerationMixin:
else: else:
model_kwargs["past"] = None model_kwargs["past"] = None
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
# update attention mask # update attention mask
if not is_encoder_decoder: if not is_encoder_decoder:
if "attention_mask" in model_kwargs: if "attention_mask" in model_kwargs:
......
...@@ -708,9 +708,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -708,9 +708,12 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
return self.lm_head return self.lm_head
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None) attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None) position_ids = kwargs.get("position_ids", None)
...@@ -729,6 +732,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -729,6 +732,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
"position_ids": position_ids, "position_ids": position_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"token_type_ids": token_type_ids,
} }
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
...@@ -836,14 +840,32 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -836,14 +840,32 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
return self.lm_head return self.lm_head
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"past_key_values": past, "past_key_values": past,
"use_cache": kwargs.get("use_cache"), "use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
} }
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
......
...@@ -469,10 +469,84 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -469,10 +469,84 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
] ]
inputs = tokenizer(sentences, return_tensors="pt", padding=True) inputs = tokenizer(sentences, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(torch_device)
token_type_ids = torch.cat(
[
input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0),
input_ids.new_full((input_ids.shape[0], 1), 500),
],
dim=-1,
)
outputs = model.generate(
input_ids=input_ids,
attention_mask=inputs["attention_mask"].to(torch_device),
)
outputs_tt = model.generate(
input_ids=input_ids,
attention_mask=inputs["attention_mask"].to(torch_device),
token_type_ids=token_type_ids,
)
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
output_non_padded = model.generate(input_ids=inputs_non_padded)
num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
expected_output_sentence = [
"Hello, my dog is a little bit of a mess. I'm not sure if he's going",
"Today, I'm going to be doing a lot of research on this. I",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
@slow
def test_batch_generation_2heads(self):
model = GPT2DoubleHeadsModel.from_pretrained("gpt2")
model.to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.padding_side = "left"
# This tokenizer has no pad token, so we have to set it in some way
# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# use different length sentences to test batching
sentences = [
"Hello, my dog is a little",
"Today, I",
]
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"].to(torch_device)
token_type_ids = torch.cat(
[
input_ids.new_full((input_ids.shape[0], input_ids.shape[1] - 1), 0),
input_ids.new_full((input_ids.shape[0], 1), 500),
],
dim=-1,
)
outputs = model.generate( outputs = model.generate(
input_ids=inputs["input_ids"].to(torch_device), input_ids=input_ids,
attention_mask=inputs["attention_mask"].to(torch_device),
)
outputs_tt = model.generate(
input_ids=input_ids,
attention_mask=inputs["attention_mask"].to(torch_device), attention_mask=inputs["attention_mask"].to(torch_device),
token_type_ids=token_type_ids,
) )
inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
...@@ -483,6 +557,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -483,6 +557,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)
batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
batch_out_sentence_tt = tokenizer.batch_decode(outputs_tt, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)
...@@ -491,6 +566,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -491,6 +566,7 @@ class GPT2ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
"Today, I'm going to be doing a lot of research on this. I", "Today, I'm going to be doing a lot of research on this. I",
] ]
self.assertListEqual(expected_output_sentence, batch_out_sentence) self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertTrue(batch_out_sentence_tt != batch_out_sentence) # token_type_ids should change output
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])
@slow @slow
...@@ -540,11 +616,23 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase): ...@@ -540,11 +616,23 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
model.to(torch_device) model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
input_ids = tokenizer("Today is a nice day and", return_tensors="pt").input_ids.to(torch_device) tokenized = tokenizer("Today is a nice day and", return_tensors="pt", return_token_type_ids=True)
input_ids = tokenized.input_ids.to(torch_device)
output_ids = model.generate(input_ids, do_sample=True) output_ids = model.generate(input_ids, do_sample=True)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True) output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)
token_type_ids = tokenized.token_type_ids.to(torch_device)
output_seq = model.generate(input_ids=input_ids, do_sample=True, num_return_sequences=5)
output_seq_tt = model.generate(
input_ids=input_ids, token_type_ids=token_type_ids, do_sample=True, num_return_sequences=5
)
output_seq_strs = tokenizer.batch_decode(output_seq, skip_special_tokens=True)
output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)
EXPECTED_OUTPUT_STR = ( EXPECTED_OUTPUT_STR = (
"Today is a nice day and if you don't know anything about the state of play during your holiday" "Today is a nice day and if you don't know anything about the state of play during your holiday"
) )
self.assertEqual(output_str, EXPECTED_OUTPUT_STR) self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
self.assertTrue(
all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
) # token_type_ids should change output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment