Unverified Commit c4129196 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

🚨🚨🚨 Remove softmax for EfficientNetForImageClassification 🚨🚨🚨 (#25501)

* Remove softmax for EfficientNet

* Update integration test values

* Fix up
parent 06a1d75b
......@@ -588,7 +588,6 @@ class EfficientNetForImageClassification(EfficientNetPreTrainedModel):
# Classifier head
self.dropout = nn.Dropout(p=config.dropout_rate)
self.classifier = nn.Linear(config.hidden_dim, self.num_labels) if self.num_labels > 0 else nn.Identity()
self.classifier_act = nn.Softmax(dim=1)
# Initialize weights and apply final processing
self.post_init()
......@@ -620,7 +619,6 @@ class EfficientNetForImageClassification(EfficientNetPreTrainedModel):
pooled_output = outputs.pooler_output if return_dict else outputs[1]
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
logits = self.classifier_act(logits)
loss = None
if labels is not None:
......
......@@ -265,5 +265,5 @@ class EfficientNetModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([0.0001, 0.0002, 0.0002]).to(torch_device)
expected_slice = torch.tensor([-0.2962, 0.4487, 0.4499]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment