Unverified Commit 7b100ce5 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] conditionally check `fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory` (#10669)

* conditionally check if compute capability is met.

* log info.

* fix condition.

* updates

* updates

* updates

* updates
parent c4d4ac21
...@@ -149,3 +149,13 @@ def apply_freeu( ...@@ -149,3 +149,13 @@ def apply_freeu(
res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"]) res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"])
return hidden_states, res_hidden_states return hidden_states, res_hidden_states
def get_torch_cuda_device_capability():
if torch.cuda.is_available():
device = torch.device("cuda")
compute_capability = torch.cuda.get_device_capability(device)
compute_capability = f"{compute_capability[0]}.{compute_capability[1]}"
return float(compute_capability)
else:
return None
...@@ -68,6 +68,7 @@ from diffusers.utils.testing_utils import ( ...@@ -68,6 +68,7 @@ from diffusers.utils.testing_utils import (
torch_all_close, torch_all_close,
torch_device, torch_device,
) )
from diffusers.utils.torch_utils import get_torch_cuda_device_capability
from ..others.test_utils import TOKEN, USER, is_staging_test from ..others.test_utils import TOKEN, USER, is_staging_test
...@@ -1384,6 +1385,7 @@ class ModelTesterMixin: ...@@ -1384,6 +1385,7 @@ class ModelTesterMixin:
@require_torch_gpu @require_torch_gpu
def test_layerwise_casting_memory(self): def test_layerwise_casting_memory(self):
MB_TOLERANCE = 0.2 MB_TOLERANCE = 0.2
LEAST_COMPUTE_CAPABILITY = 8.0
def reset_memory_stats(): def reset_memory_stats():
gc.collect() gc.collect()
...@@ -1412,9 +1414,11 @@ class ModelTesterMixin: ...@@ -1412,9 +1414,11 @@ class ModelTesterMixin:
torch.float8_e4m3fn, torch.bfloat16 torch.float8_e4m3fn, torch.bfloat16
) )
compute_capability = get_torch_cuda_device_capability()
self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint) self.assertTrue(fp8_e4m3_bf16_memory_footprint < fp8_e4m3_fp32_memory_footprint < fp32_memory_footprint)
# NOTE: the following assertion will fail on our CI (running Tesla T4) due to bf16 using more memory than fp32. # NOTE: the following assertion would fail on our CI (running Tesla T4) due to bf16 using more memory than fp32.
# On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. # On other devices, such as DGX (Ampere) and Audace (Ada), the test passes. So, we conditionally check it.
if compute_capability and compute_capability >= LEAST_COMPUTE_CAPABILITY:
self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory) self.assertTrue(fp8_e4m3_bf16_max_memory < fp8_e4m3_fp32_max_memory)
# On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few # On this dummy test case with a small model, sometimes fp8_e4m3_fp32 max memory usage is higher than fp32 by a few
# bytes. This only happens for some models, so we allow a small tolerance. # bytes. This only happens for some models, so we allow a small tolerance.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment