Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
9ee3dd38
Unverified
Commit
9ee3dd38
authored
Apr 09, 2025
by
hlky
Committed by
GitHub
Apr 09, 2025
Browse files
AudioLDM2 Fixes (#11244)
parent
fd02aad4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
7 deletions
+22
-7
src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
+8
-5
tests/pipelines/audioldm2/test_audioldm2.py
tests/pipelines/audioldm2/test_audioldm2.py
+14
-2
No files found.
src/diffusers/pipelines/audioldm2/pipeline_audioldm2.py
View file @
9ee3dd38
...
...
@@ -20,7 +20,7 @@ import torch
from
transformers
import
(
ClapFeatureExtractor
,
ClapModel
,
GPT2Model
,
GPT2
LMHead
Model
,
RobertaTokenizer
,
RobertaTokenizerFast
,
SpeechT5HifiGan
,
...
...
@@ -196,7 +196,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
text_encoder
:
ClapModel
,
text_encoder_2
:
Union
[
T5EncoderModel
,
VitsModel
],
projection_model
:
AudioLDM2ProjectionModel
,
language_model
:
GPT2Model
,
language_model
:
GPT2
LMHead
Model
,
tokenizer
:
Union
[
RobertaTokenizer
,
RobertaTokenizerFast
],
tokenizer_2
:
Union
[
T5Tokenizer
,
T5TokenizerFast
,
VitsTokenizer
],
feature_extractor
:
ClapFeatureExtractor
,
...
...
@@ -259,7 +259,10 @@ class AudioLDM2Pipeline(DiffusionPipeline):
)
device_type
=
torch_device
.
type
device
=
torch
.
device
(
f
"
{
device_type
}
:
{
gpu_id
or
torch_device
.
index
}
"
)
device_str
=
device_type
if
gpu_id
or
torch_device
.
index
:
device_str
=
f
"
{
device_str
}
:
{
gpu_id
or
torch_device
.
index
}
"
device
=
torch
.
device
(
device_str
)
if
self
.
device
.
type
!=
"cpu"
:
self
.
to
(
"cpu"
,
silence_dtype_warnings
=
True
)
...
...
@@ -316,9 +319,9 @@ class AudioLDM2Pipeline(DiffusionPipeline):
model_inputs
=
prepare_inputs_for_generation
(
inputs_embeds
,
**
model_kwargs
)
# forward pass to get next hidden states
output
=
self
.
language_model
(
**
model_inputs
,
return_dict
=
True
)
output
=
self
.
language_model
(
**
model_inputs
,
output_hidden_states
=
True
,
return_dict
=
True
)
next_hidden_states
=
output
.
last_
hidden_state
next_hidden_states
=
output
.
hidden_state
s
[
-
1
]
# Update the model input
inputs_embeds
=
torch
.
cat
([
inputs_embeds
,
next_hidden_states
[:,
-
1
:,
:]],
dim
=
1
)
...
...
tests/pipelines/audioldm2/test_audioldm2.py
View file @
9ee3dd38
...
...
@@ -26,7 +26,7 @@ from transformers import (
ClapModel
,
ClapTextConfig
,
GPT2Config
,
GPT2Model
,
GPT2
LMHead
Model
,
RobertaTokenizer
,
SpeechT5HifiGan
,
SpeechT5HifiGanConfig
,
...
...
@@ -162,7 +162,7 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
n_ctx
=
99
,
n_positions
=
99
,
)
language_model
=
GPT2Model
(
language_model_config
)
language_model
=
GPT2
LMHead
Model
(
language_model_config
)
language_model
.
config
.
max_new_tokens
=
8
torch
.
manual_seed
(
0
)
...
...
@@ -516,6 +516,18 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def
test_encode_prompt_works_in_isolation
(
self
):
pass
@
unittest
.
skip
(
"Not supported yet due to CLAPModel."
)
def
test_sequential_offload_forward_pass_twice
(
self
):
pass
@
unittest
.
skip
(
"Not supported yet, the second forward has mixed devices and `vocoder` is not offloaded."
)
def
test_cpu_offload_forward_pass_twice
(
self
):
pass
@
unittest
.
skip
(
"Not supported yet. `vocoder` is not offloaded."
)
def
test_model_cpu_offload_forward_pass
(
self
):
pass
@
nightly
class
AudioLDM2PipelineSlowTests
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment