Unverified Commit 3e4c9e5c authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`ViTHybrid`] + [`BiT`] cleaner `__init__` (#20649)

* cleaner `__init__`

* add docstring for `backbone_config`
parent aac7b0d2
......@@ -671,15 +671,6 @@ class BitPreTrainedModel(PreTrainedModel):
if isinstance(module, BitModel):
module.gradient_checkpointing = value
@torch.no_grad()
def _get_feature_map(self, dummy_image):
training = self.training
if training:
self.eval()
feature_map = self(dummy_image).feature_maps[-1]
self.train(training)
return feature_map
BIT_START_DOCSTRING = r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
......
......@@ -69,6 +69,8 @@ class ViTHybridConfig(PretrainedConfig):
The number of input channels.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*, defaults to `None`):
The configuration of the backbone in a dictionary or the config object of the backbone.
Example:
......
......@@ -167,7 +167,8 @@ class ViTHybridPatchEmbeddings(nn.Module):
if feature_size is None:
dummy_image = torch.zeros(1, num_channels, image_size[0], image_size[1])
feature_map = self.backbone._get_feature_map(dummy_image)
with torch.no_grad():
feature_map = self.backbone(dummy_image).feature_maps[-1]
feature_size = feature_map.shape[-2:]
feature_dim = feature_map.shape[1]
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment