Unverified Commit 66fd3ec7 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

[CI] try to fix GPU OOMs between tests and excessive tqdm logging (#323)

* Fix tqdm and OOM

* tqdm auto

* tqdm is still spamming try to disable it altogether

* rather just set the pipe config, to keep the global tqdm clean

* style
parent 3a536ac8
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import random import random
import tempfile import tempfile
import unittest import unittest
...@@ -77,6 +78,12 @@ def test_progress_bar(capsys): ...@@ -77,6 +78,12 @@ def test_progress_bar(capsys):
class PipelineFastTests(unittest.TestCase): class PipelineFastTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@property @property
def dummy_image(self): def dummy_image(self):
batch_size = 1 batch_size = 1
...@@ -186,6 +193,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -186,6 +193,7 @@ class PipelineFastTests(unittest.TestCase):
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"] image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
...@@ -204,6 +212,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -204,6 +212,7 @@ class PipelineFastTests(unittest.TestCase):
pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device) pndm.to(torch_device)
pndm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = pndm(generator=generator, num_inference_steps=20, output_type="numpy")["sample"] image = pndm(generator=generator, num_inference_steps=20, output_type="numpy")["sample"]
...@@ -222,6 +231,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -222,6 +231,7 @@ class PipelineFastTests(unittest.TestCase):
ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler) ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
ldm.to(torch_device) ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -261,6 +271,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -261,6 +271,7 @@ class PipelineFastTests(unittest.TestCase):
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor,
) )
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
...@@ -293,6 +304,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -293,6 +304,7 @@ class PipelineFastTests(unittest.TestCase):
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor,
) )
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
...@@ -325,6 +337,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -325,6 +337,7 @@ class PipelineFastTests(unittest.TestCase):
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor,
) )
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
...@@ -344,6 +357,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -344,6 +357,7 @@ class PipelineFastTests(unittest.TestCase):
sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler) sde_ve = ScoreSdeVePipeline(unet=unet, scheduler=scheduler)
sde_ve.to(torch_device) sde_ve.to(torch_device)
sde_ve.set_progress_bar_config(disable=None)
torch.manual_seed(0) torch.manual_seed(0)
image = sde_ve(num_inference_steps=2, output_type="numpy")["sample"] image = sde_ve(num_inference_steps=2, output_type="numpy")["sample"]
...@@ -362,6 +376,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -362,6 +376,7 @@ class PipelineFastTests(unittest.TestCase):
ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler) ldm = LDMPipeline(unet=unet, vqvae=vae, scheduler=scheduler)
ldm.to(torch_device) ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"] image = ldm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
...@@ -378,6 +393,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -378,6 +393,7 @@ class PipelineFastTests(unittest.TestCase):
pipe = KarrasVePipeline(unet=unet, scheduler=scheduler) pipe = KarrasVePipeline(unet=unet, scheduler=scheduler)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy")["sample"] image = pipe(num_inference_steps=2, generator=generator, output_type="numpy")["sample"]
...@@ -408,6 +424,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -408,6 +424,7 @@ class PipelineFastTests(unittest.TestCase):
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor,
) )
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
...@@ -451,6 +468,7 @@ class PipelineFastTests(unittest.TestCase): ...@@ -451,6 +468,7 @@ class PipelineFastTests(unittest.TestCase):
feature_extractor=self.dummy_extractor, feature_extractor=self.dummy_extractor,
) )
sd_pipe = sd_pipe.to(device) sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0) generator = torch.Generator(device=device).manual_seed(0)
...@@ -474,6 +492,12 @@ class PipelineFastTests(unittest.TestCase): ...@@ -474,6 +492,12 @@ class PipelineFastTests(unittest.TestCase):
class PipelineTesterMixin(unittest.TestCase): class PipelineTesterMixin(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
# 1. Load models # 1. Load models
model = UNet2DModel( model = UNet2DModel(
...@@ -489,6 +513,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -489,6 +513,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDPMPipeline(model, schedular) ddpm = DDPMPipeline(model, schedular)
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
ddpm.save_pretrained(tmpdirname) ddpm.save_pretrained(tmpdirname)
...@@ -511,8 +536,10 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -511,8 +536,10 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler) ddpm = DDPMPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm_from_hub.to(torch_device) ddpm_from_hub.to(torch_device)
ddpm_from_hub.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -532,9 +559,11 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -532,9 +559,11 @@ class PipelineTesterMixin(unittest.TestCase):
unet = UNet2DModel.from_pretrained(model_path) unet = UNet2DModel.from_pretrained(model_path)
ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler) ddpm_from_hub_custom_model = DiffusionPipeline.from_pretrained(model_path, unet=unet, scheduler=scheduler)
ddpm_from_hub_custom_model.to(torch_device) ddpm_from_hub_custom_model.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler) ddpm_from_hub = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler)
ddpm_from_hub.to(torch_device) ddpm_from_hub.to(torch_device)
ddpm_from_hub_custom_model.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -550,6 +579,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -550,6 +579,7 @@ class PipelineTesterMixin(unittest.TestCase):
pipe = DDIMPipeline.from_pretrained(model_path) pipe = DDIMPipeline.from_pretrained(model_path)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
images = pipe(generator=generator, output_type="numpy")["sample"] images = pipe(generator=generator, output_type="numpy")["sample"]
...@@ -576,6 +606,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -576,6 +606,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"] image = ddpm(generator=generator, output_type="numpy")["sample"]
...@@ -595,6 +626,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -595,6 +626,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler)
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"] image = ddpm(generator=generator, output_type="numpy")["sample"]
...@@ -614,6 +646,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -614,6 +646,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddim = DDIMPipeline(unet=unet, scheduler=scheduler) ddim = DDIMPipeline(unet=unet, scheduler=scheduler)
ddim.to(torch_device) ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"] image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
...@@ -633,6 +666,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -633,6 +666,7 @@ class PipelineTesterMixin(unittest.TestCase):
pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device) pndm.to(torch_device)
pndm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = pndm(generator=generator, output_type="numpy")["sample"] image = pndm(generator=generator, output_type="numpy")["sample"]
...@@ -646,6 +680,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -646,6 +680,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_ldm_text2img(self): def test_ldm_text2img(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device) ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -663,6 +698,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -663,6 +698,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_ldm_text2img_fast(self): def test_ldm_text2img_fast(self):
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256") ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
ldm.to(torch_device) ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -680,6 +716,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -680,6 +716,7 @@ class PipelineTesterMixin(unittest.TestCase):
# make sure here that pndm scheduler skips prk # make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", use_auth_token=True) sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", use_auth_token=True)
sd_pipe = sd_pipe.to(torch_device) sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0) generator = torch.Generator(device=torch_device).manual_seed(0)
...@@ -701,6 +738,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -701,6 +738,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_stable_diffusion_fast_ddim(self): def test_stable_diffusion_fast_ddim(self):
sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", use_auth_token=True) sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", use_auth_token=True)
sd_pipe = sd_pipe.to(torch_device) sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
scheduler = DDIMScheduler( scheduler = DDIMScheduler(
beta_start=0.00085, beta_start=0.00085,
...@@ -733,6 +771,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -733,6 +771,7 @@ class PipelineTesterMixin(unittest.TestCase):
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler) sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
sde_ve.to(torch_device) sde_ve.to(torch_device)
sde_ve.set_progress_bar_config(disable=None)
torch.manual_seed(0) torch.manual_seed(0)
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"] image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
...@@ -748,6 +787,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -748,6 +787,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_ldm_uncond(self): def test_ldm_uncond(self):
ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256") ldm = LDMPipeline.from_pretrained("CompVis/ldm-celebahq-256")
ldm.to(torch_device) ldm.to(torch_device)
ldm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"] image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
...@@ -768,8 +808,10 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -768,8 +808,10 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler) ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device) ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"] ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]
...@@ -790,9 +832,11 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -790,9 +832,11 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler) ddpm = DDPMPipeline(unet=unet, scheduler=ddpm_scheduler)
ddpm.to(torch_device) ddpm.to(torch_device)
ddpm.set_progress_bar_config(disable=None)
ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler) ddim = DDIMPipeline(unet=unet, scheduler=ddim_scheduler)
ddim.to(torch_device) ddim.to(torch_device)
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"] ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]
...@@ -813,6 +857,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -813,6 +857,7 @@ class PipelineTesterMixin(unittest.TestCase):
pipe = KarrasVePipeline(unet=model, scheduler=scheduler) pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"] image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]
...@@ -827,6 +872,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -827,6 +872,7 @@ class PipelineTesterMixin(unittest.TestCase):
def test_lms_stable_diffusion_pipeline(self): def test_lms_stable_diffusion_pipeline(self):
model_id = "CompVis/stable-diffusion-v1-1" model_id = "CompVis/stable-diffusion-v1-1"
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).to(torch_device) pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=True).to(torch_device)
pipe.set_progress_bar_config(disable=None)
scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True) scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler", use_auth_token=True)
pipe.scheduler = scheduler pipe.scheduler = scheduler
...@@ -852,6 +898,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -852,6 +898,7 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, use_auth_token=True) pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id, use_auth_token=True)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "A fantasy landscape, trending on artstation" prompt = "A fantasy landscape, trending on artstation"
...@@ -878,6 +925,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -878,6 +925,7 @@ class PipelineTesterMixin(unittest.TestCase):
model_id = "CompVis/stable-diffusion-v1-4" model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, use_auth_token=True) pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, use_auth_token=True)
pipe.to(torch_device) pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
prompt = "A red cat sitting on a parking bench" prompt = "A red cat sitting on a parking bench"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment