Unverified Commit 9ef46659 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Improve backbone (#20380)


Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 5efd074a
...@@ -440,13 +440,12 @@ class ResNetBackbone(ResNetPreTrainedModel): ...@@ -440,13 +440,12 @@ class ResNetBackbone(ResNetPreTrainedModel):
self.out_features = config.out_features self.out_features = config.out_features
self.out_feature_channels = { out_feature_channels = {}
"stem": config.embedding_size, out_feature_channels["stem"] = config.embedding_size
"stage1": config.hidden_sizes[0], for idx, stage in enumerate(self.stage_names[1:]):
"stage2": config.hidden_sizes[1], out_feature_channels[stage] = config.hidden_sizes[idx]
"stage3": config.hidden_sizes[2],
"stage4": config.hidden_sizes[3], self.out_feature_channels = out_feature_channels
}
# initialize weights and apply final processing # initialize weights and apply final processing
self.post_init() self.post_init()
......
...@@ -55,7 +55,7 @@ class ResNetModelTester: ...@@ -55,7 +55,7 @@ class ResNetModelTester:
hidden_act="relu", hidden_act="relu",
num_labels=3, num_labels=3,
scope=None, scope=None,
out_features=["stage1", "stage2", "stage3", "stage4"], out_features=["stage2", "stage3", "stage4"],
): ):
self.parent = parent self.parent = parent
self.batch_size = batch_size self.batch_size = batch_size
...@@ -121,10 +121,11 @@ class ResNetModelTester: ...@@ -121,10 +121,11 @@ class ResNetModelTester:
# verify hidden states # verify hidden states
self.parent.assertEqual(len(result.feature_maps), len(config.out_features)) self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [3, 10, 8, 8]) self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
# verify channels # verify channels
self.parent.assertListEqual(model.channels, config.hidden_sizes) self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment