"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ea118ae2e1ef62e909626f1b5a4487f5d1cb4a55"
Unverified Commit c4129196 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

🚨🚨🚨 Remove softmax for EfficientNetForImageClassification 🚨🚨🚨 (#25501)

* Remove softmax for EfficientNet

* Update integration test values

* Fix up
parent 06a1d75b
...@@ -588,7 +588,6 @@ class EfficientNetForImageClassification(EfficientNetPreTrainedModel): ...@@ -588,7 +588,6 @@ class EfficientNetForImageClassification(EfficientNetPreTrainedModel):
# Classifier head # Classifier head
self.dropout = nn.Dropout(p=config.dropout_rate) self.dropout = nn.Dropout(p=config.dropout_rate)
self.classifier = nn.Linear(config.hidden_dim, self.num_labels) if self.num_labels > 0 else nn.Identity() self.classifier = nn.Linear(config.hidden_dim, self.num_labels) if self.num_labels > 0 else nn.Identity()
self.classifier_act = nn.Softmax(dim=1)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -620,7 +619,6 @@ class EfficientNetForImageClassification(EfficientNetPreTrainedModel): ...@@ -620,7 +619,6 @@ class EfficientNetForImageClassification(EfficientNetPreTrainedModel):
pooled_output = outputs.pooler_output if return_dict else outputs[1] pooled_output = outputs.pooler_output if return_dict else outputs[1]
pooled_output = self.dropout(pooled_output) pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output) logits = self.classifier(pooled_output)
logits = self.classifier_act(logits)
loss = None loss = None
if labels is not None: if labels is not None:
......
...@@ -265,5 +265,5 @@ class EfficientNetModelIntegrationTest(unittest.TestCase): ...@@ -265,5 +265,5 @@ class EfficientNetModelIntegrationTest(unittest.TestCase):
expected_shape = torch.Size((1, 1000)) expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([0.0001, 0.0002, 0.0002]).to(torch_device) expected_slice = torch.tensor([-0.2962, 0.4487, 0.4499]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment