"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "312c3d3287aa9936e79a374819ff466e0517c5d1"
Unverified Commit adf1f911 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] fix some fast gpu tests. (#9379)

fix some fast gpu tests.
parent f28a8c25
...@@ -1597,6 +1597,7 @@ def main(args): ...@@ -1597,6 +1597,7 @@ def main(args):
tokenizers=[None, None], tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two], text_input_ids_list=[tokens_one, tokens_two],
max_sequence_length=args.max_sequence_length, max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=prompts, prompt=prompts,
) )
else: else:
...@@ -1606,6 +1607,7 @@ def main(args): ...@@ -1606,6 +1607,7 @@ def main(args):
tokenizers=[None, None], tokenizers=[None, None],
text_input_ids_list=[tokens_one, tokens_two], text_input_ids_list=[tokens_one, tokens_two],
max_sequence_length=args.max_sequence_length, max_sequence_length=args.max_sequence_length,
device=accelerator.device,
prompt=args.instance_prompt, prompt=args.instance_prompt,
) )
......
...@@ -465,6 +465,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -465,6 +465,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
"Please remove the batch dimension and pass it as a 2d torch Tensor" "Please remove the batch dimension and pass it as a 2d torch Tensor"
) )
img_ids = img_ids[0] img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0) ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids) image_rotary_emb = self.pos_embed(ids)
......
...@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin ...@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism() enable_full_determinism()
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxImg2ImgPipeline pipeline_class = FluxImg2ImgPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"]) batch_params = frozenset(["prompt"])
test_xformers_attention = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
...@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin ...@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism() enable_full_determinism()
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
pipeline_class = FluxInpaintPipeline pipeline_class = FluxInpaintPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"]) batch_params = frozenset(["prompt"])
test_xformers_attention = False
def get_dummy_components(self): def get_dummy_components(self):
torch.manual_seed(0) torch.manual_seed(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment