"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d3046dad809b7b10019b142ae20b49fb58d21c28"
Unverified Commit 38a4bf79 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Encoder-decoder models: move embedding scale to nn.Module (#30410)



* move scaling to nn.Module

* let the test be here for now (need to fix)

* failing tests

* last failing models

* Revert commit 4c14817f38

* clean-up

* oops forgot

* codestyle

* raise NotImplemented when possible

* Update tests/test_modeling_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* skip tests in respective modeling files

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 9d31b32e
...@@ -132,6 +132,19 @@ class BartLearnedPositionalEmbedding(nn.Embedding): ...@@ -132,6 +132,19 @@ class BartLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions + self.offset) return super().forward(positions + self.offset)
class BartScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class BartAttention(nn.Module): class BartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
...@@ -1056,9 +1069,11 @@ class BartEncoder(BartPreTrainedModel): ...@@ -1056,9 +1069,11 @@ class BartEncoder(BartPreTrainedModel):
embed_dim = config.d_model embed_dim = config.d_model
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = BartScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -1146,7 +1161,7 @@ class BartEncoder(BartPreTrainedModel): ...@@ -1146,7 +1161,7 @@ class BartEncoder(BartPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input) embed_pos = self.embed_positions(input)
embed_pos = embed_pos.to(inputs_embeds.device) embed_pos = embed_pos.to(inputs_embeds.device)
...@@ -1238,9 +1253,11 @@ class BartDecoder(BartPreTrainedModel): ...@@ -1238,9 +1253,11 @@ class BartDecoder(BartPreTrainedModel):
self.layerdrop = config.decoder_layerdrop self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_tokens = BartScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -1369,7 +1386,7 @@ class BartDecoder(BartPreTrainedModel): ...@@ -1369,7 +1386,7 @@ class BartDecoder(BartPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale inputs_embeds = self.embed_tokens(input)
if self._use_flash_attention_2: if self._use_flash_attention_2:
# 2d mask is passed through the layers # 2d mask is passed through the layers
......
...@@ -90,6 +90,20 @@ class BigBirdPegasusLearnedPositionalEmbedding(nn.Embedding): ...@@ -90,6 +90,20 @@ class BigBirdPegasusLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions) return super().forward(positions)
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BigBirdPegasus
class BigBirdPegasusScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.big_bird.modeling_big_bird.BigBirdSelfAttention with BigBird->BigBirdPegasus # Copied from transformers.models.big_bird.modeling_big_bird.BigBirdSelfAttention with BigBird->BigBirdPegasus
class BigBirdPegasusSelfAttention(nn.Module): class BigBirdPegasusSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
...@@ -1749,9 +1763,11 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): ...@@ -1749,9 +1763,11 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
embed_dim = config.d_model embed_dim = config.d_model
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = BigBirdPegasusScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -1827,7 +1843,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): ...@@ -1827,7 +1843,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
...@@ -2042,9 +2058,11 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2042,9 +2058,11 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
self.layerdrop = config.decoder_layerdrop self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_tokens = BigBirdPegasusScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -2168,7 +2186,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): ...@@ -2168,7 +2186,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
...@@ -2292,7 +2310,10 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel): ...@@ -2292,7 +2310,10 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
super().__init__(config) super().__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = BigBirdPegasusScaledWordEmbedding(
vocab_size, config.d_model, padding_idx, embed_scale=embed_scale
)
self.encoder = BigBirdPegasusEncoder(config, self.shared) self.encoder = BigBirdPegasusEncoder(config, self.shared)
self.decoder = BigBirdPegasusDecoder(config, self.shared) self.decoder = BigBirdPegasusDecoder(config, self.shared)
......
...@@ -75,6 +75,20 @@ class BioGptLearnedPositionalEmbedding(nn.Embedding): ...@@ -75,6 +75,20 @@ class BioGptLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions + self.offset) return super().forward(positions + self.offset)
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BioGpt
class BioGptScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BioGpt # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BioGpt
class BioGptAttention(nn.Module): class BioGptAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
...@@ -423,9 +437,11 @@ class BioGptModel(BioGptPreTrainedModel): ...@@ -423,9 +437,11 @@ class BioGptModel(BioGptPreTrainedModel):
self.dropout = config.hidden_dropout_prob self.dropout = config.hidden_dropout_prob
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, self.embed_dim, self.padding_idx) self.embed_tokens = BioGptScaledWordEmbedding(
config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim) self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
self.layers = nn.ModuleList([BioGptDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([BioGptDecoderLayer(config) for _ in range(config.num_hidden_layers)])
...@@ -482,7 +498,7 @@ class BioGptModel(BioGptPreTrainedModel): ...@@ -482,7 +498,7 @@ class BioGptModel(BioGptPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale inputs_embeds = self.embed_tokens(input)
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones( attention_mask = torch.ones(
......
...@@ -90,6 +90,20 @@ class BlenderbotLearnedPositionalEmbedding(nn.Embedding): ...@@ -90,6 +90,20 @@ class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions) return super().forward(positions)
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->Blenderbot
class BlenderbotScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot
class BlenderbotAttention(nn.Module): class BlenderbotAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
...@@ -632,12 +646,14 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): ...@@ -632,12 +646,14 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
embed_dim = config.d_model embed_dim = config.d_model
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
else: else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = BlenderbotScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = BlenderbotLearnedPositionalEmbedding( self.embed_positions = BlenderbotLearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
...@@ -715,7 +731,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): ...@@ -715,7 +731,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_shape) embed_pos = self.embed_positions(input_shape)
...@@ -799,12 +815,14 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -799,12 +815,14 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
self.layerdrop = config.decoder_layerdrop self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
else: else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_tokens = BlenderbotScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = BlenderbotLearnedPositionalEmbedding( self.embed_positions = BlenderbotLearnedPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
...@@ -926,7 +944,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -926,7 +944,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
......
...@@ -1325,6 +1325,11 @@ class BridgeTowerModel(BridgeTowerPreTrainedModel): ...@@ -1325,6 +1325,11 @@ class BridgeTowerModel(BridgeTowerPreTrainedModel):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
if inputs_embeds is not None and input_ids is None:
raise NotImplementedError(
"BridgeTowerModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
image_token_type_idx = image_token_type_idx if image_token_type_idx else 1 image_token_type_idx = image_token_type_idx if image_token_type_idx else 1
input_shape = input_ids.size() input_shape = input_ids.size()
......
...@@ -972,8 +972,7 @@ class FunnelBaseModel(FunnelPreTrainedModel): ...@@ -972,8 +972,7 @@ class FunnelBaseModel(FunnelPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# TODO: deal with head_mask # TODO: deal with head_mask
if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
inputs_embeds = self.embeddings(input_ids)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs_embeds, inputs_embeds,
...@@ -1048,8 +1047,7 @@ class FunnelModel(FunnelPreTrainedModel): ...@@ -1048,8 +1047,7 @@ class FunnelModel(FunnelPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# TODO: deal with head_mask # TODO: deal with head_mask
if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
inputs_embeds = self.embeddings(input_ids)
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs_embeds, inputs_embeds,
......
...@@ -920,6 +920,10 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel): ...@@ -920,6 +920,10 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
device = self.position_embeddings.weight.device device = self.position_embeddings.weight.device
if input_ids is None: if input_ids is None:
input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None
if inputs_embeds is not None:
raise NotImplementedError(
"GPTSanJapaneseModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
)
num_pasts_contexts = 0 num_pasts_contexts = 0
num_batch = input_ids.shape[0] num_batch = input_ids.shape[0]
pasts_or_spout_value = None pasts_or_spout_value = None
......
...@@ -87,6 +87,20 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l ...@@ -87,6 +87,20 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
return incremental_indices.long() + padding_idx return incremental_indices.long() + padding_idx
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->M2M100
class M2M100ScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class M2M100SinusoidalPositionalEmbedding(nn.Module): class M2M100SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -886,9 +900,11 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -886,9 +900,11 @@ class M2M100Encoder(M2M100PreTrainedModel):
embed_dim = config.d_model embed_dim = config.d_model
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = M2M100ScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -971,7 +987,7 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -971,7 +987,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids, inputs_embeds) embed_pos = self.embed_positions(input_ids, inputs_embeds)
embed_pos = embed_pos.to(inputs_embeds.device) embed_pos = embed_pos.to(inputs_embeds.device)
...@@ -1061,9 +1077,11 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -1061,9 +1077,11 @@ class M2M100Decoder(M2M100PreTrainedModel):
self.layerdrop = config.decoder_layerdrop self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_tokens = M2M100ScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -1183,7 +1201,7 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -1183,7 +1201,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2: if self._use_flash_attention_2:
# 2d mask is passed through the layers # 2d mask is passed through the layers
...@@ -1321,7 +1339,8 @@ class M2M100Model(M2M100PreTrainedModel): ...@@ -1321,7 +1339,8 @@ class M2M100Model(M2M100PreTrainedModel):
super().__init__(config) super().__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = M2M100ScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
self.encoder = M2M100Encoder(config, self.shared) self.encoder = M2M100Encoder(config, self.shared)
self.decoder = M2M100Decoder(config, self.shared) self.decoder = M2M100Decoder(config, self.shared)
......
...@@ -118,6 +118,20 @@ class MBartLearnedPositionalEmbedding(nn.Embedding): ...@@ -118,6 +118,20 @@ class MBartLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions + self.offset) return super().forward(positions + self.offset)
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->MBart
class MBartScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart
class MBartAttention(nn.Module): class MBartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
...@@ -919,9 +933,11 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -919,9 +933,11 @@ class MBartEncoder(MBartPreTrainedModel):
embed_dim = config.d_model embed_dim = config.d_model
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = MBartScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -1009,7 +1025,7 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -1009,7 +1025,7 @@ class MBartEncoder(MBartPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input) embed_pos = self.embed_positions(input)
...@@ -1097,9 +1113,11 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1097,9 +1113,11 @@ class MBartDecoder(MBartPreTrainedModel):
self.layerdrop = config.decoder_layerdrop self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_tokens = MBartScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -1227,7 +1245,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1227,7 +1245,7 @@ class MBartDecoder(MBartPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2: if self._use_flash_attention_2:
# 2d mask is passed through the layers # 2d mask is passed through the layers
......
...@@ -133,6 +133,20 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T ...@@ -133,6 +133,20 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T
return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2) return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->NllbMoe
class NllbMoeScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding
class NllbMoeSinusoidalPositionalEmbedding(nn.Module): class NllbMoeSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -992,9 +1006,11 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): ...@@ -992,9 +1006,11 @@ class NllbMoeEncoder(NllbMoePreTrainedModel):
embed_dim = config.d_model embed_dim = config.d_model
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = NllbMoeScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -1085,7 +1101,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel): ...@@ -1085,7 +1101,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids, inputs_embeds) embed_pos = self.embed_positions(input_ids, inputs_embeds)
embed_pos = embed_pos.to(inputs_embeds.device) embed_pos = embed_pos.to(inputs_embeds.device)
...@@ -1178,9 +1194,11 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): ...@@ -1178,9 +1194,11 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
self.layerdrop = config.decoder_layerdrop self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_tokens = NllbMoeScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -1309,7 +1327,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): ...@@ -1309,7 +1327,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
# create causal mask # create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
...@@ -1458,7 +1476,8 @@ class NllbMoeModel(NllbMoePreTrainedModel): ...@@ -1458,7 +1476,8 @@ class NllbMoeModel(NllbMoePreTrainedModel):
super().__init__(config) super().__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = NllbMoeScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
self.encoder = NllbMoeEncoder(config, self.shared) self.encoder = NllbMoeEncoder(config, self.shared)
self.decoder = NllbMoeDecoder(config, self.shared) self.decoder = NllbMoeDecoder(config, self.shared)
......
...@@ -87,6 +87,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start ...@@ -87,6 +87,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids return shifted_input_ids
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PegasusX
class PegasusXScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class PegasusXSinusoidalPositionalEmbedding(nn.Module): class PegasusXSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -880,13 +894,16 @@ class PegasusXEncoder(PegasusXPreTrainedModel): ...@@ -880,13 +894,16 @@ class PegasusXEncoder(PegasusXPreTrainedModel):
self.layerdrop = config.encoder_layerdrop self.layerdrop = config.encoder_layerdrop
embed_dim = config.d_model embed_dim = config.d_model
padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
else: else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim) self.embed_tokens = PegasusXScaledWordEmbedding(
config.vocab_size, embed_dim, padding_idx, embed_scale=embed_scale
)
self.embed_global = nn.Embedding(config.num_global_tokens, embed_dim) self.embed_global = nn.Embedding(config.num_global_tokens, embed_dim)
self.embed_positions = PegasusXSinusoidalPositionalEmbedding(embed_dim) self.embed_positions = PegasusXSinusoidalPositionalEmbedding(embed_dim)
...@@ -988,7 +1005,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel): ...@@ -988,7 +1005,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(inputs_embeds) embed_pos = self.embed_positions(inputs_embeds)
...@@ -1086,12 +1103,15 @@ class PegasusXDecoder(PegasusXPreTrainedModel): ...@@ -1086,12 +1103,15 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
self.dropout = config.dropout self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop self.layerdrop = config.decoder_layerdrop
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
padding_idx = config.pad_token_id
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
else: else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.embed_tokens = PegasusXScaledWordEmbedding(
config.vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale
)
self.embed_positions = PegasusXSinusoidalPositionalEmbedding(config.d_model) self.embed_positions = PegasusXSinusoidalPositionalEmbedding(config.d_model)
self.layers = nn.ModuleList([PegasusXDecoderLayer(config) for _ in range(config.decoder_layers)]) self.layers = nn.ModuleList([PegasusXDecoderLayer(config) for _ in range(config.decoder_layers)])
...@@ -1196,7 +1216,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel): ...@@ -1196,7 +1216,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
...@@ -1307,7 +1327,11 @@ class PegasusXModel(PegasusXPreTrainedModel): ...@@ -1307,7 +1327,11 @@ class PegasusXModel(PegasusXPreTrainedModel):
super().__init__(config) super().__init__(config)
vocab_size = config.vocab_size vocab_size = config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model) embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
padding_idx = config.pad_token_id
self.shared = PegasusXScaledWordEmbedding(
vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale
)
self.encoder = PegasusXEncoder(config, self.shared) self.encoder = PegasusXEncoder(config, self.shared)
self.decoder = PegasusXDecoder(config, self.shared) self.decoder = PegasusXDecoder(config, self.shared)
......
...@@ -102,6 +102,20 @@ class PLBartLearnedPositionalEmbedding(nn.Embedding): ...@@ -102,6 +102,20 @@ class PLBartLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions + self.offset) return super().forward(positions + self.offset)
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PLBart
class PLBartScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PLBart # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PLBart
class PLBartAttention(nn.Module): class PLBartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
...@@ -658,9 +672,11 @@ class PLBartEncoder(PLBartPreTrainedModel): ...@@ -658,9 +672,11 @@ class PLBartEncoder(PLBartPreTrainedModel):
embed_dim = config.d_model embed_dim = config.d_model
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = PLBartScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -748,7 +764,7 @@ class PLBartEncoder(PLBartPreTrainedModel): ...@@ -748,7 +764,7 @@ class PLBartEncoder(PLBartPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input) embed_pos = self.embed_positions(input)
embed_pos = embed_pos.to(inputs_embeds.device) embed_pos = embed_pos.to(inputs_embeds.device)
...@@ -841,9 +857,11 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -841,9 +857,11 @@ class PLBartDecoder(PLBartPreTrainedModel):
self.layerdrop = config.decoder_layerdrop self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_tokens = PLBartScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -972,7 +990,7 @@ class PLBartDecoder(PLBartPreTrainedModel): ...@@ -972,7 +990,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale inputs_embeds = self.embed_tokens(input)
if self._use_flash_attention_2: if self._use_flash_attention_2:
# 2d mask is passed through the layers # 2d mask is passed through the layers
...@@ -1122,7 +1140,8 @@ class PLBartModel(PLBartPreTrainedModel): ...@@ -1122,7 +1140,8 @@ class PLBartModel(PLBartPreTrainedModel):
super().__init__(config) super().__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx) embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
self.encoder = PLBartEncoder(config, self.shared) self.encoder = PLBartEncoder(config, self.shared)
self.decoder = PLBartDecoder(config, self.shared) self.decoder = PLBartDecoder(config, self.shared)
......
...@@ -989,6 +989,20 @@ class SeamlessM4TConformerAdapter(nn.Module): ...@@ -989,6 +989,20 @@ class SeamlessM4TConformerAdapter(nn.Module):
############ TEXT / UNITS related code ################ ############ TEXT / UNITS related code ################
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->SeamlessM4T
class SeamlessM4TScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding
class SeamlessM4TSinusoidalPositionalEmbedding(nn.Module): class SeamlessM4TSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -1631,9 +1645,11 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): ...@@ -1631,9 +1645,11 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
if not self.is_t2u_encoder: if not self.is_t2u_encoder:
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = SeamlessM4TScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -1726,7 +1742,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): ...@@ -1726,7 +1742,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
if not self.is_t2u_encoder: if not self.is_t2u_encoder:
embed_pos = self.embed_positions(input) embed_pos = self.embed_positions(input)
...@@ -1809,14 +1825,18 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): ...@@ -1809,14 +1825,18 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
if embed_tokens is not None: if embed_tokens is not None:
# if embed_tokens defined, use its shape instead # if embed_tokens defined, use its shape instead
self.embed_tokens = nn.Embedding(embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx) self.embed_tokens = SeamlessM4TScaledWordEmbedding(
embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx, embed_scale=embed_scale
)
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
else: else:
self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = SeamlessM4TScaledWordEmbedding(
self.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding( self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding(
self.max_target_positions, self.max_target_positions,
...@@ -1935,7 +1955,7 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): ...@@ -1935,7 +1955,7 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
......
...@@ -946,6 +946,20 @@ class SeamlessM4Tv2ConformerAdapter(nn.Module): ...@@ -946,6 +946,20 @@ class SeamlessM4Tv2ConformerAdapter(nn.Module):
############ TEXT / UNITS related code ################ ############ TEXT / UNITS related code ################
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->SeamlessM4Tv2
class SeamlessM4Tv2ScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding # Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding
class SeamlessM4Tv2SinusoidalPositionalEmbedding(nn.Module): class SeamlessM4Tv2SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -1753,9 +1767,11 @@ class SeamlessM4Tv2Encoder(SeamlessM4Tv2PreTrainedModel): ...@@ -1753,9 +1767,11 @@ class SeamlessM4Tv2Encoder(SeamlessM4Tv2PreTrainedModel):
self.max_source_positions = config.max_position_embeddings self.max_source_positions = config.max_position_embeddings
if not self.is_t2u_encoder: if not self.is_t2u_encoder:
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0 embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx) self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
...@@ -1848,7 +1864,7 @@ class SeamlessM4Tv2Encoder(SeamlessM4Tv2PreTrainedModel): ...@@ -1848,7 +1864,7 @@ class SeamlessM4Tv2Encoder(SeamlessM4Tv2PreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
if not self.is_t2u_encoder: if not self.is_t2u_encoder:
embed_pos = self.embed_positions(input) embed_pos = self.embed_positions(input)
...@@ -1932,14 +1948,18 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel): ...@@ -1932,14 +1948,18 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
if embed_tokens is not None: if embed_tokens is not None:
# if embed_tokens defined, use its shape instead # if embed_tokens defined, use its shape instead
self.embed_tokens = nn.Embedding(embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx) self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding(
embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx, embed_scale=embed_scale
)
self.embed_tokens.weight = embed_tokens.weight self.embed_tokens.weight = embed_tokens.weight
else: else:
self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding(
self.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = SeamlessM4Tv2SinusoidalPositionalEmbedding( self.embed_positions = SeamlessM4Tv2SinusoidalPositionalEmbedding(
self.max_target_positions, self.max_target_positions,
...@@ -2058,7 +2078,7 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel): ...@@ -2058,7 +2078,7 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
......
...@@ -63,6 +63,20 @@ class TrOCRLearnedPositionalEmbedding(nn.Embedding): ...@@ -63,6 +63,20 @@ class TrOCRLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions + self.offset) return super().forward(positions + self.offset)
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->TrOCR
class TrOCRScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class TrOCRSinusoidalPositionalEmbedding(nn.Module): class TrOCRSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -451,9 +465,11 @@ class TrOCRDecoder(TrOCRPreTrainedModel): ...@@ -451,9 +465,11 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
self.dropout = config.dropout self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = TrOCRScaledWordEmbedding(
config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale
)
if config.use_learned_position_embeddings: if config.use_learned_position_embeddings:
self.embed_positions = TrOCRLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) self.embed_positions = TrOCRLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
...@@ -584,7 +600,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel): ...@@ -584,7 +600,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
if self.config.use_learned_position_embeddings: if self.config.use_learned_position_embeddings:
embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length) embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length)
......
...@@ -127,6 +127,20 @@ XGLM_INPUTS_DOCSTRING = r""" ...@@ -127,6 +127,20 @@ XGLM_INPUTS_DOCSTRING = r"""
""" """
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->XGLM
class XGLMScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class XGLMSinusoidalPositionalEmbedding(nn.Module): class XGLMSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.""" """This module produces sinusoidal positional embeddings of any length."""
...@@ -490,12 +504,14 @@ class XGLMModel(XGLMPreTrainedModel): ...@@ -490,12 +504,14 @@ class XGLMModel(XGLMPreTrainedModel):
self.layerdrop = config.layerdrop self.layerdrop = config.layerdrop
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None: if embed_tokens is not None:
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
else: else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx) self.embed_tokens = XGLMScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = XGLMSinusoidalPositionalEmbedding( self.embed_positions = XGLMSinusoidalPositionalEmbedding(
config.max_position_embeddings, config.max_position_embeddings,
...@@ -568,7 +584,7 @@ class XGLMModel(XGLMPreTrainedModel): ...@@ -568,7 +584,7 @@ class XGLMModel(XGLMPreTrainedModel):
position_ids = position_ids.unsqueeze(0) position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
......
...@@ -167,6 +167,10 @@ class AlignVisionModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -167,6 +167,10 @@ class AlignVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="AlignVisionModel does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="AlignVisionModel does not support input and output embeddings") @unittest.skip(reason="AlignVisionModel does not support input and output embeddings")
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
...@@ -379,6 +383,10 @@ class AlignTextModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -379,6 +383,10 @@ class AlignTextModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="Align does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="AlignTextModel has no base class and is not available in MODEL_MAPPING") @unittest.skip(reason="AlignTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
...@@ -473,6 +481,10 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -473,6 +481,10 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="Align does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests") @unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self): def test_retain_grad_hidden_states_attentions(self):
pass pass
......
...@@ -579,6 +579,29 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te ...@@ -579,6 +579,29 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te
with torch.no_grad(): with torch.no_grad():
model(**inputs)[0] model(**inputs)[0]
# override as the input arg is called "input_embeds", not "inputs_embeds"
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
with torch.no_grad():
out_ids = model(**inputs)[0]
input_ids = inputs["input_ids"]
del inputs["input_ids"]
wte = model.get_input_embeddings()
inputs["input_embeds"] = wte(input_ids)
with torch.no_grad():
out_embeds = model(**inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@require_torch_fp16 @require_torch_fp16
def test_generate_fp16(self): def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs() config, input_dict = self.model_tester.prepare_config_and_inputs()
...@@ -645,6 +668,29 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test ...@@ -645,6 +668,29 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
with torch.no_grad(): with torch.no_grad():
model(**inputs)[0] model(**inputs)[0]
# override as the input arg is called "input_embeds", not "inputs_embeds"
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
with torch.no_grad():
out_ids = model(**inputs)[0]
input_ids = inputs["input_ids"]
del inputs["input_ids"]
wte = model.get_input_embeddings()
inputs["input_embeds"] = wte(input_ids)
with torch.no_grad():
out_embeds = model(**inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@require_torch_fp16 @require_torch_fp16
def test_generate_fp16(self): def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs() config, input_dict = self.model_tester.prepare_config_and_inputs()
...@@ -709,6 +755,10 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -709,6 +755,10 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
model(**inputs)[0] model(**inputs)[0]
@unittest.skip("FineModel relies on codebook idx and does not return same logits")
def test_inputs_embeds_matches_input_ids(self):
pass
@require_torch_fp16 @require_torch_fp16
def test_generate_fp16(self): def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs() config, input_dict = self.model_tester.prepare_config_and_inputs()
......
...@@ -506,6 +506,10 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC ...@@ -506,6 +506,10 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
def test_inputs_embeds(self): def test_inputs_embeds(self):
pass pass
@unittest.skip(reason="Bridge Tower does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
# We will verify our results on an image of cute cats # We will verify our results on an image of cute cats
def prepare_img(): def prepare_img():
......
...@@ -502,6 +502,10 @@ class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): ...@@ -502,6 +502,10 @@ class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# ViT does not use inputs_embeds # ViT does not use inputs_embeds
pass pass
@unittest.skip(reason="Canine Tower does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip("CANINE does not have a get_input_embeddings() method.") @unittest.skip("CANINE does not have a get_input_embeddings() method.")
def test_model_common_attributes(self): def test_model_common_attributes(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment