"sgl-kernel/git@developer.sourcefind.cn:change/sglang.git" did not exist on "616a3e20df5faf73a7d1581c7016413f64d583e4"
Unverified Commit ee2f2775 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[tests] use parent class for monkey patching to not break other tests (#4088)

* [tests] use parent class for monkey patching to not break other tests

* fix
parent 692b7a90
...@@ -230,10 +230,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -230,10 +230,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
pipe_2 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device) pipe_2 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device)
pipe_2.unet.set_default_attn_processor() pipe_2.unet.set_default_attn_processor()
def assert_run_mixture(num_steps, split, scheduler_cls): def assert_run_mixture(num_steps, split, scheduler_cls_orig):
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = num_steps inputs["num_inference_steps"] = num_steps
class scheduler_cls(scheduler_cls_orig):
pass
pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config) pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config) pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
...@@ -287,10 +290,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest ...@@ -287,10 +290,13 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
pipe_3 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device) pipe_3 = StableDiffusionXLImg2ImgPipeline(**components).to(torch_device)
pipe_3.unet.set_default_attn_processor() pipe_3.unet.set_default_attn_processor()
def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls): def assert_run_mixture(num_steps, split_1, split_2, scheduler_cls_orig):
inputs = self.get_dummy_inputs(torch_device) inputs = self.get_dummy_inputs(torch_device)
inputs["num_inference_steps"] = num_steps inputs["num_inference_steps"] = num_steps
class scheduler_cls(scheduler_cls_orig):
pass
pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config) pipe_1.scheduler = scheduler_cls.from_config(pipe_1.scheduler.config)
pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config) pipe_2.scheduler = scheduler_cls.from_config(pipe_2.scheduler.config)
pipe_3.scheduler = scheduler_cls.from_config(pipe_3.scheduler.config) pipe_3.scheduler = scheduler_cls.from_config(pipe_3.scheduler.config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment