Unverified Commit fbddf028 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] properly skip tests instead of `return` (#11771)

model test updates
parent f20b83a0
...@@ -30,6 +30,7 @@ from collections import defaultdict ...@@ -30,6 +30,7 @@ from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import pytest
import requests_mock import requests_mock
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -938,8 +939,9 @@ class ModelTesterMixin: ...@@ -938,8 +939,9 @@ class ModelTesterMixin:
@require_torch_accelerator_with_training @require_torch_accelerator_with_training
def test_enable_disable_gradient_checkpointing(self): def test_enable_disable_gradient_checkpointing(self):
# Skip test if model does not support gradient checkpointing
if not self.model_class._supports_gradient_checkpointing: if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing pytest.skip("Gradient checkpointing is not supported.")
init_dict, _ = self.prepare_init_args_and_inputs_for_common() init_dict, _ = self.prepare_init_args_and_inputs_for_common()
...@@ -957,8 +959,9 @@ class ModelTesterMixin: ...@@ -957,8 +959,9 @@ class ModelTesterMixin:
@require_torch_accelerator_with_training @require_torch_accelerator_with_training
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}): def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
# Skip test if model does not support gradient checkpointing
if not self.model_class._supports_gradient_checkpointing: if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing pytest.skip("Gradient checkpointing is not supported.")
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -1015,8 +1018,9 @@ class ModelTesterMixin: ...@@ -1015,8 +1018,9 @@ class ModelTesterMixin:
def test_gradient_checkpointing_is_applied( def test_gradient_checkpointing_is_applied(
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
): ):
# Skip test if model does not support gradient checkpointing
if not self.model_class._supports_gradient_checkpointing: if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing pytest.skip("Gradient checkpointing is not supported.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -1073,7 +1077,7 @@ class ModelTesterMixin: ...@@ -1073,7 +1077,7 @@ class ModelTesterMixin:
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin): if not issubclass(model.__class__, PeftAdapterMixin):
return pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
torch.manual_seed(0) torch.manual_seed(0)
output_no_lora = model(**inputs_dict, return_dict=False)[0] output_no_lora = model(**inputs_dict, return_dict=False)[0]
...@@ -1128,7 +1132,7 @@ class ModelTesterMixin: ...@@ -1128,7 +1132,7 @@ class ModelTesterMixin:
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin): if not issubclass(model.__class__, PeftAdapterMixin):
return pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig( denoiser_lora_config = LoraConfig(
r=4, r=4,
...@@ -1159,7 +1163,7 @@ class ModelTesterMixin: ...@@ -1159,7 +1163,7 @@ class ModelTesterMixin:
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin): if not issubclass(model.__class__, PeftAdapterMixin):
return pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig( denoiser_lora_config = LoraConfig(
r=rank, r=rank,
...@@ -1196,7 +1200,7 @@ class ModelTesterMixin: ...@@ -1196,7 +1200,7 @@ class ModelTesterMixin:
model = self.model_class(**init_dict).to(torch_device) model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin): if not issubclass(model.__class__, PeftAdapterMixin):
return pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).")
denoiser_lora_config = LoraConfig( denoiser_lora_config = LoraConfig(
r=4, r=4,
...@@ -1233,10 +1237,10 @@ class ModelTesterMixin: ...@@ -1233,10 +1237,10 @@ class ModelTesterMixin:
@require_torch_accelerator @require_torch_accelerator
def test_cpu_offload(self): def test_cpu_offload(self):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
if model._no_split_modules is None:
return
model = model.to(torch_device) model = model.to(torch_device)
...@@ -1263,10 +1267,10 @@ class ModelTesterMixin: ...@@ -1263,10 +1267,10 @@ class ModelTesterMixin:
@require_torch_accelerator @require_torch_accelerator
def test_disk_offload_without_safetensors(self): def test_disk_offload_without_safetensors(self):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
if model._no_split_modules is None:
return
model = model.to(torch_device) model = model.to(torch_device)
...@@ -1296,10 +1300,10 @@ class ModelTesterMixin: ...@@ -1296,10 +1300,10 @@ class ModelTesterMixin:
@require_torch_accelerator @require_torch_accelerator
def test_disk_offload_with_safetensors(self): def test_disk_offload_with_safetensors(self):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
if model._no_split_modules is None:
return
model = model.to(torch_device) model = model.to(torch_device)
...@@ -1324,10 +1328,10 @@ class ModelTesterMixin: ...@@ -1324,10 +1328,10 @@ class ModelTesterMixin:
@require_torch_multi_accelerator @require_torch_multi_accelerator
def test_model_parallelism(self): def test_model_parallelism(self):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
if model._no_split_modules is None:
return
model = model.to(torch_device) model = model.to(torch_device)
...@@ -1426,10 +1430,10 @@ class ModelTesterMixin: ...@@ -1426,10 +1430,10 @@ class ModelTesterMixin:
@require_torch_accelerator @require_torch_accelerator
def test_sharded_checkpoints_device_map(self): def test_sharded_checkpoints_device_map(self):
if self.model_class._no_split_modules is None:
pytest.skip("Test not supported for this model as `_no_split_modules` is not set.")
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
if model._no_split_modules is None:
return
model = model.to(torch_device) model = model.to(torch_device)
torch.manual_seed(0) torch.manual_seed(0)
...@@ -1497,7 +1501,7 @@ class ModelTesterMixin: ...@@ -1497,7 +1501,7 @@ class ModelTesterMixin:
def test_layerwise_casting_training(self): def test_layerwise_casting_training(self):
def test_fn(storage_dtype, compute_dtype): def test_fn(storage_dtype, compute_dtype):
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16: if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
return pytest.skip("Skipping test because CPU doesn't go well with bfloat16.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
...@@ -1617,6 +1621,9 @@ class ModelTesterMixin: ...@@ -1617,6 +1621,9 @@ class ModelTesterMixin:
@parameterized.expand([False, True]) @parameterized.expand([False, True])
@require_torch_accelerator @require_torch_accelerator
def test_group_offloading(self, record_stream): def test_group_offloading(self, record_stream):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0) torch.manual_seed(0)
...@@ -1633,8 +1640,6 @@ class ModelTesterMixin: ...@@ -1633,8 +1640,6 @@ class ModelTesterMixin:
return model(**inputs_dict)[0] return model(**inputs_dict)[0]
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
model.to(torch_device) model.to(torch_device)
output_without_group_offloading = run_forward(model) output_without_group_offloading = run_forward(model)
...@@ -1670,13 +1675,13 @@ class ModelTesterMixin: ...@@ -1670,13 +1675,13 @@ class ModelTesterMixin:
@require_torch_accelerator @require_torch_accelerator
@torch.no_grad() @torch.no_grad()
def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type): def test_group_offloading_with_layerwise_casting(self, record_stream, offload_type):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
torch.manual_seed(0) torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
_ = model(**inputs_dict)[0] _ = model(**inputs_dict)[0]
...@@ -1698,13 +1703,13 @@ class ModelTesterMixin: ...@@ -1698,13 +1703,13 @@ class ModelTesterMixin:
@require_torch_accelerator @require_torch_accelerator
@torch.no_grad() @torch.no_grad()
def test_group_offloading_with_disk(self, record_stream, offload_type): def test_group_offloading_with_disk(self, record_stream, offload_type):
if not self.model_class._supports_group_offloading:
pytest.skip("Model does not support group offloading.")
torch.manual_seed(0) torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
if not getattr(model, "_supports_group_offloading", True):
return
torch.manual_seed(0) torch.manual_seed(0)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment