Unverified Commit 51843fd7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Refactor full determinism (#3485)

* up

* fix more

* Apply suggestions from code review

* fix more

* fix more

* Check it

* Remove 16:8

* fix more

* fix more

* fix more

* up

* up

* Test only stable diffusion

* Test only two files

* up

* Try out spinning up processes that can be killed

* up

* Apply suggestions from code review

* up

* up
parent 49ad61c2
import contextlib
import copy
import os
import random
from random import random
from typing import Any, Dict, Iterable, Optional, Union
import numpy as np
......@@ -14,26 +13,6 @@ if is_transformers_available():
import transformers
def enable_full_determinism(seed: int):
"""
Helper function for reproducible behavior during distributed training. See
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
"""
# set seed first
set_seed(seed)
# Enable PyTorch deterministic mode. This potentially requires either the environment
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
# depending on the CUDA version, so we set them both here
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)
# Enable CUDNN deterministic mode
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def set_seed(seed: int):
"""
Args:
......
......@@ -514,3 +514,21 @@ class CaptureLogger:
def __repr__(self):
return f"captured: {self.out}\n"
def enable_full_determinism():
"""
Helper function for reproducible behavior during distributed training. See
- https://pytorch.org/docs/stable/notes/randomness.html for pytorch
"""
# Enable PyTorch deterministic mode. This potentially requires either the environment
# variable 'CUDA_LAUNCH_BLOCKING' or 'CUBLAS_WORKSPACE_CONFIG' to be set,
# depending on the CUDA version, so we set them both here
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(True)
# Enable CUDNN deterministic mode
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
......@@ -27,9 +27,6 @@ from diffusers.models.transformer_2d import Transformer2DModel
from diffusers.utils import torch_device
torch.backends.cuda.matmul.allow_tf32 = False
class EmbeddingsTests(unittest.TestCase):
def test_timestep_embeddings(self):
embedding_dim = 256
......
......@@ -23,9 +23,6 @@ from diffusers.utils import floats_tensor, slow, torch_device
from .test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class UNet1DModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet1DModel
......
......@@ -21,13 +21,14 @@ import torch
from diffusers import UNet2DModel
from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device
from diffusers.utils.testing_utils import enable_full_determinism
from .test_modeling_common import ModelTesterMixin
logger = logging.get_logger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
enable_full_determinism()
class Unet2DModelTests(ModelTesterMixin, unittest.TestCase):
......
......@@ -33,13 +33,14 @@ from diffusers.utils import (
torch_device,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism
from .test_modeling_common import ModelTesterMixin
logger = logging.get_logger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
enable_full_determinism()
def create_lora_layers(model, mock_weights: bool = True):
......
......@@ -29,13 +29,14 @@ from diffusers.utils import (
torch_device,
)
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism
from .test_modeling_common import ModelTesterMixin
enable_full_determinism()
logger = logging.get_logger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
def create_lora_layers(model, mock_weights: bool = True):
......
......@@ -22,12 +22,12 @@ from parameterized import parameterized
from diffusers import AutoencoderKL
from diffusers.utils import floats_tensor, load_hf_numpy, require_torch_gpu, slow, torch_all_close, torch_device
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism
from .test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
......
......@@ -19,12 +19,12 @@ import torch
from diffusers import VQModel
from diffusers.utils import floats_tensor, torch_device
from diffusers.utils.testing_utils import enable_full_determinism
from .test_modeling_common import ModelTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
enable_full_determinism()
class VQModelTests(ModelTesterMixin, unittest.TestCase):
......
......@@ -20,11 +20,10 @@ import torch
from diffusers import UNet2DConditionModel
from diffusers.training_utils import EMAModel
from diffusers.utils.testing_utils import skip_mps, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, skip_mps, torch_device
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
enable_full_determinism()
class EMAModelTests(unittest.TestCase):
......
......@@ -26,14 +26,13 @@ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
RobertaSeriesModelWithTransformation,
)
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
enable_full_determinism()
class AltDiffusionPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
......
......@@ -33,11 +33,10 @@ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
RobertaSeriesModelWithTransformation,
)
from diffusers.utils import floats_tensor, load_image, load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
enable_full_determinism()
class AltDiffusionImg2ImgPipelineFastTests(unittest.TestCase):
......
......@@ -30,11 +30,10 @@ from diffusers import (
UNet2DModel,
)
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
enable_full_determinism()
class PipelineFastTests(unittest.TestCase):
......
......@@ -37,13 +37,13 @@ from diffusers import (
UNet2DConditionModel,
)
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
enable_full_determinism()
class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
......
......@@ -32,7 +32,7 @@ from diffusers import (
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.utils import load_image, load_numpy, randn_tensor, slow, torch_device
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import require_torch_gpu
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
......@@ -41,8 +41,7 @@ from ..pipeline_params import (
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
enable_full_determinism()
class ControlNetPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
......
......@@ -35,7 +35,7 @@ from diffusers import (
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.utils import floats_tensor, load_image, load_numpy, randn_tensor, slow, torch_device
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import require_torch_gpu
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
......@@ -44,8 +44,7 @@ from ..pipeline_params import (
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
enable_full_determinism()
class ControlNetImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
......
......@@ -35,7 +35,7 @@ from diffusers import (
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.utils import floats_tensor, load_image, load_numpy, randn_tensor, slow, torch_device
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import require_torch_gpu
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
from ..pipeline_params import (
TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS,
......@@ -44,8 +44,7 @@ from ..pipeline_params import (
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
torch.use_deterministic_algorithms(True)
enable_full_determinism()
class ControlNetInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase):
......
......@@ -21,13 +21,13 @@ import torch
from diffusers import DanceDiffusionPipeline, IPNDMScheduler, UNet1DModel
from diffusers.utils import slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu, skip_mps
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, skip_mps
from ..pipeline_params import UNCONDITIONAL_AUDIO_GENERATION_BATCH_PARAMS, UNCONDITIONAL_AUDIO_GENERATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
enable_full_determinism()
class DanceDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
......
......@@ -19,13 +19,13 @@ import numpy as np
import torch
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
enable_full_determinism()
class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
......
......@@ -19,10 +19,10 @@ import numpy as np
import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils.testing_utils import require_torch_gpu, slow, torch_device
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device
torch.backends.cuda.matmul.allow_tf32 = False
enable_full_determinism()
class DDPMPipelineFastTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment