Unverified Commit 494c2a8c authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Clean up semantic segmentation tests (#16801)


Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 989a15d1
...@@ -244,13 +244,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -244,13 +244,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
# we don't test BeitForMaskedImageModeling # we don't test BeitForMaskedImageModeling
if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]: if model_class in [*get_values(MODEL_MAPPING), BeitForMaskedImageModeling]:
continue continue
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
elif model_class.__name__ == "BeitForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros(
[self.model_tester.batch_size, height, width], device=torch_device
).long()
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
......
...@@ -316,13 +316,7 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -316,13 +316,7 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING): if model_class in get_values(MODEL_MAPPING):
continue continue
# TODO: remove the following 3 lines once we have a MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING
# this can then be incorporated into _prepare_for_class in test_modeling_common.py
if model_class.__name__ == "SegformerForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros(
[self.model_tester.batch_size, height, width], device=torch_device
).long()
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment