Unverified Commit c503a1c1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[ProphetNet] Bart-like Refactor (#10501)

* first step to refactor

* make all fast tests pass

* make all slow tests pass

* save intermediate

* correct cache

* finish PR

* make fp16 work
parent 6290169e
...@@ -92,6 +92,8 @@ class ProphetNetConfig(PretrainedConfig): ...@@ -92,6 +92,8 @@ class ProphetNetConfig(PretrainedConfig):
smoothing is performed. smoothing is performed.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
""" """
model_type = "prophetnet" model_type = "prophetnet"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
...@@ -119,6 +121,7 @@ class ProphetNetConfig(PretrainedConfig): ...@@ -119,6 +121,7 @@ class ProphetNetConfig(PretrainedConfig):
num_buckets=32, num_buckets=32,
relative_max_distance=128, relative_max_distance=128,
disable_ngram_loss=False, disable_ngram_loss=False,
gradient_checkpointing=False,
eps=0.0, eps=0.0,
use_cache=True, use_cache=True,
pad_token_id=0, pad_token_id=0,
...@@ -161,6 +164,9 @@ class ProphetNetConfig(PretrainedConfig): ...@@ -161,6 +164,9 @@ class ProphetNetConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
# 4 Training Args (should be removed soon)
self.gradient_checkpointing = gradient_checkpointing
@property @property
def num_attention_heads(self) -> int: def num_attention_heads(self) -> int:
return self.num_encoder_attention_heads return self.num_encoder_attention_heads
......
...@@ -243,7 +243,7 @@ class ProphetNetModelTester: ...@@ -243,7 +243,7 @@ class ProphetNetModelTester:
# There should be `num_layers` key value embeddings stored in decoder_past # There should be `num_layers` key value embeddings stored in decoder_past
self.parent.assertEqual(len(decoder_past), config.num_decoder_layers) self.parent.assertEqual(len(decoder_past), config.num_decoder_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
self.parent.assertEqual(len(decoder_past[0]), 2) # cross-attention + uni-directional self-attention self.parent.assertEqual(len(decoder_past[0]), 4) # cross-attention + uni-directional self-attention
def create_and_check_with_lm_head( def create_and_check_with_lm_head(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment