"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a42844955f3134a2b2fd8075353c710dcfb7b362"
Unverified Commit 34307bb3 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix tests (#14289)

parent 24b30d4d
...@@ -232,7 +232,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -232,7 +232,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
# this can then be incorporated into _prepare_for_class in test_modeling_common.py # this can then be incorporated into _prepare_for_class in test_modeling_common.py
elif model_class.__name__ == "BeitForSemanticSegmentation": elif model_class.__name__ == "BeitForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long() inputs_dict["labels"] = torch.zeros(
[self.model_tester.batch_size, height, width], device=torch_device
).long()
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
...@@ -259,7 +261,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -259,7 +261,9 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
# this can then be incorporated into _prepare_for_class in test_modeling_common.py # this can then be incorporated into _prepare_for_class in test_modeling_common.py
elif model_class.__name__ == "BeitForSemanticSegmentation": elif model_class.__name__ == "BeitForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long() inputs_dict["labels"] = torch.zeros(
[self.model_tester.batch_size, height, width], device=torch_device
).long()
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
......
...@@ -318,7 +318,9 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -318,7 +318,9 @@ class SegformerModelTest(ModelTesterMixin, unittest.TestCase):
# this can then be incorporated into _prepare_for_class in test_modeling_common.py # this can then be incorporated into _prepare_for_class in test_modeling_common.py
if model_class.__name__ == "SegformerForSemanticSegmentation": if model_class.__name__ == "SegformerForSemanticSegmentation":
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
inputs_dict["labels"] = torch.zeros([self.model_tester.batch_size, height, width]).long() inputs_dict["labels"] = torch.zeros(
[self.model_tester.batch_size, height, width], device=torch_device
).long()
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment