Unverified Commit 80468251 authored by Dan Jones's avatar Dan Jones Committed by GitHub
Browse files

Change BartLearnedPositionalEmbedding's forward method signature to support...


Change BartLearnedPositionalEmbedding's forward method signature to support Opacus training (#18486)

* changing BartLearnedPositionalEmbedding forward signature and references to it

* removing debugging dead code (thanks style checker)

* blackened modeling_bart file

* removing copy inconsistencies via make fix-copies

* changing references to copied signatures in Bart variants

* make fix-copies once more

* using expand over repeat (thanks @michaelbenayoun)

* expand instead of repeat for all model copies
Co-authored-by: default avatarDaniel Jones <jonesdaniel@microsoft.com>
parent 3f0707b2
...@@ -128,12 +128,14 @@ class BartLearnedPositionalEmbedding(nn.Embedding): ...@@ -128,12 +128,14 @@ class BartLearnedPositionalEmbedding(nn.Embedding):
self.offset = 2 self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim) super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen].""" """`input_ids' shape is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
bsz, seq_len = input_ids.shape[:2]
positions = torch.arange( positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
) ).expand(bsz, -1)
return super().forward(positions + self.offset) return super().forward(positions + self.offset)
...@@ -788,17 +790,17 @@ class BartEncoder(BartPretrainedModel): ...@@ -788,17 +790,17 @@ class BartEncoder(BartPretrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input = input_ids
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_ids.shape[-1])
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input = inputs_embeds[:, :, -1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
...@@ -1015,10 +1017,12 @@ class BartDecoder(BartPretrainedModel): ...@@ -1015,10 +1017,12 @@ class BartDecoder(BartPretrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input = input_ids
input_shape = input.shape
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
...@@ -1026,7 +1030,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -1026,7 +1030,7 @@ class BartDecoder(BartPretrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask( attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
...@@ -1038,7 +1042,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -1038,7 +1042,7 @@ class BartDecoder(BartPretrainedModel):
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
......
...@@ -134,12 +134,14 @@ class MBartLearnedPositionalEmbedding(nn.Embedding): ...@@ -134,12 +134,14 @@ class MBartLearnedPositionalEmbedding(nn.Embedding):
self.offset = 2 self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim) super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen].""" """`input_ids' shape is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
bsz, seq_len = input_ids.shape[:2]
positions = torch.arange( positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
) ).expand(bsz, -1)
return super().forward(positions + self.offset) return super().forward(positions + self.offset)
...@@ -783,17 +785,18 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -783,17 +785,18 @@ class MBartEncoder(MBartPreTrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input = input_ids
input_shape = input.shape
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input = inputs_embeds[:, :, -1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
...@@ -1013,10 +1016,12 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1013,10 +1016,12 @@ class MBartDecoder(MBartPreTrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input = input_ids
input_shape = input.size()
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
...@@ -1036,7 +1041,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1036,7 +1041,7 @@ class MBartDecoder(MBartPreTrainedModel):
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
......
...@@ -134,12 +134,14 @@ class MvpLearnedPositionalEmbedding(nn.Embedding): ...@@ -134,12 +134,14 @@ class MvpLearnedPositionalEmbedding(nn.Embedding):
self.offset = 2 self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim) super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen].""" """`input_ids' shape is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
bsz, seq_len = input_ids.shape[:2]
positions = torch.arange( positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
) ).expand(bsz, -1)
return super().forward(positions + self.offset) return super().forward(positions + self.offset)
...@@ -895,17 +897,19 @@ class MvpEncoder(MvpPreTrainedModel): ...@@ -895,17 +897,19 @@ class MvpEncoder(MvpPreTrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input = input_ids
input_shape = input.shape
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
...@@ -1144,10 +1148,12 @@ class MvpDecoder(MvpPreTrainedModel): ...@@ -1144,10 +1148,12 @@ class MvpDecoder(MvpPreTrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input = input_ids
input_shape = input_ids.shape
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
...@@ -1167,7 +1173,7 @@ class MvpDecoder(MvpPreTrainedModel): ...@@ -1167,7 +1173,7 @@ class MvpDecoder(MvpPreTrainedModel):
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
......
...@@ -131,12 +131,14 @@ class PLBartLearnedPositionalEmbedding(nn.Embedding): ...@@ -131,12 +131,14 @@ class PLBartLearnedPositionalEmbedding(nn.Embedding):
self.offset = 2 self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim) super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen].""" """`input_ids' shape is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
bsz, seq_len = input_ids.shape[:2]
positions = torch.arange( positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
) ).expand(bsz, -1)
return super().forward(positions + self.offset) return super().forward(positions + self.offset)
...@@ -759,17 +761,17 @@ class PLBartEncoder(PLBartPreTrainedModel): ...@@ -759,17 +761,17 @@ class PLBartEncoder(PLBartPreTrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input = input_ids
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_ids.shape[-1])
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input = inputs_embeds[:, :, -1]
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input)
hidden_states = inputs_embeds + embed_pos hidden_states = inputs_embeds + embed_pos
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
...@@ -987,10 +989,12 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -987,10 +989,12 @@ class PLBartDecoder(PLBartPreTrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input = input_ids
input_shape = input.shape
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
...@@ -998,7 +1002,7 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -998,7 +1002,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input) * self.embed_scale
attention_mask = self._prepare_decoder_attention_mask( attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
...@@ -1010,7 +1014,7 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -1010,7 +1014,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]) encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input, past_key_values_length)
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states) hidden_states = self.layernorm_embedding(hidden_states)
......
...@@ -87,12 +87,14 @@ class TrOCRLearnedPositionalEmbedding(nn.Embedding): ...@@ -87,12 +87,14 @@ class TrOCRLearnedPositionalEmbedding(nn.Embedding):
self.offset = 2 self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim) super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen].""" """`input_ids' shape is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
bsz, seq_len = input_ids.shape[:2]
positions = torch.arange( positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
) ).expand(bsz, -1)
return super().forward(positions + self.offset) return super().forward(positions + self.offset)
...@@ -626,10 +628,11 @@ class TrOCRDecoder(TrOCRPreTrainedModel): ...@@ -626,10 +628,11 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
input_shape = input_ids.size() input = input_ids
input_ids = input_ids.view(-1, input_shape[-1]) input_ids = input_ids.view(-1, input.shape[-1])
elif inputs_embeds is not None: elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1] input_shape = inputs_embeds.size()[:-1]
input = inputs_embeds[:, :, -1]
else: else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
...@@ -640,7 +643,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel): ...@@ -640,7 +643,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
if self.config.use_learned_position_embeddings: if self.config.use_learned_position_embeddings:
embed_pos = self.embed_positions(input_shape, past_key_values_length=past_key_values_length) embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length)
else: else:
embed_pos = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) embed_pos = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
...@@ -651,6 +654,8 @@ class TrOCRDecoder(TrOCRPreTrainedModel): ...@@ -651,6 +654,8 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
input_shape = input.shape
attention_mask = self._prepare_decoder_attention_mask( attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment