Unverified Commit de2d793e authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix `EfficientFormer` (#21294)



* fix

* fix checkpoint

* fix style

* tiny update
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 8788fd0c
...@@ -42,7 +42,7 @@ logger = logging.get_logger(__name__) ...@@ -42,7 +42,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "EfficientFormerConfig" _CONFIG_FOR_DOC = "EfficientFormerConfig"
# Base docstring # Base docstring
_CHECKPOINT_FOR_DOC = "efficientformer-l1-300" _CHECKPOINT_FOR_DOC = "snap-research/efficientformer-l1-300"
_EXPECTED_OUTPUT_SHAPE = [1, 197, 768] _EXPECTED_OUTPUT_SHAPE = [1, 197, 768]
# Image classification docstring # Image classification docstring
...@@ -51,7 +51,7 @@ _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat" ...@@ -51,7 +51,7 @@ _IMAGE_CLASS_EXPECTED_OUTPUT = "Egyptian cat"
EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ EFFICIENTFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"huggingface/efficientformer-l1-300", "snap-research/efficientformer-l1-300",
# See all EfficientFormer models at https://huggingface.co/models?filter=efficientformer # See all EfficientFormer models at https://huggingface.co/models?filter=efficientformer
] ]
...@@ -133,6 +133,10 @@ class EfficientFormerSelfAttention(nn.Module): ...@@ -133,6 +133,10 @@ class EfficientFormerSelfAttention(nn.Module):
key_layer = key_layer.permute(0, 2, 1, 3) key_layer = key_layer.permute(0, 2, 1, 3)
value_layer = value_layer.permute(0, 2, 1, 3) value_layer = value_layer.permute(0, 2, 1, 3)
# set `model.to(torch_device)` won't change `self.ab.device`, if there is no follow-up `train` or `eval` call.
# Let's do it manually here, so users won't have to do this everytime.
if not self.training:
self.ab = self.ab.to(self.attention_biases.device)
attention_probs = (torch.matmul(query_layer, key_layer.transpose(-2, -1))) * self.scale + ( attention_probs = (torch.matmul(query_layer, key_layer.transpose(-2, -1))) * self.scale + (
self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment