Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dfbd209c
Unverified
Commit
dfbd209c
authored
Nov 28, 2023
by
Susnato Dhar
Committed by
GitHub
Nov 28, 2023
Browse files
CLVP Fixes (#27547)
* fixes * more fixes * style fix * more fix * comments
parent
30e92ea3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
106 additions
and
32 deletions
+106
-32
src/transformers/models/clvp/modeling_clvp.py
src/transformers/models/clvp/modeling_clvp.py
+99
-18
tests/models/clvp/test_modeling_clvp.py
tests/models/clvp/test_modeling_clvp.py
+7
-14
No files found.
src/transformers/models/clvp/modeling_clvp.py
View file @
dfbd209c
...
...
@@ -81,8 +81,7 @@ def rotate_half(x):
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
,
unsqueeze_dim
=
1
):
def
apply_rotary_pos_emb
(
q
,
k
,
v
,
cos
,
sin
,
position_ids
,
unsqueeze_dim
=
1
):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
...
...
@@ -107,7 +106,51 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
sin
=
sin
[
position_ids
].
unsqueeze
(
unsqueeze_dim
)
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
v_embed
=
(
v
*
cos
)
+
(
rotate_half
(
v
)
*
sin
)
return
q_embed
,
k_embed
,
v_embed
def
_pad_extra_bos_eos_tokens
(
input_ids
,
attention_mask
=
None
,
pad_token_id
=
0
,
bos_token_id
=
255
,
eos_token_id
=
0
,
add_bos_token
=
True
,
add_eos_token
=
True
,
):
"""
This method adds extra bos and eos tokens to input_ids and accordingly modifies the attention_mask which is used in
`ClvpConditioningEncoder` and the generation loop of the `ClvpModelForConditionalGeneration`.
"""
# add the bos token at the beginning
if
add_bos_token
:
input_ids
=
torch
.
nn
.
functional
.
pad
(
input_ids
,
(
1
,
0
),
value
=
bos_token_id
)
attention_mask
=
(
torch
.
nn
.
functional
.
pad
(
attention_mask
,
(
1
,
0
),
value
=
1
)
if
attention_mask
is
not
None
else
attention_mask
)
modified_input_ids
=
input_ids
if
add_eos_token
:
modified_input_ids
=
torch
.
zeros
(
(
input_ids
.
shape
[
0
],
input_ids
.
shape
[
1
]
+
1
),
dtype
=
input_ids
.
dtype
,
device
=
input_ids
.
device
)
for
i
,
each_input_id
in
enumerate
(
input_ids
):
# locate where the valid tokens end and then add the eos token
if
torch
.
isin
(
each_input_id
,
pad_token_id
).
sum
():
pos
=
torch
.
where
(
each_input_id
==
pad_token_id
)[
0
].
min
()
modified_input_ids
[
i
]
=
torch
.
concatenate
(
[
each_input_id
[:
pos
],
torch
.
tensor
([
eos_token_id
],
device
=
input_ids
.
device
),
each_input_id
[
pos
:]]
)
else
:
# if there are no pad tokens present, then add eos to the end
modified_input_ids
[
i
]
=
torch
.
nn
.
functional
.
pad
(
each_input_id
,
(
0
,
1
),
value
=
eos_token_id
)
attention_mask
=
(
torch
.
nn
.
functional
.
pad
(
attention_mask
,
(
1
,
0
),
value
=
1
)
if
attention_mask
is
not
None
else
attention_mask
)
return
modified_input_ids
,
attention_mask
@
dataclass
...
...
@@ -312,13 +355,18 @@ class ClvpSelfAttention(nn.Module):
key_states
[...,
:
rotary_emb_dim
],
key_states
[...,
rotary_emb_dim
:],
)
value_rot
,
value_pass
=
(
value_states
[...,
:
rotary_emb_dim
],
value_states
[...,
rotary_emb_dim
:],
)
cos
,
sin
=
rotary_pos_emb
.
cos
().
squeeze
(
0
),
rotary_pos_emb
.
sin
().
squeeze
(
0
)
query_rot
,
key_rot
=
apply_rotary_pos_emb
(
query_rot
,
key_rot
,
cos
,
sin
,
position_ids
)
query_rot
,
key_rot
,
value_rot
=
apply_rotary_pos_emb
(
query_rot
,
key_rot
,
value_rot
,
cos
,
sin
,
position_ids
)
# [batch_size, num_heads, seq_length, head_dim]
query_states
=
torch
.
cat
((
query_rot
,
query_pass
),
dim
=-
1
)
key_states
=
torch
.
cat
((
key_rot
,
key_pass
),
dim
=-
1
)
value_states
=
torch
.
cat
((
value_rot
,
value_pass
),
dim
=-
1
)
tgt_len
=
query_states
.
shape
[
2
]
src_len
=
key_states
.
shape
[
2
]
...
...
@@ -599,16 +647,7 @@ class ClvpConditioningEncoder(nn.Module):
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
# We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
# This logic is specific to ClvpConditioningEncoder and not used by other modules.
input_ids
=
torch
.
nn
.
functional
.
pad
(
input_ids
,
(
1
,
0
),
value
=
self
.
text_config
.
bos_token_id
)
input_ids
=
torch
.
nn
.
functional
.
pad
(
input_ids
,
(
0
,
1
),
value
=
self
.
text_config
.
eos_token_id
)
batch_size
,
seq_length
=
input_ids
.
size
()
inputs_embeds
=
self
.
text_token_embedding
(
input_ids
)
# check if we need to update attention mask, if yes then pad it too
if
attention_mask
is
not
None
and
attention_mask
.
shape
[
1
]
!=
seq_length
:
attention_mask
=
torch
.
nn
.
functional
.
pad
(
attention_mask
,
(
1
,
0
),
value
=
1
)
attention_mask
=
torch
.
nn
.
functional
.
pad
(
attention_mask
,
(
0
,
1
),
value
=
1
)
elif
inputs_embeds
is
not
None
:
batch_size
,
seq_length
=
inputs_embeds
.
size
()[:
-
1
]
else
:
...
...
@@ -616,8 +655,18 @@ class ClvpConditioningEncoder(nn.Module):
# construct attention mask if not given
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
([
batch_size
,
seq_length
],
dtype
=
torch
.
long
,
device
=
input
s_embe
ds
.
device
)
attention_mask
=
torch
.
ones
([
batch_size
,
seq_length
],
dtype
=
torch
.
long
,
device
=
input
_i
ds
.
device
)
# We add bos and eos input_ids in the modeling file instead of the tokenizer file to keep the logic simple
# This logic is specific to ClvpConditioningEncoder and not used by other modules.
input_ids
,
attention_mask
=
_pad_extra_bos_eos_tokens
(
input_ids
,
attention_mask
,
bos_token_id
=
self
.
text_config
.
bos_token_id
,
eos_token_id
=
self
.
text_config
.
eos_token_id
,
)
inputs_embeds
=
self
.
text_token_embedding
(
input_ids
)
position_ids
=
attention_mask
.
cumsum
(
-
1
)
-
1
position_embeds
=
self
.
text_position_embedding
(
position_ids
)
text_embeds
=
inputs_embeds
+
position_embeds
...
...
@@ -1512,10 +1561,6 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
"""
decoder_fixing_codes
=
self
.
config
.
decoder_config
.
decoder_fixing_codes
speech_ids
=
speech_ids
[:,
1
:]
if
torch
.
isin
(
self
.
speech_decoder_model
.
config
.
eos_token_id
,
speech_ids
):
speech_ids
=
torch
.
nn
.
functional
.
pad
(
speech_ids
,
pad
=
(
0
,
1
),
value
=
self
.
speech_decoder_model
.
config
.
eos_token_id
)
stop_token_indices
=
torch
.
where
(
speech_ids
==
self
.
speech_decoder_model
.
config
.
eos_token_id
,
1
,
0
)
speech_ids
=
torch
.
masked_fill
(
speech_ids
,
mask
=
stop_token_indices
.
bool
(),
value
=
decoder_fixing_codes
[
0
])
...
...
@@ -1828,6 +1873,7 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
input_features
:
torch
.
FloatTensor
=
None
,
attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
generation_config
:
Optional
[
GenerationConfig
]
=
None
,
pad_to_max_mel_tokens
:
Optional
[
int
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
**
kwargs
,
):
...
...
@@ -1855,6 +1901,11 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
pad_to_max_mel_tokens (`int`, *optional*):
Pads generated speech_ids to the specified value. This is to implement the same logic from the official
repo, link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
and to make sure the logits are same.
This does not affect generation quality so please don't consider using it since it is less efficient.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of decoder model, text encoder and speech encoder models.
...
...
@@ -1862,6 +1913,17 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
`ClvpOutput` or tuple: A `ClvpOutput` (if `return_dict_in_generate=True` or when
`config.return_dict_in_generate=True`) or a tuple.
"""
# If the input sequences are larger than (self.config.decoder_config.max_text_tokens - 3) then raise error,
# because we need to add 3 tokens ( 1 bos tokens and 2 eos tokens) to the input_ids in ClvpConditioningEncoder to
# properly sample
sequence_length
=
input_ids
.
shape
[
-
1
]
if
sequence_length
>
(
self
.
config
.
decoder_config
.
max_text_tokens
-
3
):
raise
ValueError
(
f
"Maximum sequence length reached! Found input_ids of length
{
sequence_length
}
."
f
"Please make sure that the maximum length of input_ids is
{
self
.
config
.
decoder_config
.
max_text_tokens
-
3
}
"
)
if
generation_config
is
None
:
generation_config
=
self
.
generation_config
...
...
@@ -1870,6 +1932,16 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
generation_config
.
validate
()
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
# pad input_ids as specified in the original repo
# link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L380
input_ids
,
attention_mask
=
_pad_extra_bos_eos_tokens
(
input_ids
,
attention_mask
,
add_bos_token
=
False
,
bos_token_id
=
self
.
config
.
text_config
.
bos_token_id
,
eos_token_id
=
self
.
config
.
text_config
.
eos_token_id
,
)
conditioning_embeds
=
self
.
conditioning_encoder
(
input_features
=
input_features
,
input_ids
=
input_ids
,
...
...
@@ -1884,6 +1956,15 @@ class ClvpModelForConditionalGeneration(ClvpPreTrainedModel):
)
if
isinstance
(
decoder_outputs
,
ModelOutput
):
speech_ids
=
decoder_outputs
.
sequences
# pad to pad_to_max_mel_tokens if given, to replicate the original repo logic
# link: https://github.com/neonbjb/tortoise-tts/blob/80f89987a5abda5e2b082618cd74f9c7411141dc/tortoise/api.py#L430
if
pad_to_max_mel_tokens
is
not
None
:
padding_needed
=
pad_to_max_mel_tokens
-
speech_ids
.
shape
[
-
1
]
speech_ids
=
torch
.
nn
.
functional
.
pad
(
speech_ids
,
(
0
,
padding_needed
),
value
=
self
.
generation_config
.
eos_token_id
)
speech_ids
=
self
.
fix_speech_decoder_output
(
speech_ids
)
speech_outputs
=
self
.
speech_encoder_model
(
...
...
tests/models/clvp/test_modeling_clvp.py
View file @
dfbd209c
...
...
@@ -604,12 +604,7 @@ class ClvpIntegrationTest(unittest.TestCase):
text_embeds
=
self
.
model
.
text_encoder_model
(
input_ids
=
self
.
text_tokens
,
return_dict
=
True
)[
0
].
cpu
()
# fmt: off
EXPECTED_TEXT_EMBEDS
=
torch
.
tensor
(
[
1.8060e+00
,
-
2.7928e+00
,
3.2021e+00
,
-
1.5673e+00
,
2.3284e+00
,
-
3.2065e+00
,
-
1.3368e+00
,
2.2322e+00
,
-
1.7667e+00
,
4.1505e-01
,
2.4119e+00
,
-
5.8133e-03
,
-
4.6367e+00
,
1.6450e-01
,
6.7459e+00
,
6.6292e+00
,
1.1046e+00
,
3.6196e+00
,
-
1.0496e+01
,
5.4924e+00
]
)
EXPECTED_TEXT_EMBEDS
=
torch
.
tensor
([
1.4798
,
-
2.0005
,
2.3902
,
-
0.5042
,
1.6401
,
-
2.4135
,
-
1.4800
,
3.0118
,
-
2.4422
,
1.3266
,
2.2339
,
1.4761
,
-
4.8983
,
-
1.3592
,
6.0251
,
6.7364
,
2.2576
,
3.7229
,
-
10.0436
,
4.6676
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
text_embeds
[
0
,
:
20
],
EXPECTED_TEXT_EMBEDS
,
atol
=
1e-4
))
...
...
@@ -618,11 +613,7 @@ class ClvpIntegrationTest(unittest.TestCase):
speech_embeds
=
self
.
model
.
speech_encoder_model
(
input_ids
=
self
.
text_tokens
,
return_dict
=
True
)[
0
].
cpu
()
# fmt: off
EXPECTED_SPEECH_EMBEDS
=
torch
.
tensor
(
[
4.6143
,
-
5.5784
,
0.8983
,
-
3.9665
,
-
0.6714
,
-
1.0665
,
-
1.1277
,
1.5619
,
2.6322
,
-
7.2008
,
-
2.4932
,
0.3265
,
-
1.4738
,
0.1425
,
5.0825
,
4.1760
,
-
5.4708
,
2.1935
,
-
6.0044
,
3.9540
]
)
EXPECTED_SPEECH_EMBEDS
=
torch
.
tensor
([
3.1202
,
-
3.1183
,
-
1.4264
,
-
6.1339
,
1.8885
,
-
0.1983
,
0.9461
,
-
1.7414
,
0.3320
,
-
3.8400
,
-
1.5715
,
1.5096
,
-
1.7576
,
0.2387
,
4.9758
,
5.8450
,
-
6.2534
,
2.8587
,
-
5.5816
,
4.7821
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
speech_embeds
[
0
,
:
20
],
EXPECTED_SPEECH_EMBEDS
,
atol
=
1e-4
))
...
...
@@ -635,8 +626,10 @@ class ClvpIntegrationTest(unittest.TestCase):
num_beams
=
4
,
num_return_sequences
=
4
,
max_new_tokens
=
10
,
)
.
speech_ids
.
cpu
()
)
EXPECTED_OUTPUTS
=
torch
.
tensor
([[
1953
,
1080
,
612
],
[
1953
,
1953
,
612
],
[
1953
,
612
,
716
]])
EXPECTED_SPEECH_IDS
=
torch
.
tensor
([[
1953
,
1080
,
612
],
[
1953
,
612
,
493
],
[
1953
,
612
,
716
]])
EXPECTED_SIMILARITY_SCORES
=
torch
.
tensor
([[
14.7660
,
14.4569
,
13.6472
,
13.5683
]])
self
.
assertTrue
(
torch
.
allclose
(
full_model_output
[
-
3
:,
-
3
:],
EXPECTED_OUTPUTS
))
self
.
assertTrue
(
torch
.
allclose
(
full_model_output
.
speech_ids
.
cpu
()[
-
3
:,
-
3
:],
EXPECTED_SPEECH_IDS
))
self
.
assertTrue
(
torch
.
allclose
(
full_model_output
.
logits_per_text
.
cpu
(),
EXPECTED_SIMILARITY_SCORES
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment