Unverified Commit fb36c273 authored by guillaume-be's avatar guillaume-be Committed by GitHub
Browse files

Allow text generation for ProphetNetForCausalLM (#9707)

* Moved ProphetNetForCausalLM's parent initialization after config update

* Added unit tests for generation for ProphetNetForCausalLM
parent 910aa896
...@@ -1883,11 +1883,11 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1883,11 +1883,11 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
) )
class ProphetNetForCausalLM(ProphetNetPreTrainedModel): class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config)
# set config for CLM # set config for CLM
config = copy.deepcopy(config) config = copy.deepcopy(config)
config.is_decoder = True config.is_decoder = True
config.is_encoder_decoder = False config.is_encoder_decoder = False
super().__init__(config)
self.prophetnet = ProphetNetDecoderWrapper(config) self.prophetnet = ProphetNetDecoderWrapper(config)
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
......
...@@ -302,6 +302,24 @@ class ProphetNetModelTester: ...@@ -302,6 +302,24 @@ class ProphetNetModelTester:
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def create_and_check_decoder_generate_with_past_key_value_states(
self,
config,
input_ids,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
lm_labels,
):
model = ProphetNetForCausalLM(config=config).to(torch_device).eval()
torch.manual_seed(0)
output_without_past_cache = model.generate(
input_ids[:1], num_beams=2, max_length=10, do_sample=True, use_cache=False
)
torch.manual_seed(0)
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=10, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def create_and_check_model_fp16_forward( def create_and_check_model_fp16_forward(
self, self,
config, config,
...@@ -911,6 +929,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test ...@@ -911,6 +929,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_generate_with_past_key_value_states(*config_and_inputs) self.model_tester.create_and_check_generate_with_past_key_value_states(*config_and_inputs)
def test_encoder_decoder_model_generate(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_generate_with_past_key_value_states(*config_and_inputs)
def test_attn_mask_model(self): def test_attn_mask_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_model_with_attn_mask(*config_and_inputs) self.model_tester.check_model_with_attn_mask(*config_and_inputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment