Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3e4c9e5c
Unverified
Commit
3e4c9e5c
authored
Dec 07, 2022
by
Younes Belkada
Committed by
GitHub
Dec 07, 2022
Browse files
[`ViTHybrid`] + [`BiT`] cleaner `__init__` (#20649)
* cleaner `__init__` * add docstring for `backbone_config`
parent
aac7b0d2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
4 additions
and
10 deletions
+4
-10
src/transformers/models/bit/modeling_bit.py
src/transformers/models/bit/modeling_bit.py
+0
-9
src/transformers/models/vit_hybrid/configuration_vit_hybrid.py
...ransformers/models/vit_hybrid/configuration_vit_hybrid.py
+2
-0
src/transformers/models/vit_hybrid/modeling_vit_hybrid.py
src/transformers/models/vit_hybrid/modeling_vit_hybrid.py
+2
-1
No files found.
src/transformers/models/bit/modeling_bit.py
View file @
3e4c9e5c
...
...
@@ -671,15 +671,6 @@ class BitPreTrainedModel(PreTrainedModel):
if
isinstance
(
module
,
BitModel
):
module
.
gradient_checkpointing
=
value
@
torch
.
no_grad
()
def
_get_feature_map
(
self
,
dummy_image
):
training
=
self
.
training
if
training
:
self
.
eval
()
feature_map
=
self
(
dummy_image
).
feature_maps
[
-
1
]
self
.
train
(
training
)
return
feature_map
BIT_START_DOCSTRING
=
r
"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
...
...
src/transformers/models/vit_hybrid/configuration_vit_hybrid.py
View file @
3e4c9e5c
...
...
@@ -69,6 +69,8 @@ class ViTHybridConfig(PretrainedConfig):
The number of input channels.
qkv_bias (`bool`, *optional*, defaults to `True`):
Whether to add a bias to the queries, keys and values.
backbone_config (`Union[Dict[str, Any], PretrainedConfig]`, *optional*, defaults to `None`):
The configuration of the backbone in a dictionary or the config object of the backbone.
Example:
...
...
src/transformers/models/vit_hybrid/modeling_vit_hybrid.py
View file @
3e4c9e5c
...
...
@@ -167,7 +167,8 @@ class ViTHybridPatchEmbeddings(nn.Module):
if
feature_size
is
None
:
dummy_image
=
torch
.
zeros
(
1
,
num_channels
,
image_size
[
0
],
image_size
[
1
])
feature_map
=
self
.
backbone
.
_get_feature_map
(
dummy_image
)
with
torch
.
no_grad
():
feature_map
=
self
.
backbone
(
dummy_image
).
feature_maps
[
-
1
]
feature_size
=
feature_map
.
shape
[
-
2
:]
feature_dim
=
feature_map
.
shape
[
1
]
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment