Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9af1b6a8
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4c6728460a4c3439c017d0e4aa36a156eb128a6f"
Unverified
Commit
9af1b6a8
authored
Jun 17, 2024
by
Raushan Turganbay
Committed by
GitHub
Jun 17, 2024
Browse files
Musicgen special tokens in tensors (#31420)
fix
parent
eed9ed67
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
0 deletions
+8
-0
src/transformers/models/musicgen/modeling_musicgen.py
src/transformers/models/musicgen/modeling_musicgen.py
+4
-0
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
...ormers/models/musicgen_melody/modeling_musicgen_melody.py
+4
-0
No files found.
src/transformers/models/musicgen/modeling_musicgen.py
View file @
9af1b6a8
...
@@ -1666,6 +1666,8 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
...
@@ -1666,6 +1666,8 @@ class MusicgenForCausalLM(MusicgenPreTrainedModel):
inputs
,
generation_config
.
bos_token_id
,
model_kwargs
inputs
,
generation_config
.
bos_token_id
,
model_kwargs
)
)
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
kwargs_has_attention_mask
=
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
self
.
_prepare_special_tokens
(
generation_config
,
kwargs_has_attention_mask
,
device
=
input_ids
.
device
)
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
...
@@ -2738,6 +2740,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
...
@@ -2738,6 +2740,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
inputs
,
generation_config
.
bos_token_id
,
model_kwargs
inputs
,
generation_config
.
bos_token_id
,
model_kwargs
)
)
batch_size
=
inputs_tensor
.
shape
[
0
]
batch_size
=
inputs_tensor
.
shape
[
0
]
kwargs_has_attention_mask
=
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
self
.
_prepare_special_tokens
(
generation_config
,
kwargs_has_attention_mask
,
device
=
inputs_tensor
.
device
)
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
...
...
src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
View file @
9af1b6a8
...
@@ -1587,6 +1587,8 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
...
@@ -1587,6 +1587,8 @@ class MusicgenMelodyForCausalLM(MusicgenMelodyPreTrainedModel):
inputs
,
generation_config
.
bos_token_id
,
model_kwargs
inputs
,
generation_config
.
bos_token_id
,
model_kwargs
)
)
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
batch_size
=
input_ids
.
shape
[
0
]
//
self
.
num_codebooks
kwargs_has_attention_mask
=
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
self
.
_prepare_special_tokens
(
generation_config
,
kwargs_has_attention_mask
,
device
=
input_ids
.
device
)
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
...
@@ -2588,6 +2590,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
...
@@ -2588,6 +2590,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
inputs
,
generation_config
.
bos_token_id
,
model_kwargs
inputs
,
generation_config
.
bos_token_id
,
model_kwargs
)
)
batch_size
=
inputs_tensor
.
shape
[
0
]
batch_size
=
inputs_tensor
.
shape
[
0
]
kwargs_has_attention_mask
=
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
not
None
self
.
_prepare_special_tokens
(
generation_config
,
kwargs_has_attention_mask
,
device
=
inputs_tensor
.
device
)
# 4. Define other model kwargs
# 4. Define other model kwargs
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment