Unverified Commit 38a4bf79 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Encoder-decoder models: move embedding scale to nn.Module (#30410)



* move scaling to nn.Module

* let the test be here for now (need to fix)

* failing tests

* last failing models

* Revert commit 4c14817f38

* clean-up

* oops forgot

* codestyle

* raise NotImplemented when possible

* Update tests/test_modeling_common.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* skip tests in respective modeling files

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 9d31b32e
......@@ -132,6 +132,19 @@ class BartLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions + self.offset)
class BartScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class BartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
......@@ -1056,9 +1069,11 @@ class BartEncoder(BartPreTrainedModel):
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens = BartScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -1146,7 +1161,7 @@ class BartEncoder(BartPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input)
embed_pos = embed_pos.to(inputs_embeds.device)
......@@ -1238,9 +1253,11 @@ class BartDecoder(BartPreTrainedModel):
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens = BartScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -1369,7 +1386,7 @@ class BartDecoder(BartPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale
inputs_embeds = self.embed_tokens(input)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
......
......@@ -90,6 +90,20 @@ class BigBirdPegasusLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions)
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BigBirdPegasus
class BigBirdPegasusScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.big_bird.modeling_big_bird.BigBirdSelfAttention with BigBird->BigBirdPegasus
class BigBirdPegasusSelfAttention(nn.Module):
def __init__(self, config):
......@@ -1749,9 +1763,11 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens = BigBirdPegasusScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -1827,7 +1843,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_shape)
......@@ -2042,9 +2058,11 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens = BigBirdPegasusScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -2168,7 +2186,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
......@@ -2292,7 +2310,10 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
super().__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = BigBirdPegasusScaledWordEmbedding(
vocab_size, config.d_model, padding_idx, embed_scale=embed_scale
)
self.encoder = BigBirdPegasusEncoder(config, self.shared)
self.decoder = BigBirdPegasusDecoder(config, self.shared)
......
......@@ -75,6 +75,20 @@ class BioGptLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions + self.offset)
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BioGpt
class BioGptScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BioGpt
class BioGptAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
......@@ -423,9 +437,11 @@ class BioGptModel(BioGptPreTrainedModel):
self.dropout = config.hidden_dropout_prob
self.embed_dim = config.hidden_size
self.padding_idx = config.pad_token_id
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, self.embed_dim, self.padding_idx)
self.embed_tokens = BioGptScaledWordEmbedding(
config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
self.layers = nn.ModuleList([BioGptDecoderLayer(config) for _ in range(config.num_hidden_layers)])
......@@ -482,7 +498,7 @@ class BioGptModel(BioGptPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale
inputs_embeds = self.embed_tokens(input)
if attention_mask is None:
attention_mask = torch.ones(
......
......@@ -90,6 +90,20 @@ class BlenderbotLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions)
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->Blenderbot
class BlenderbotScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot
class BlenderbotAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
......@@ -632,12 +646,14 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens = BlenderbotScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
config.max_position_embeddings,
......@@ -715,7 +731,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_shape)
......@@ -799,12 +815,14 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens = BlenderbotScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = BlenderbotLearnedPositionalEmbedding(
config.max_position_embeddings,
......@@ -926,7 +944,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
......
......@@ -1325,6 +1325,11 @@ class BridgeTowerModel(BridgeTowerPreTrainedModel):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
if inputs_embeds is not None and input_ids is None:
raise NotImplementedError(
"BridgeTowerModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
image_token_type_idx = image_token_type_idx if image_token_type_idx else 1
input_shape = input_ids.size()
......
......@@ -972,8 +972,7 @@ class FunnelBaseModel(FunnelPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# TODO: deal with head_mask
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
encoder_outputs = self.encoder(
inputs_embeds,
......@@ -1048,8 +1047,7 @@ class FunnelModel(FunnelPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# TODO: deal with head_mask
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
encoder_outputs = self.encoder(
inputs_embeds,
......
......@@ -920,6 +920,10 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
device = self.position_embeddings.weight.device
if input_ids is None:
input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None
if inputs_embeds is not None:
raise NotImplementedError(
"GPTSanJapaneseModel does not use `inputs_embeds`. Make sure to pass in `input_ids` instead."
)
num_pasts_contexts = 0
num_batch = input_ids.shape[0]
pasts_or_spout_value = None
......
......@@ -87,6 +87,20 @@ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_l
return incremental_indices.long() + padding_idx
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->M2M100
class M2M100ScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class M2M100SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
......@@ -886,9 +900,11 @@ class M2M100Encoder(M2M100PreTrainedModel):
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens = M2M100ScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -971,7 +987,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids, inputs_embeds)
embed_pos = embed_pos.to(inputs_embeds.device)
......@@ -1061,9 +1077,11 @@ class M2M100Decoder(M2M100PreTrainedModel):
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens = M2M100ScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -1183,7 +1201,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
......@@ -1321,7 +1339,8 @@ class M2M100Model(M2M100PreTrainedModel):
super().__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = M2M100ScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
self.encoder = M2M100Encoder(config, self.shared)
self.decoder = M2M100Decoder(config, self.shared)
......
......@@ -118,6 +118,20 @@ class MBartLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions + self.offset)
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->MBart
class MBartScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart
class MBartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
......@@ -919,9 +933,11 @@ class MBartEncoder(MBartPreTrainedModel):
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens = MBartScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -1009,7 +1025,7 @@ class MBartEncoder(MBartPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input)
......@@ -1097,9 +1113,11 @@ class MBartDecoder(MBartPreTrainedModel):
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens = MBartScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -1227,7 +1245,7 @@ class MBartDecoder(MBartPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
......
......@@ -133,6 +133,20 @@ def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.T
return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->NllbMoe
class NllbMoeScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding
class NllbMoeSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
......@@ -992,9 +1006,11 @@ class NllbMoeEncoder(NllbMoePreTrainedModel):
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens = NllbMoeScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -1085,7 +1101,7 @@ class NllbMoeEncoder(NllbMoePreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input_ids, inputs_embeds)
embed_pos = embed_pos.to(inputs_embeds.device)
......@@ -1178,9 +1194,11 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens = NllbMoeScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -1309,7 +1327,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
......@@ -1458,7 +1476,8 @@ class NllbMoeModel(NllbMoePreTrainedModel):
super().__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = NllbMoeScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
self.encoder = NllbMoeEncoder(config, self.shared)
self.decoder = NllbMoeDecoder(config, self.shared)
......
......@@ -87,6 +87,20 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
return shifted_input_ids
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PegasusX
class PegasusXScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class PegasusXSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
......@@ -880,13 +894,16 @@ class PegasusXEncoder(PegasusXPreTrainedModel):
self.layerdrop = config.encoder_layerdrop
embed_dim = config.d_model
padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim)
self.embed_tokens = PegasusXScaledWordEmbedding(
config.vocab_size, embed_dim, padding_idx, embed_scale=embed_scale
)
self.embed_global = nn.Embedding(config.num_global_tokens, embed_dim)
self.embed_positions = PegasusXSinusoidalPositionalEmbedding(embed_dim)
......@@ -988,7 +1005,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(inputs_embeds)
......@@ -1086,12 +1103,15 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
padding_idx = config.pad_token_id
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.embed_tokens = PegasusXScaledWordEmbedding(
config.vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale
)
self.embed_positions = PegasusXSinusoidalPositionalEmbedding(config.d_model)
self.layers = nn.ModuleList([PegasusXDecoderLayer(config) for _ in range(config.decoder_layers)])
......@@ -1196,7 +1216,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
......@@ -1307,7 +1327,11 @@ class PegasusXModel(PegasusXPreTrainedModel):
super().__init__(config)
vocab_size = config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model)
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
padding_idx = config.pad_token_id
self.shared = PegasusXScaledWordEmbedding(
vocab_size, config.d_model, padding_idx=padding_idx, embed_scale=embed_scale
)
self.encoder = PegasusXEncoder(config, self.shared)
self.decoder = PegasusXDecoder(config, self.shared)
......
......@@ -102,6 +102,20 @@ class PLBartLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions + self.offset)
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PLBart
class PLBartScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PLBart
class PLBartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
......@@ -658,9 +672,11 @@ class PLBartEncoder(PLBartPreTrainedModel):
embed_dim = config.d_model
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens = PLBartScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -748,7 +764,7 @@ class PLBartEncoder(PLBartPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
embed_pos = self.embed_positions(input)
embed_pos = embed_pos.to(inputs_embeds.device)
......@@ -841,9 +857,11 @@ class PLBartDecoder(PLBartPreTrainedModel):
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens = PLBartScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -972,7 +990,7 @@ class PLBartDecoder(PLBartPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input) * self.embed_scale
inputs_embeds = self.embed_tokens(input)
if self._use_flash_attention_2:
# 2d mask is passed through the layers
......@@ -1122,7 +1140,8 @@ class PLBartModel(PLBartPreTrainedModel):
super().__init__(config)
padding_idx, vocab_size = config.pad_token_id, config.vocab_size
self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
self.encoder = PLBartEncoder(config, self.shared)
self.decoder = PLBartDecoder(config, self.shared)
......
......@@ -989,6 +989,20 @@ class SeamlessM4TConformerAdapter(nn.Module):
############ TEXT / UNITS related code ################
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->SeamlessM4T
class SeamlessM4TScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding
class SeamlessM4TSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
......@@ -1631,9 +1645,11 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
self.max_source_positions = config.max_position_embeddings
if not self.is_t2u_encoder:
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens = SeamlessM4TScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -1726,7 +1742,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
if not self.is_t2u_encoder:
embed_pos = self.embed_positions(input)
......@@ -1809,14 +1825,18 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
if embed_tokens is not None:
# if embed_tokens defined, use its shape instead
self.embed_tokens = nn.Embedding(embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx)
self.embed_tokens = SeamlessM4TScaledWordEmbedding(
embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx, embed_scale=embed_scale
)
self.embed_tokens.weight = embed_tokens.weight
else:
self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx)
self.embed_tokens = SeamlessM4TScaledWordEmbedding(
self.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = SeamlessM4TSinusoidalPositionalEmbedding(
self.max_target_positions,
......@@ -1935,7 +1955,7 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
......
......@@ -946,6 +946,20 @@ class SeamlessM4Tv2ConformerAdapter(nn.Module):
############ TEXT / UNITS related code ################
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100ScaledWordEmbedding with M2M100->SeamlessM4Tv2
class SeamlessM4Tv2ScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
# Copied from transformers.models.m2m_100.modeling_m2m_100.M2M100SinusoidalPositionalEmbedding
class SeamlessM4Tv2SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
......@@ -1753,9 +1767,11 @@ class SeamlessM4Tv2Encoder(SeamlessM4Tv2PreTrainedModel):
self.max_source_positions = config.max_position_embeddings
if not self.is_t2u_encoder:
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, embed_dim, self.padding_idx)
self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding(
config.vocab_size, embed_dim, self.padding_idx, embed_scale=embed_scale
)
if embed_tokens is not None:
self.embed_tokens.weight = embed_tokens.weight
......@@ -1848,7 +1864,7 @@ class SeamlessM4Tv2Encoder(SeamlessM4Tv2PreTrainedModel):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
if not self.is_t2u_encoder:
embed_pos = self.embed_positions(input)
......@@ -1932,14 +1948,18 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
if embed_tokens is not None:
# if embed_tokens defined, use its shape instead
self.embed_tokens = nn.Embedding(embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx)
self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding(
embed_tokens.num_embeddings, embed_tokens.embedding_dim, self.padding_idx, embed_scale=embed_scale
)
self.embed_tokens.weight = embed_tokens.weight
else:
self.embed_tokens = nn.Embedding(self.vocab_size, config.hidden_size, self.padding_idx)
self.embed_tokens = SeamlessM4Tv2ScaledWordEmbedding(
self.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = SeamlessM4Tv2SinusoidalPositionalEmbedding(
self.max_target_positions,
......@@ -2058,7 +2078,7 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
......
......@@ -63,6 +63,20 @@ class TrOCRLearnedPositionalEmbedding(nn.Embedding):
return super().forward(positions + self.offset)
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->TrOCR
class TrOCRScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class TrOCRSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
......@@ -451,9 +465,11 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.embed_tokens = TrOCRScaledWordEmbedding(
config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=embed_scale
)
if config.use_learned_position_embeddings:
self.embed_positions = TrOCRLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
......@@ -584,7 +600,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
if self.config.use_learned_position_embeddings:
embed_pos = self.embed_positions(input, past_key_values_length=past_key_values_length)
......
......@@ -127,6 +127,20 @@ XGLM_INPUTS_DOCSTRING = r"""
"""
# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->XGLM
class XGLMScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.embed_scale = embed_scale
def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
class XGLMSinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length."""
......@@ -490,12 +504,14 @@ class XGLMModel(XGLMPreTrainedModel):
self.layerdrop = config.layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if embed_tokens is not None:
self.embed_tokens = embed_tokens
else:
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
self.embed_tokens = XGLMScaledWordEmbedding(
config.vocab_size, config.d_model, self.padding_idx, embed_scale=embed_scale
)
self.embed_positions = XGLMSinusoidalPositionalEmbedding(
config.max_position_embeddings,
......@@ -568,7 +584,7 @@ class XGLMModel(XGLMPreTrainedModel):
position_ids = position_ids.unsqueeze(0)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
inputs_embeds = self.embed_tokens(input_ids)
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
......
......@@ -167,6 +167,10 @@ class AlignVisionModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@unittest.skip(reason="AlignVisionModel does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="AlignVisionModel does not support input and output embeddings")
def test_model_common_attributes(self):
pass
......@@ -379,6 +383,10 @@ class AlignTextModelTest(ModelTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Align does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="AlignTextModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
......@@ -473,6 +481,10 @@ class AlignModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Align does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests")
def test_retain_grad_hidden_states_attentions(self):
pass
......
......@@ -579,6 +579,29 @@ class BarkSemanticModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Te
with torch.no_grad():
model(**inputs)[0]
# override as the input arg is called "input_embeds", not "inputs_embeds"
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
with torch.no_grad():
out_ids = model(**inputs)[0]
input_ids = inputs["input_ids"]
del inputs["input_ids"]
wte = model.get_input_embeddings()
inputs["input_embeds"] = wte(input_ids)
with torch.no_grad():
out_embeds = model(**inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
......@@ -645,6 +668,29 @@ class BarkCoarseModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
with torch.no_grad():
model(**inputs)[0]
# override as the input arg is called "input_embeds", not "inputs_embeds"
def test_inputs_embeds_matches_input_ids(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
with torch.no_grad():
out_ids = model(**inputs)[0]
input_ids = inputs["input_ids"]
del inputs["input_ids"]
wte = model.get_input_embeddings()
inputs["input_embeds"] = wte(input_ids)
with torch.no_grad():
out_embeds = model(**inputs)[0]
self.assertTrue(torch.allclose(out_embeds, out_ids))
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
......@@ -709,6 +755,10 @@ class BarkFineModelTest(ModelTesterMixin, unittest.TestCase):
with torch.no_grad():
model(**inputs)[0]
@unittest.skip("FineModel relies on codebook idx and does not return same logits")
def test_inputs_embeds_matches_input_ids(self):
pass
@require_torch_fp16
def test_generate_fp16(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
......
......@@ -506,6 +506,10 @@ class BridgeTowerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestC
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Bridge Tower does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
# We will verify our results on an image of cute cats
def prepare_img():
......
......@@ -502,6 +502,10 @@ class CanineModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
# ViT does not use inputs_embeds
pass
@unittest.skip(reason="Canine Tower does not use inputs_embeds")
def test_inputs_embeds_matches_input_ids(self):
pass
@unittest.skip("CANINE does not have a get_input_embeddings() method.")
def test_model_common_attributes(self):
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment