Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
adf1f911
Unverified
Commit
adf1f911
authored
Sep 11, 2024
by
Sayak Paul
Committed by
GitHub
Sep 11, 2024
Browse files
[Tests] fix some fast gpu tests. (#9379)
fix some fast gpu tests.
parent
f28a8c25
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
5 additions
and
2 deletions
+5
-2
examples/dreambooth/train_dreambooth_lora_flux.py
examples/dreambooth/train_dreambooth_lora_flux.py
+2
-0
src/diffusers/models/transformers/transformer_flux.py
src/diffusers/models/transformers/transformer_flux.py
+1
-0
tests/pipelines/flux/test_pipeline_flux_img2img.py
tests/pipelines/flux/test_pipeline_flux_img2img.py
+1
-1
tests/pipelines/flux/test_pipeline_flux_inpaint.py
tests/pipelines/flux/test_pipeline_flux_inpaint.py
+1
-1
No files found.
examples/dreambooth/train_dreambooth_lora_flux.py
View file @
adf1f911
...
@@ -1597,6 +1597,7 @@ def main(args):
...
@@ -1597,6 +1597,7 @@ def main(args):
tokenizers
=
[
None
,
None
],
tokenizers
=
[
None
,
None
],
text_input_ids_list
=
[
tokens_one
,
tokens_two
],
text_input_ids_list
=
[
tokens_one
,
tokens_two
],
max_sequence_length
=
args
.
max_sequence_length
,
max_sequence_length
=
args
.
max_sequence_length
,
device
=
accelerator
.
device
,
prompt
=
prompts
,
prompt
=
prompts
,
)
)
else
:
else
:
...
@@ -1606,6 +1607,7 @@ def main(args):
...
@@ -1606,6 +1607,7 @@ def main(args):
tokenizers
=
[
None
,
None
],
tokenizers
=
[
None
,
None
],
text_input_ids_list
=
[
tokens_one
,
tokens_two
],
text_input_ids_list
=
[
tokens_one
,
tokens_two
],
max_sequence_length
=
args
.
max_sequence_length
,
max_sequence_length
=
args
.
max_sequence_length
,
device
=
accelerator
.
device
,
prompt
=
args
.
instance_prompt
,
prompt
=
args
.
instance_prompt
,
)
)
...
...
src/diffusers/models/transformers/transformer_flux.py
View file @
adf1f911
...
@@ -465,6 +465,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
...
@@ -465,6 +465,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
"Please remove the batch dimension and pass it as a 2d torch Tensor"
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
)
img_ids
=
img_ids
[
0
]
img_ids
=
img_ids
[
0
]
ids
=
torch
.
cat
((
txt_ids
,
img_ids
),
dim
=
0
)
ids
=
torch
.
cat
((
txt_ids
,
img_ids
),
dim
=
0
)
image_rotary_emb
=
self
.
pos_embed
(
ids
)
image_rotary_emb
=
self
.
pos_embed
(
ids
)
...
...
tests/pipelines/flux/test_pipeline_flux_img2img.py
View file @
adf1f911
...
@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin
...
@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism
()
enable_full_determinism
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"Flux has a float64 operation which is not supported in MPS."
)
class
FluxImg2ImgPipelineFastTests
(
unittest
.
TestCase
,
PipelineTesterMixin
):
class
FluxImg2ImgPipelineFastTests
(
unittest
.
TestCase
,
PipelineTesterMixin
):
pipeline_class
=
FluxImg2ImgPipeline
pipeline_class
=
FluxImg2ImgPipeline
params
=
frozenset
([
"prompt"
,
"height"
,
"width"
,
"guidance_scale"
,
"prompt_embeds"
,
"pooled_prompt_embeds"
])
params
=
frozenset
([
"prompt"
,
"height"
,
"width"
,
"guidance_scale"
,
"prompt_embeds"
,
"pooled_prompt_embeds"
])
batch_params
=
frozenset
([
"prompt"
])
batch_params
=
frozenset
([
"prompt"
])
test_xformers_attention
=
False
def
get_dummy_components
(
self
):
def
get_dummy_components
(
self
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
...
tests/pipelines/flux/test_pipeline_flux_inpaint.py
View file @
adf1f911
...
@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin
...
@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism
()
enable_full_determinism
()
@
unittest
.
skipIf
(
torch_device
==
"mps"
,
"Flux has a float64 operation which is not supported in MPS."
)
class
FluxInpaintPipelineFastTests
(
unittest
.
TestCase
,
PipelineTesterMixin
):
class
FluxInpaintPipelineFastTests
(
unittest
.
TestCase
,
PipelineTesterMixin
):
pipeline_class
=
FluxInpaintPipeline
pipeline_class
=
FluxInpaintPipeline
params
=
frozenset
([
"prompt"
,
"height"
,
"width"
,
"guidance_scale"
,
"prompt_embeds"
,
"pooled_prompt_embeds"
])
params
=
frozenset
([
"prompt"
,
"height"
,
"width"
,
"guidance_scale"
,
"prompt_embeds"
,
"pooled_prompt_embeds"
])
batch_params
=
frozenset
([
"prompt"
])
batch_params
=
frozenset
([
"prompt"
])
test_xformers_attention
=
False
def
get_dummy_components
(
self
):
def
get_dummy_components
(
self
):
torch
.
manual_seed
(
0
)
torch
.
manual_seed
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment