Unverified Commit 9a349538 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

Additional Memory clean up for slow tests (#7436)

* update

* update

* update
parent e29f16cf
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import gc
import unittest import unittest
import torch import torch
...@@ -26,6 +27,18 @@ from ..test_pipelines_common import assert_mean_pixel_difference ...@@ -26,6 +27,18 @@ from ..test_pipelines_common import assert_mean_pixel_difference
@nightly @nightly
@require_torch_gpu @require_torch_gpu
class TextToVideoZeroPipelineSlowTests(unittest.TestCase): class TextToVideoZeroPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_full_model(self): def test_full_model(self):
model_id = "runwayml/stable-diffusion-v1-5" model_id = "runwayml/stable-diffusion-v1-5"
pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda") pipe = TextToVideoZeroPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import contextlib import contextlib
import gc
import inspect import inspect
import io import io
import re import re
...@@ -381,6 +382,18 @@ class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, unittest.TestCas ...@@ -381,6 +382,18 @@ class TextToVideoZeroSDXLPipelineFastTests(PipelineTesterMixin, unittest.TestCas
@nightly @nightly
@require_torch_gpu @require_torch_gpu
class TextToVideoZeroSDXLPipelineSlowTests(unittest.TestCase): class TextToVideoZeroSDXLPipelineSlowTests(unittest.TestCase):
def setUp(self):
# clean up the VRAM before each test
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_full_model(self): def test_full_model(self):
model_id = "stabilityai/stable-diffusion-xl-base-1.0" model_id = "stabilityai/stable-diffusion-xl-base-1.0"
pipe = TextToVideoZeroSDXLPipeline.from_pretrained( pipe = TextToVideoZeroSDXLPipeline.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment