Unverified Commit dfbd209c authored by Susnato Dhar's avatar Susnato Dhar Committed by GitHub
Browse files

CLVP Fixes (#27547)

* fixes

* more fixes

* style fix

* more fix

* comments
parent 30e92ea3
...@@ -81,8 +81,7 @@ def rotate_half(x): ...@@ -81,8 +81,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, v, cos, sin, position_ids, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
Args: Args:
...@@ -107,7 +106,51 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): ...@@ -107,7 +106,51 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
sin = sin[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed v_embed = (v * cos) + (rotate_half(v) * sin)
return q_embed, k_embed, v_embed
def _pad_extra_bos_eos_tokens(
input_ids,
attention_mask=None,
pad_token_id=0,
bos_token_id=255,
eos_token_id=0,
add_bos_token=True,
add_eos_token=True,
):
"""
This method adds extra bos and eos tokens to input_ids and accordingly modifies the attention_mask which is used in
`ClvpConditioningEncoder` and the generation loop of the `ClvpModelForConditionalGeneration`.
"""
# add the bos token at the beginning
if add_bos_token:
input_ids = torch.nn.functional.pad(input_ids, (1, 0), value=bos_token_id)
attention_mask = (
torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
)
modified_input_ids = input_ids
if add_eos_token:
modified_input_ids = torch.zeros(
(input_ids.shape[0], input_ids.shape[1] + 1), dtype=input_ids.dtype, device=input_ids.device
)
for i, each_input_id in enumerate(input_ids):
# locate where the valid tokens end and then add the eos token
if torch.isin(each_input_id, pad_token_id).sum():
pos = torch.where(each_input_id == pad_token_id)[0].min()
modified_input_ids[i] = torch.concatenate(
[each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]]
)
else:
# if there are no pad tokens present, then add eos to the end
modified_input_ids[i] = torch.nn.functional.pad(each_input_id, (0, 1), value=eos_token_id)
attention_mask = (
torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
)
return modified_input_ids, attention_mask
@dataclass @dataclass
...@@ -312,13 +355,18 @@ class ClvpSelfAttention(nn.Module): ...@@ -312,13 +355,18 @@ class ClvpSelfAttention(nn.Module):
key_states[..., :rotary_emb_dim], key_states[..., :rotary_emb_dim],
key_states[..., rotary_emb_dim:], key_states[..., rotary_emb_dim:],
) )
value_rot, value_pass = (
value_states[..., :rotary_emb_dim],
value_states[..., rotary_emb_dim:],
)
cos, sin = rotary_pos_emb.cos().squeeze(0), rotary_pos_emb.sin().squeeze(0) cos, sin = rotary_pos_emb.cos().squeeze(0), rotary_pos_emb.sin().squeeze(0)
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) query_rot, key_rot, value_rot = apply_rotary_pos_emb(query_rot, key_rot, value_rot, cos, sin, position_ids)
# [batch_size, num_heads, seq_length, head_dim] # [batch_size, num_heads, seq_length, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1) query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1)
value_states = torch.cat((value_rot, value_pass), dim=-1)
tgt_len = query_states.shape[2] tgt_len = query_states.shape[2]
src_len = key_states.shape[2] src_len = key_states.shape[2]
...@@ -599,16 +647,7 @@ class ClvpConditioningEncoder(nn.Module): ...@@ -599,16 +647,7 @@ class ClvpConditioningEncoder(nn.Module):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
# We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
# This logic is specific to ClvpConditioningEncoder and not used by other modules.
input_ids = torch.nn.functional.pad(input_ids, (1, 0), value=self.text_config.bos_token_id)
input_ids = torch.nn.functional.pad(input_ids, (0, 1), value=self.text_config.eos_token_id)
batch_size, seq_length = input_ids.size() batch_size, seq_length = input_ids.size()
inputs_embeds = self.text_token_embedding(input_ids)
# check if we need to update attention mask, if yes then pad it too
if attention_mask is not None and attention_mask.shape[1] != seq_length:
attention_mask = torch.nn.functional.pad(attention_mask, (1, 0), value=1)
attention_mask = torch.nn.functional.pad(attention_mask, (0, 1), value=1)
elif inputs_embeds is not None: elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.size()[:-1] batch_size, seq_length = inputs_embeds.size()[:-1]
else: else:
...@@ -616,8 +655,18 @@ class ClvpConditioningEncoder(nn.Module): ...@@ -616,8 +655,18 @@ class ClvpConditioningEncoder(nn.Module):
# construct attention mask if not given # construct attention mask if not given
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones([batch_size, seq_length], dtype=torch.long, device=inputs_embeds.device) attention_mask = torch.ones([batch_size, seq_length], dtype=torch.long, device=input_ids.device)
# We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
# This logic is specific to ClvpConditioningEncoder and not used by other modules.
input_ids, attention_mask = _pad_extra_bos_eos_tokens(
input_ids,
attention_mask,
bos_token_id=self.text_config.bos_token_id,
eos_token_id=self.text_config.eos_token_id,
)
inputs_embeds = self.text_token_embedding(input_ids)
position_ids = attention_mask.cumsum(-1) - 1 position_ids = attention_mask.cumsum(-1) - 1
position_embeds = self.text_position_embedding(position_ids) position_embeds = self.text_position_embedding(position_ids)
text_embeds = inputs_embeds + position_embeds text_embeds = inputs_embeds + position_embeds
...@@ -1512,10 +1561,6 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): ...@@ -1512,10 +1561,6 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
""" """
decoder_fixing_codes = self.config.decoder_config.decoder_fixing_codes decoder_fixing_codes = self.config.decoder_config.decoder_fixing_codes
speech_ids = speech_ids[:, 1:] speech_ids = speech_ids[:, 1:]
if torch.isin(self.speech_decoder_model.config.eos_token_id, speech_ids):
speech_ids = torch.nn.functional.pad(
speech_ids, pad=(0, 1), value=self.speech_decoder_model.config.eos_token_id
)
stop_token_indices = torch.where(speech_ids == self.speech_decoder_model.config.eos_token_id, 1, 0) stop_token_indices = torch.where(speech_ids == self.speech_decoder_model.config.eos_token_id, 1, 0)
speech_ids = torch.masked_fill(speech_ids, mask=stop_token_indices.bool(), value=decoder_fixing_codes[0]) speech_ids = torch.masked_fill(speech_ids, mask=stop_token_indices.bool(), value=decoder_fixing_codes[0])
...@@ -1828,6 +1873,7 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): ...@@ -1828,6 +1873,7 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
input_features: torch.FloatTensor = None, input_features: torch.FloatTensor = None,
attention_mask: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.LongTensor] = None,
generation_config: Optional[GenerationConfig] = None, generation_config: Optional[GenerationConfig] = None,
pad_to_max_mel_tokens: Optional[int] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
**kwargs, **kwargs,
): ):
...@@ -1855,6 +1901,11 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): ...@@ -1855,6 +1901,11 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation. default values, whose documentation should be checked to parameterize generation.
pad_to_max_mel_tokens (`int`, *optional*):
Pads generated speech_ids to the specified value. This is to implement the same logic from the official
repo, link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
and to make sure the logits are same.
This does not affect generation quality so please don't consider using it since it is less efficient.
output_hidden_states (`bool`, *optional*): output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of decoder model, text encoder and speech encoder models. Whether or not to return the hidden states of decoder model, text encoder and speech encoder models.
...@@ -1862,6 +1913,17 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): ...@@ -1862,6 +1913,17 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
`ClvpOutput` or tuple: A `ClvpOutput` (if `return_dict_in_generate=True` or when `ClvpOutput` or tuple: A `ClvpOutput` (if `return_dict_in_generate=True` or when
`config.return_dict_in_generate=True`) or a tuple. `config.return_dict_in_generate=True`) or a tuple.
""" """
# If the input sequences are larger than (self.config.decoder_config.max_text_tokens - 3) then raise error,
# because we need to add 3 tokens ( 1 bos tokens and 2 eos tokens) to the input_ids in ClvpConditioningEncoder to
# properly sample
sequence_length = input_ids.shape[-1]
if sequence_length > (self.config.decoder_config.max_text_tokens - 3):
raise ValueError(
f"Maximum sequence length reached! Found input_ids of length {sequence_length}."
f"Please make sure that the maximum length of input_ids is {self.config.decoder_config.max_text_tokens - 3}"
)
if generation_config is None: if generation_config is None:
generation_config = self.generation_config generation_config = self.generation_config
...@@ -1870,6 +1932,16 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): ...@@ -1870,6 +1932,16 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
generation_config.validate() generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy()) self._validate_model_kwargs(model_kwargs.copy())
# pad input_ids as specified in the original repo
# link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L380
input_ids, attention_mask = _pad_extra_bos_eos_tokens(
input_ids,
attention_mask,
add_bos_token=False,
bos_token_id=self.config.text_config.bos_token_id,
eos_token_id=self.config.text_config.eos_token_id,
)
conditioning_embeds = self.conditioning_encoder( conditioning_embeds = self.conditioning_encoder(
input_features=input_features, input_features=input_features,
input_ids=input_ids, input_ids=input_ids,
...@@ -1884,6 +1956,15 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel): ...@@ -1884,6 +1956,15 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
) )
if isinstance(decoder_outputs, ModelOutput): if isinstance(decoder_outputs, ModelOutput):
speech_ids = decoder_outputs.sequences speech_ids = decoder_outputs.sequences
# pad to pad_to_max_mel_tokens if given, to replicate the original repo logic
# link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
if pad_to_max_mel_tokens is not None:
padding_needed = pad_to_max_mel_tokens - speech_ids.shape[-1]
speech_ids = torch.nn.functional.pad(
speech_ids, (0, padding_needed), value=self.generation_config.eos_token_id
)
speech_ids = self.fix_speech_decoder_output(speech_ids) speech_ids = self.fix_speech_decoder_output(speech_ids)
speech_outputs = self.speech_encoder_model( speech_outputs = self.speech_encoder_model(
......
...@@ -604,12 +604,7 @@ class ClvpIntegrationTest(unittest.TestCase): ...@@ -604,12 +604,7 @@ class ClvpIntegrationTest(unittest.TestCase):
text_embeds = self.model.text_encoder_model(input_ids=self.text_tokens, return_dict=True)[0].cpu() text_embeds = self.model.text_encoder_model(input_ids=self.text_tokens, return_dict=True)[0].cpu()
# fmt: off # fmt: off
EXPECTED_TEXT_EMBEDS = torch.tensor( EXPECTED_TEXT_EMBEDS = torch.tensor([1.4798, -2.0005, 2.3902, -0.5042, 1.6401, -2.4135, -1.4800, 3.0118, -2.4422, 1.3266, 2.2339, 1.4761, -4.8983, -1.3592, 6.0251, 6.7364, 2.2576, 3.7229, -10.0436, 4.6676])
[ 1.8060e+00, -2.7928e+00, 3.2021e+00, -1.5673e+00, 2.3284e+00, -3.2065e+00, -1.3368e+00, 2.2322e+00,
-1.7667e+00, 4.1505e-01, 2.4119e+00, -5.8133e-03, -4.6367e+00, 1.6450e-01, 6.7459e+00, 6.6292e+00,
1.1046e+00, 3.6196e+00, -1.0496e+01, 5.4924e+00
]
)
# fmt: on # fmt: on
self.assertTrue(torch.allclose(text_embeds[0, :20], EXPECTED_TEXT_EMBEDS, atol=1e-4)) self.assertTrue(torch.allclose(text_embeds[0, :20], EXPECTED_TEXT_EMBEDS, atol=1e-4))
...@@ -618,11 +613,7 @@ class ClvpIntegrationTest(unittest.TestCase): ...@@ -618,11 +613,7 @@ class ClvpIntegrationTest(unittest.TestCase):
speech_embeds = self.model.speech_encoder_model(input_ids=self.text_tokens, return_dict=True)[0].cpu() speech_embeds = self.model.speech_encoder_model(input_ids=self.text_tokens, return_dict=True)[0].cpu()
# fmt: off # fmt: off
EXPECTED_SPEECH_EMBEDS = torch.tensor( EXPECTED_SPEECH_EMBEDS = torch.tensor([3.1202, -3.1183, -1.4264, -6.1339, 1.8885, -0.1983, 0.9461, -1.7414, 0.3320, -3.8400, -1.5715, 1.5096, -1.7576, 0.2387, 4.9758, 5.8450, -6.2534, 2.8587, -5.5816, 4.7821])
[ 4.6143, -5.5784, 0.8983, -3.9665, -0.6714, -1.0665, -1.1277, 1.5619, 2.6322, -7.2008, -2.4932, 0.3265,
-1.4738, 0.1425, 5.0825, 4.1760, -5.4708, 2.1935, -6.0044, 3.9540
]
)
# fmt: on # fmt: on
self.assertTrue(torch.allclose(speech_embeds[0, :20], EXPECTED_SPEECH_EMBEDS, atol=1e-4)) self.assertTrue(torch.allclose(speech_embeds[0, :20], EXPECTED_SPEECH_EMBEDS, atol=1e-4))
...@@ -635,8 +626,10 @@ class ClvpIntegrationTest(unittest.TestCase): ...@@ -635,8 +626,10 @@ class ClvpIntegrationTest(unittest.TestCase):
num_beams=4, num_beams=4,
num_return_sequences=4, num_return_sequences=4,
max_new_tokens=10, max_new_tokens=10,
).speech_ids.cpu() )
EXPECTED_OUTPUTS = torch.tensor([[1953, 1080, 612], [1953, 1953, 612], [1953, 612, 716]]) EXPECTED_SPEECH_IDS = torch.tensor([[1953, 1080, 612], [1953, 612, 493], [1953, 612, 716]])
EXPECTED_SIMILARITY_SCORES = torch.tensor([[14.7660, 14.4569, 13.6472, 13.5683]])
self.assertTrue(torch.allclose(full_model_output[-3:, -3:], EXPECTED_OUTPUTS)) self.assertTrue(torch.allclose(full_model_output.speech_ids.cpu()[-3:, -3:], EXPECTED_SPEECH_IDS))
self.assertTrue(torch.allclose(full_model_output.logits_per_text.cpu(), EXPECTED_SIMILARITY_SCORES))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment