Unverified Commit dfbd209c authored by Susnato Dhar's avatar Susnato Dhar Committed by GitHub
Browse files

CLVP Fixes (#27547)

* fixes

* more fixes

* style fix

* more fix

* comments
parent 30e92ea3
......@@ -81,8 +81,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, v, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
......@@ -107,7 +106,51 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
v_embed = (v * cos) + (rotate_half(v) * sin)
return q_embed, k_embed, v_embed
def _pad_extra_bos_eos_tokens(
input_ids,
attention_mask=None,
pad_token_id=0,
bos_token_id=255,
eos_token_id=0,
add_bos_token=True,
add_eos_token=True,
):
"""
This method adds extra bos and eos tokens to input_ids and accordingly modifies the attention_mask which is used in
`ClvpConditioningEncoder` and the generation loop of the `ClvpModelForConditionalGeneration`.
"""
# add the bos token at the beginning
if add_bos_token:
input_ids = torch.nn.functional.pad(input_ids, (1, 0), value=bos_token_id)
attention_mask = (
torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
)
modified_input_ids = input_ids
if add_eos_token:
modified_input_ids = torch.zeros(
(input_ids.shape[0], input_ids.shape[1] + 1), dtype=input_ids.dtype, device=input_ids.device
)
for i, each_input_id in enumerate(input_ids):
# locate where the valid tokens end and then add the eos token
if torch.isin(each_input_id, pad_token_id).sum():
pos = torch.where(each_input_id == pad_token_id)[0].min()
modified_input_ids[i] = torch.concatenate(
[each_input_id[:pos], torch.tensor([eos_token_id], device=input_ids.device), each_input_id[pos:]]
)
else:
# if there are no pad tokens present, then add eos to the end
modified_input_ids[i] = torch.nn.functional.pad(each_input_id, (0, 1), value=eos_token_id)
attention_mask = (
torch.nn.functional.pad(attention_mask, (1, 0), value=1) if attention_mask is not None else attention_mask
)
return modified_input_ids, attention_mask
@dataclass
......@@ -312,13 +355,18 @@ class ClvpSelfAttention(nn.Module):
key_states[..., :rotary_emb_dim],
key_states[..., rotary_emb_dim:],
)
value_rot, value_pass = (
value_states[..., :rotary_emb_dim],
value_states[..., rotary_emb_dim:],
)
cos, sin = rotary_pos_emb.cos().squeeze(0), rotary_pos_emb.sin().squeeze(0)
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query_rot, key_rot, value_rot = apply_rotary_pos_emb(query_rot, key_rot, value_rot, cos, sin, position_ids)
# [batch_size, num_heads, seq_length, head_dim]
query_states = torch.cat((query_rot, query_pass), dim=-1)
key_states = torch.cat((key_rot, key_pass), dim=-1)
value_states = torch.cat((value_rot, value_pass), dim=-1)
tgt_len = query_states.shape[2]
src_len = key_states.shape[2]
......@@ -599,16 +647,7 @@ class ClvpConditioningEncoder(nn.Module):
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
# We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
# This logic is specific to ClvpConditioningEncoder and not used by other modules.
input_ids = torch.nn.functional.pad(input_ids, (1, 0), value=self.text_config.bos_token_id)
input_ids = torch.nn.functional.pad(input_ids, (0, 1), value=self.text_config.eos_token_id)
batch_size, seq_length = input_ids.size()
inputs_embeds = self.text_token_embedding(input_ids)
# check if we need to update attention mask, if yes then pad it too
if attention_mask is not None and attention_mask.shape[1] != seq_length:
attention_mask = torch.nn.functional.pad(attention_mask, (1, 0), value=1)
attention_mask = torch.nn.functional.pad(attention_mask, (0, 1), value=1)
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.size()[:-1]
else:
......@@ -616,8 +655,18 @@ class ClvpConditioningEncoder(nn.Module):
# construct attention mask if not given
if attention_mask is None:
attention_mask = torch.ones([batch_size, seq_length], dtype=torch.long, device=inputs_embeds.device)
attention_mask = torch.ones([batch_size, seq_length], dtype=torch.long, device=input_ids.device)
# We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
# This logic is specific to ClvpConditioningEncoder and not used by other modules.
input_ids, attention_mask = _pad_extra_bos_eos_tokens(
input_ids,
attention_mask,
bos_token_id=self.text_config.bos_token_id,
eos_token_id=self.text_config.eos_token_id,
)
inputs_embeds = self.text_token_embedding(input_ids)
position_ids = attention_mask.cumsum(-1) - 1
position_embeds = self.text_position_embedding(position_ids)
text_embeds = inputs_embeds + position_embeds
......@@ -1512,10 +1561,6 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
"""
decoder_fixing_codes = self.config.decoder_config.decoder_fixing_codes
speech_ids = speech_ids[:, 1:]
if torch.isin(self.speech_decoder_model.config.eos_token_id, speech_ids):
speech_ids = torch.nn.functional.pad(
speech_ids, pad=(0, 1), value=self.speech_decoder_model.config.eos_token_id
)
stop_token_indices = torch.where(speech_ids == self.speech_decoder_model.config.eos_token_id, 1, 0)
speech_ids = torch.masked_fill(speech_ids, mask=stop_token_indices.bool(), value=decoder_fixing_codes[0])
......@@ -1828,6 +1873,7 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
input_features: torch.FloatTensor = None,
attention_mask: Optional[torch.LongTensor] = None,
generation_config: Optional[GenerationConfig] = None,
pad_to_max_mel_tokens: Optional[int] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
):
......@@ -1855,6 +1901,11 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
pad_to_max_mel_tokens (`int`, *optional*):
Pads generated speech_ids to the specified value. This is to implement the same logic from the official
repo, link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
and to make sure the logits are same.
This does not affect generation quality so please don't consider using it since it is less efficient.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of decoder model, text encoder and speech encoder models.
......@@ -1862,6 +1913,17 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
`ClvpOutput` or tuple: A `ClvpOutput` (if `return_dict_in_generate=True` or when
`config.return_dict_in_generate=True`) or a tuple.
"""
# If the input sequences are larger than (self.config.decoder_config.max_text_tokens - 3) then raise error,
# because we need to add 3 tokens ( 1 bos tokens and 2 eos tokens) to the input_ids in ClvpConditioningEncoder to
# properly sample
sequence_length = input_ids.shape[-1]
if sequence_length > (self.config.decoder_config.max_text_tokens - 3):
raise ValueError(
f"Maximum sequence length reached! Found input_ids of length {sequence_length}."
f"Please make sure that the maximum length of input_ids is {self.config.decoder_config.max_text_tokens - 3}"
)
if generation_config is None:
generation_config = self.generation_config
......@@ -1870,6 +1932,16 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
generation_config.validate()
self._validate_model_kwargs(model_kwargs.copy())
# pad input_ids as specified in the original repo
# link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L380
input_ids, attention_mask = _pad_extra_bos_eos_tokens(
input_ids,
attention_mask,
add_bos_token=False,
bos_token_id=self.config.text_config.bos_token_id,
eos_token_id=self.config.text_config.eos_token_id,
)
conditioning_embeds = self.conditioning_encoder(
input_features=input_features,
input_ids=input_ids,
......@@ -1884,6 +1956,15 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
)
if isinstance(decoder_outputs, ModelOutput):
speech_ids = decoder_outputs.sequences
# pad to pad_to_max_mel_tokens if given, to replicate the original repo logic
# link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
if pad_to_max_mel_tokens is not None:
padding_needed = pad_to_max_mel_tokens - speech_ids.shape[-1]
speech_ids = torch.nn.functional.pad(
speech_ids, (0, padding_needed), value=self.generation_config.eos_token_id
)
speech_ids = self.fix_speech_decoder_output(speech_ids)
speech_outputs = self.speech_encoder_model(
......
......@@ -604,12 +604,7 @@ class ClvpIntegrationTest(unittest.TestCase):
text_embeds = self.model.text_encoder_model(input_ids=self.text_tokens, return_dict=True)[0].cpu()
# fmt: off
EXPECTED_TEXT_EMBEDS = torch.tensor(
[ 1.8060e+00, -2.7928e+00, 3.2021e+00, -1.5673e+00, 2.3284e+00, -3.2065e+00, -1.3368e+00, 2.2322e+00,
-1.7667e+00, 4.1505e-01, 2.4119e+00, -5.8133e-03, -4.6367e+00, 1.6450e-01, 6.7459e+00, 6.6292e+00,
1.1046e+00, 3.6196e+00, -1.0496e+01, 5.4924e+00
]
)
EXPECTED_TEXT_EMBEDS = torch.tensor([1.4798, -2.0005, 2.3902, -0.5042, 1.6401, -2.4135, -1.4800, 3.0118, -2.4422, 1.3266, 2.2339, 1.4761, -4.8983, -1.3592, 6.0251, 6.7364, 2.2576, 3.7229, -10.0436, 4.6676])
# fmt: on
self.assertTrue(torch.allclose(text_embeds[0, :20], EXPECTED_TEXT_EMBEDS, atol=1e-4))
......@@ -618,11 +613,7 @@ class ClvpIntegrationTest(unittest.TestCase):
speech_embeds = self.model.speech_encoder_model(input_ids=self.text_tokens, return_dict=True)[0].cpu()
# fmt: off
EXPECTED_SPEECH_EMBEDS = torch.tensor(
[ 4.6143, -5.5784, 0.8983, -3.9665, -0.6714, -1.0665, -1.1277, 1.5619, 2.6322, -7.2008, -2.4932, 0.3265,
-1.4738, 0.1425, 5.0825, 4.1760, -5.4708, 2.1935, -6.0044, 3.9540
]
)
EXPECTED_SPEECH_EMBEDS = torch.tensor([3.1202, -3.1183, -1.4264, -6.1339, 1.8885, -0.1983, 0.9461, -1.7414, 0.3320, -3.8400, -1.5715, 1.5096, -1.7576, 0.2387, 4.9758, 5.8450, -6.2534, 2.8587, -5.5816, 4.7821])
# fmt: on
self.assertTrue(torch.allclose(speech_embeds[0, :20], EXPECTED_SPEECH_EMBEDS, atol=1e-4))
......@@ -635,8 +626,10 @@ class ClvpIntegrationTest(unittest.TestCase):
num_beams=4,
num_return_sequences=4,
max_new_tokens=10,
).speech_ids.cpu()
)
EXPECTED_OUTPUTS = torch.tensor([[1953, 1080, 612], [1953, 1953, 612], [1953, 612, 716]])
EXPECTED_SPEECH_IDS = torch.tensor([[1953, 1080, 612], [1953, 612, 493], [1953, 612, 716]])
EXPECTED_SIMILARITY_SCORES = torch.tensor([[14.7660, 14.4569, 13.6472, 13.5683]])
self.assertTrue(torch.allclose(full_model_output[-3:, -3:], EXPECTED_OUTPUTS))
self.assertTrue(torch.allclose(full_model_output.speech_ids.cpu()[-3:, -3:], EXPECTED_SPEECH_IDS))
self.assertTrue(torch.allclose(full_model_output.logits_per_text.cpu(), EXPECTED_SIMILARITY_SCORES))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment