Unverified Commit aa5f5d41 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] add tests to check for graph breaks, recompilation, cuda syncs in...

[tests] add tests to check for graph breaks, recompilation, cuda syncs in pipelines during torch.compile() (#11085)

* test for better torch.compile stuff.

* fixes

* recompilation and graph break.

* clear compilation cache.

* change to modeling level test.

* allow running compilation tests during nightlies.
parent bd96a084
......@@ -180,6 +180,55 @@ jobs:
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_torch_compile_tests:
name: PyTorch Compile CUDA tests
runs-on:
group: aws-g4dn-2xlarge
container:
image: diffusers/diffusers-pytorch-compile-cuda
options: --gpus 0 --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
run: |
nvidia-smi
- name: Install dependencies
run: |
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
python -m uv pip install -e [quality,test,training]
- name: Environment
run: |
python utils/print_env.py
- name: Run torch compile tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
run: |
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile -s -v -k "compile" --make-reports=tests_torch_compile_cuda tests/
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_compile_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
with:
name: torch_compile_test_reports
path: reports
- name: Generate Report and Notify Channel
if: always()
run: |
pip install slack_sdk tabulate
python utils/log_reports.py >> $GITHUB_STEP_SUMMARY
run_big_gpu_torch_tests:
name: Torch tests on big GPU
strategy:
......
......@@ -335,7 +335,7 @@ jobs:
- name: Environment
run: |
python utils/print_env.py
- name: Run example tests on GPU
- name: Run torch compile tests on GPU
env:
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
RUN_COMPILE: yes
......
......@@ -1714,6 +1714,37 @@ class ModelPushToHubTester(unittest.TestCase):
delete_repo(self.repo_id, token=TOKEN)
class TorchCompileTesterMixin:
def setUp(self):
# clean up the VRAM before each test
super().setUp()
torch._dynamo.reset()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch._dynamo.reset()
gc.collect()
backend_empty_cache(torch_device)
@require_torch_gpu
@require_torch_2
@is_torch_compile
@slow
def test_torch_compile_recompilation_and_graph_break(self):
torch._dynamo.reset()
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
model = torch.compile(model, fullgraph=True)
with torch._dynamo.config.patch(error_on_recompile=True), torch.no_grad():
_ = model(**inputs_dict)
_ = model(**inputs_dict)
@slow
@require_torch_2
@require_torch_accelerator
......
......@@ -22,7 +22,7 @@ from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
......@@ -78,7 +78,7 @@ def create_flux_ip_adapter_state_dict(model):
return ip_state_dict
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
class FluxTransformerTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
......
......@@ -1111,12 +1111,14 @@ class PipelineTesterMixin:
def setUp(self):
# clean up the VRAM before each test
super().setUp()
torch._dynamo.reset()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
# clean up the VRAM after each test in case of CUDA runtime errors
super().tearDown()
torch._dynamo.reset()
gc.collect()
backend_empty_cache(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment