Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dfbd209c
Unverified
Commit
dfbd209c
authored
Nov 28, 2023
by
Susnato Dhar
Committed by
GitHub
Nov 28, 2023
Browse files
CLVP Fixes (#27547)
* fixes * more fixes * style fix * more fix * comments
parent
30e92ea3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
32 deletions
+106
-32
src/transformers/models/clvp/modeling_clvp.py
src/transformers/models/clvp/modeling_clvp.py
+99
-18
tests/models/clvp/test_modeling_clvp.py
tests/models/clvp/test_modeling_clvp.py
+7
-14
No files found.
src/transformers/models/clvp/modeling_clvp.py
View file @
dfbd209c
...
@@ -81,8 +81,7 @@ def rotate_half(x):
...
@@ -81,8 +81,7 @@ def rotate_half(x):
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def
apply_rotary_pos_emb
(
q
,
k
,
v
,
cos
,
sin
,
position_ids
,
unsqueeze_dim
=
1
):
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
Args:
...
@@ -107,7 +106,51 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
...
@@ -107,7 +106,51 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
sin
=
sin
[
position_ids
].
unsqueeze
(
unsqueeze_dim
)
sin
=
sin
[
position_ids
].
unsqueeze
(
unsqueeze_dim
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
v_embed
=
(
v
*
cos
)
+
(
rotate_half
(
v
)
*
sin
)
return
q_embed
,
k_embed
,
v_embed
def
_pad_extra_bos_eos_tokens
(
input_ids
,
attention_mask
=
None
,
pad_token_id
=
0
,
bos_token_id
=
255
,
eos_token_id
=
0
,
add_bos_token
=
True
,
add_eos_token
=
True
,
):
"""
This method adds extra bos and eos tokens to input_ids and accordingly modifies the attention_mask which is used in
`ClvpConditioningEncoder` and the generation loop of the `ClvpModelForConditionalGeneration`.
"""
# add the bos token at the beginning
if
add_bos_token
:
input_ids
=
torch
.
nn
.
functional
.
pad
(
input_ids
,
(
1
,
0
),
value
=
bos_token_id
)
attention_mask
=
(
torch
.
nn
.
functional
.
pad
(
attention_mask
,
(
1
,
0
),
value
=
1
)
if
attention_mask
is
not
None
else
attention_mask
)
modified_input_ids
=
input_ids
if
add_eos_token
:
modified_input_ids
=
torch
.
zeros
(
(
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
]
+
1
),
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
)
for
i
,
each_input_id
in
enumerate
(
input_ids
):
# locate where the valid tokens end and then add the eos token
if
torch
.
isin
(
each_input_id
,
pad_token_id
).
sum
():
pos
=
torch
.
where
(
each_input_id
==
pad_token_id
)[
0
].
min
()
modified_input_ids
[
i
]
=
torch
.
concatenate
(
[
each_input_id
[:
pos
],
torch
.
tensor
([
eos_token_id
],
device
=
input_ids
.
device
),
each_input_id
[
pos
:]]
)
else
:
# if there are no pad tokens present, then add eos to the end
modified_input_ids
[
i
]
=
torch
.
nn
.
functional
.
pad
(
each_input_id
,
(
0
,
1
),
value
=
eos_token_id
)
attention_mask
=
(
torch
.
nn
.
functional
.
pad
(
attention_mask
,
(
1
,
0
),
value
=
1
)
if
attention_mask
is
not
None
else
attention_mask
)
return
modified_input_ids
,
attention_mask
@
dataclass
@
dataclass
...
@@ -312,13 +355,18 @@ class ClvpSelfAttention(nn.Module):
...
@@ -312,13 +355,18 @@ class ClvpSelfAttention(nn.Module):
key_states
[...,
:
rotary_emb_dim
],
key_states
[...,
:
rotary_emb_dim
],
key_states
[...,
rotary_emb_dim
:],
key_states
[...,
rotary_emb_dim
:],
)
)
value_rot
,
value_pass
=
(
value_states
[...,
:
rotary_emb_dim
],
value_states
[...,
rotary_emb_dim
:],
)
cos
,
sin
=
rotary_pos_emb
.
cos
().
squeeze
(
0
),
rotary_pos_emb
.
sin
().
squeeze
(
0
)
cos
,
sin
=
rotary_pos_emb
.
cos
().
squeeze
(
0
),
rotary_pos_emb
.
sin
().
squeeze
(
0
)
query_rot
,
key_rot
=
apply_rotary_pos_emb
(
query_rot
,
key_rot
,
cos
,
sin
,
position_ids
)
query_rot
,
key_rot
,
value_rot
=
apply_rotary_pos_emb
(
query_rot
,
key_rot
,
value_rot
,
cos
,
sin
,
position_ids
)
# [batch_size, num_heads, seq_length, head_dim]
# [batch_size, num_heads, seq_length, head_dim]
query_states
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
query_states
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key_states
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
key_states
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
value_states
=
torch
.
cat
((
value_rot
,
value_pass
),
dim
=-
1
)
tgt_len
=
query_states
.
shape
[
2
]
tgt_len
=
query_states
.
shape
[
2
]
src_len
=
key_states
.
shape
[
2
]
src_len
=
key_states
.
shape
[
2
]
...
@@ -599,16 +647,7 @@ class ClvpConditioningEncoder(nn.Module):
...
@@ -599,16 +647,7 @@ class ClvpConditioningEncoder(nn.Module):
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
elif
input_ids
is
not
None
:
# We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
# This logic is specific to ClvpConditioningEncoder and not used by other modules.
input_ids
=
torch
.
nn
.
functional
.
pad
(
input_ids
,
(
1
,
0
),
value
=
self
.
text_config
.
bos_token_id
)
input_ids
=
torch
.
nn
.
functional
.
pad
(
input_ids
,
(
0
,
1
),
value
=
self
.
text_config
.
eos_token_id
)
batch_size
,
seq_length
=
input_ids
.
size
()
batch_size
,
seq_length
=
input_ids
.
size
()
inputs_embeds
=
self
.
text_token_embedding
(
input_ids
)
# check if we need to update attention mask, if yes then pad it too
if
attention_mask
is
not
None
and
attention_mask
.
shape
[
1
]
!=
seq_length
:
attention_mask
=
torch
.
nn
.
functional
.
pad
(
attention_mask
,
(
1
,
0
),
value
=
1
)
attention_mask
=
torch
.
nn
.
functional
.
pad
(
attention_mask
,
(
0
,
1
),
value
=
1
)
elif
inputs_embeds
is
not
None
:
elif
inputs_embeds
is
not
None
:
batch_size
,
seq_length
=
inputs_embeds
.
size
()[:
-
1
]
batch_size
,
seq_length
=
inputs_embeds
.
size
()[:
-
1
]
else
:
else
:
...
@@ -616,8 +655,18 @@ class ClvpConditioningEncoder(nn.Module):
...
@@ -616,8 +655,18 @@ class ClvpConditioningEncoder(nn.Module):
# construct attention mask if not given
# construct attention mask if not given
if
attention_mask
is
None
:
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
([
batch_size
,
seq_length
],
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
attention_mask
=
torch
.
ones
([
batch_size
,
seq_length
],
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
# We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
# This logic is specific to ClvpConditioningEncoder and not used by other modules.
input_ids
,
attention_mask
=
_pad_extra_bos_eos_tokens
(
input_ids
,
attention_mask
,
bos_token_id
=
self
.
text_config
.
bos_token_id
,
eos_token_id
=
self
.
text_config
.
eos_token_id
,
)
inputs_embeds
=
self
.
text_token_embedding
(
input_ids
)
position_ids
=
attention_mask
.
cumsum
(
-
1
)
-
1
position_ids
=
attention_mask
.
cumsum
(
-
1
)
-
1
position_embeds
=
self
.
text_position_embedding
(
position_ids
)
position_embeds
=
self
.
text_position_embedding
(
position_ids
)
text_embeds
=
inputs_embeds
+
position_embeds
text_embeds
=
inputs_embeds
+
position_embeds
...
@@ -1512,10 +1561,6 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
...
@@ -1512,10 +1561,6 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
"""
"""
decoder_fixing_codes
=
self
.
config
.
decoder_config
.
decoder_fixing_codes
decoder_fixing_codes
=
self
.
config
.
decoder_config
.
decoder_fixing_codes
speech_ids
=
speech_ids
[:,
1
:]
speech_ids
=
speech_ids
[:,
1
:]
if
torch
.
isin
(
self
.
speech_decoder_model
.
config
.
eos_token_id
,
speech_ids
):
speech_ids
=
torch
.
nn
.
functional
.
pad
(
speech_ids
,
pad
=
(
0
,
1
),
value
=
self
.
speech_decoder_model
.
config
.
eos_token_id
)
stop_token_indices
=
torch
.
where
(
speech_ids
==
self
.
speech_decoder_model
.
config
.
eos_token_id
,
1
,
0
)
stop_token_indices
=
torch
.
where
(
speech_ids
==
self
.
speech_decoder_model
.
config
.
eos_token_id
,
1
,
0
)
speech_ids
=
torch
.
masked_fill
(
speech_ids
,
mask
=
stop_token_indices
.
bool
(),
value
=
decoder_fixing_codes
[
0
])
speech_ids
=
torch
.
masked_fill
(
speech_ids
,
mask
=
stop_token_indices
.
bool
(),
value
=
decoder_fixing_codes
[
0
])
...
@@ -1828,6 +1873,7 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
...
@@ -1828,6 +1873,7 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
input_features
:
torch
.
FloatTensor
=
None
,
input_features
:
torch
.
FloatTensor
=
None
,
attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
generation_config
:
Optional
[
GenerationConfig
]
=
None
,
generation_config
:
Optional
[
GenerationConfig
]
=
None
,
pad_to_max_mel_tokens
:
Optional
[
int
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
**
kwargs
,
**
kwargs
,
):
):
...
@@ -1855,6 +1901,11 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
...
@@ -1855,6 +1901,11 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
default values, whose documentation should be checked to parameterize generation.
pad_to_max_mel_tokens (`int`, *optional*):
Pads generated speech_ids to the specified value. This is to implement the same logic from the official
repo, link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
and to make sure the logits are same.
This does not affect generation quality so please don't consider using it since it is less efficient.
output_hidden_states (`bool`, *optional*):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of decoder model, text encoder and speech encoder models.
Whether or not to return the hidden states of decoder model, text encoder and speech encoder models.
...
@@ -1862,6 +1913,17 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
...
@@ -1862,6 +1913,17 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
`ClvpOutput` or tuple: A `ClvpOutput` (if `return_dict_in_generate=True` or when
`ClvpOutput` or tuple: A `ClvpOutput` (if `return_dict_in_generate=True` or when
`config.return_dict_in_generate=True`) or a tuple.
`config.return_dict_in_generate=True`) or a tuple.
"""
"""
# If the input sequences are larger than (self.config.decoder_config.max_text_tokens - 3) then raise error,
# because we need to add 3 tokens ( 1 bos tokens and 2 eos tokens) to the input_ids in ClvpConditioningEncoder to
# properly sample
sequence_length
=
input_ids
.
shape
[
-
1
]
if
sequence_length
>
(
self
.
config
.
decoder_config
.
max_text_tokens
-
3
):
raise
ValueError
(
f
"Maximum sequence length reached! Found input_ids of length
{
sequence_length
}
."
f
"Please make sure that the maximum length of input_ids is
{
self
.
config
.
decoder_config
.
max_text_tokens
-
3
}
"
)
if
generation_config
is
None
:
if
generation_config
is
None
:
generation_config
=
self
.
generation_config
generation_config
=
self
.
generation_config
...
@@ -1870,6 +1932,16 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
...
@@ -1870,6 +1932,16 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
generation_config
.
validate
()
generation_config
.
validate
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# pad input_ids as specified in the original repo
# link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L380
input_ids
,
attention_mask
=
_pad_extra_bos_eos_tokens
(
input_ids
,
attention_mask
,
add_bos_token
=
False
,
bos_token_id
=
self
.
config
.
text_config
.
bos_token_id
,
eos_token_id
=
self
.
config
.
text_config
.
eos_token_id
,
)
conditioning_embeds
=
self
.
conditioning_encoder
(
conditioning_embeds
=
self
.
conditioning_encoder
(
input_features
=
input_features
,
input_features
=
input_features
,
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -1884,6 +1956,15 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
...
@@ -1884,6 +1956,15 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
)
)
if
isinstance
(
decoder_outputs
,
ModelOutput
):
if
isinstance
(
decoder_outputs
,
ModelOutput
):
speech_ids
=
decoder_outputs
.
sequences
speech_ids
=
decoder_outputs
.
sequences
# pad to pad_to_max_mel_tokens if given, to replicate the original repo logic
# link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
if
pad_to_max_mel_tokens
is
not
None
:
padding_needed
=
pad_to_max_mel_tokens
-
speech_ids
.
shape
[
-
1
]
speech_ids
=
torch
.
nn
.
functional
.
pad
(
speech_ids
,
(
0
,
padding_needed
),
value
=
self
.
generation_config
.
eos_token_id
)
speech_ids
=
self
.
fix_speech_decoder_output
(
speech_ids
)
speech_ids
=
self
.
fix_speech_decoder_output
(
speech_ids
)
speech_outputs
=
self
.
speech_encoder_model
(
speech_outputs
=
self
.
speech_encoder_model
(
...
...
tests/models/clvp/test_modeling_clvp.py
View file @
dfbd209c
...
@@ -604,12 +604,7 @@ class ClvpIntegrationTest(unittest.TestCase):
...
@@ -604,12 +604,7 @@ class ClvpIntegrationTest(unittest.TestCase):
text_embeds
=
self
.
model
.
text_encoder_model
(
input_ids
=
self
.
text_tokens
,
return_dict
=
True
)[
0
].
cpu
()
text_embeds
=
self
.
model
.
text_encoder_model
(
input_ids
=
self
.
text_tokens
,
return_dict
=
True
)[
0
].
cpu
()
# fmt: off
# fmt: off
EXPECTED_TEXT_EMBEDS
=
torch
.
tensor
(
EXPECTED_TEXT_EMBEDS
=
torch
.
tensor
([
1.4798
,
-
2.0005
,
2.3902
,
-
0.5042
,
1.6401
,
-
2.4135
,
-
1.4800
,
3.0118
,
-
2.4422
,
1.3266
,
2.2339
,
1.4761
,
-
4.8983
,
-
1.3592
,
6.0251
,
6.7364
,
2.2576
,
3.7229
,
-
10.0436
,
4.6676
])
[
1.8060e+00
,
-
2.7928e+00
,
3.2021e+00
,
-
1.5673e+00
,
2.3284e+00
,
-
3.2065e+00
,
-
1.3368e+00
,
2.2322e+00
,
-
1.7667e+00
,
4.1505e-01
,
2.4119e+00
,
-
5.8133e-03
,
-
4.6367e+00
,
1.6450e-01
,
6.7459e+00
,
6.6292e+00
,
1.1046e+00
,
3.6196e+00
,
-
1.0496e+01
,
5.4924e+00
]
)
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
text_embeds
[
0
,
:
20
],
EXPECTED_TEXT_EMBEDS
,
atol
=
1e-4
))
self
.
assertTrue
(
torch
.
allclose
(
text_embeds
[
0
,
:
20
],
EXPECTED_TEXT_EMBEDS
,
atol
=
1e-4
))
...
@@ -618,11 +613,7 @@ class ClvpIntegrationTest(unittest.TestCase):
...
@@ -618,11 +613,7 @@ class ClvpIntegrationTest(unittest.TestCase):
speech_embeds
=
self
.
model
.
speech_encoder_model
(
input_ids
=
self
.
text_tokens
,
return_dict
=
True
)[
0
].
cpu
()
speech_embeds
=
self
.
model
.
speech_encoder_model
(
input_ids
=
self
.
text_tokens
,
return_dict
=
True
)[
0
].
cpu
()
# fmt: off
# fmt: off
EXPECTED_SPEECH_EMBEDS
=
torch
.
tensor
(
EXPECTED_SPEECH_EMBEDS
=
torch
.
tensor
([
3.1202
,
-
3.1183
,
-
1.4264
,
-
6.1339
,
1.8885
,
-
0.1983
,
0.9461
,
-
1.7414
,
0.3320
,
-
3.8400
,
-
1.5715
,
1.5096
,
-
1.7576
,
0.2387
,
4.9758
,
5.8450
,
-
6.2534
,
2.8587
,
-
5.5816
,
4.7821
])
[
4.6143
,
-
5.5784
,
0.8983
,
-
3.9665
,
-
0.6714
,
-
1.0665
,
-
1.1277
,
1.5619
,
2.6322
,
-
7.2008
,
-
2.4932
,
0.3265
,
-
1.4738
,
0.1425
,
5.0825
,
4.1760
,
-
5.4708
,
2.1935
,
-
6.0044
,
3.9540
]
)
# fmt: on
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
speech_embeds
[
0
,
:
20
],
EXPECTED_SPEECH_EMBEDS
,
atol
=
1e-4
))
self
.
assertTrue
(
torch
.
allclose
(
speech_embeds
[
0
,
:
20
],
EXPECTED_SPEECH_EMBEDS
,
atol
=
1e-4
))
...
@@ -635,8 +626,10 @@ class ClvpIntegrationTest(unittest.TestCase):
...
@@ -635,8 +626,10 @@ class ClvpIntegrationTest(unittest.TestCase):
num_beams
=
4
,
num_beams
=
4
,
num_return_sequences
=
4
,
num_return_sequences
=
4
,
max_new_tokens
=
10
,
max_new_tokens
=
10
,
)
.
speech_ids
.
cpu
()
)
EXPECTED_OUTPUTS
=
torch
.
tensor
([[
1953
,
1080
,
612
],
[
1953
,
1953
,
612
],
[
1953
,
612
,
716
]])
EXPECTED_SPEECH_IDS
=
torch
.
tensor
([[
1953
,
1080
,
612
],
[
1953
,
612
,
493
],
[
1953
,
612
,
716
]])
EXPECTED_SIMILARITY_SCORES
=
torch
.
tensor
([[
14.7660
,
14.4569
,
13.6472
,
13.5683
]])
self
.
assertTrue
(
torch
.
allclose
(
full_model_output
[
-
3
:,
-
3
:],
EXPECTED_OUTPUTS
))
self
.
assertTrue
(
torch
.
allclose
(
full_model_output
.
speech_ids
.
cpu
()[
-
3
:,
-
3
:],
EXPECTED_SPEECH_IDS
))
self
.
assertTrue
(
torch
.
allclose
(
full_model_output
.
logits_per_text
.
cpu
(),
EXPECTED_SIMILARITY_SCORES
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment