Unverified Commit c80eda9d authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] Test layerwise casting with training (#10765)

* add a test to check if we can train with layerwise casting.

* updates

* updates

* style
parent 7fb481f8
......@@ -114,6 +114,12 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
def test_set_attn_processor_for_determinism(self):
return
@unittest.skip(
"Test not supported because of 'weight_norm_fwd_first_dim_kernel' not implemented for 'Float8_e4m3fn'"
)
def test_layerwise_casting_training(self):
return super().test_layerwise_casting_training()
@unittest.skip(
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
"cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"
......
......@@ -1338,6 +1338,36 @@ class ModelTesterMixin:
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)
def test_layerwise_casting_training(self):
def test_fn(storage_dtype, compute_dtype):
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
return
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
model.train()
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
with torch.amp.autocast(device_type=torch.device(torch_device).type):
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
test_fn(torch.float16, torch.float32)
test_fn(torch.float8_e4m3fn, torch.float32)
test_fn(torch.float8_e5m2, torch.float32)
test_fn(torch.float8_e4m3fn, torch.bfloat16)
def test_layerwise_casting_inference(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
......
......@@ -60,6 +60,10 @@ class UNet1DModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_training(self):
pass
@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass
def test_determinism(self):
super().test_determinism()
......@@ -239,6 +243,10 @@ class UNetRLModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
def test_training(self):
pass
@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 14,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment