Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
5d5ecb45
Unverified
Commit
5d5ecb45
authored
Jan 06, 2022
by
Yiwen Song
Committed by
GitHub
Jan 06, 2022
Browse files
[ViT] Refactor forward function (#5172)
* refactor forward function * reemove n from return
parent
058f4bd7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
2 deletions
+9
-2
torchvision/prototype/models/vision_transformer.py
torchvision/prototype/models/vision_transformer.py
+9
-2
No files found.
torchvision/prototype/models/vision_transformer.py
View file @
5d5ecb45
...
@@ -202,7 +202,7 @@ class VisionTransformer(nn.Module):
...
@@ -202,7 +202,7 @@ class VisionTransformer(nn.Module):
nn
.
init
.
zeros_
(
self
.
heads
.
head
.
weight
)
nn
.
init
.
zeros_
(
self
.
heads
.
head
.
weight
)
nn
.
init
.
zeros_
(
self
.
heads
.
head
.
bias
)
nn
.
init
.
zeros_
(
self
.
heads
.
head
.
bias
)
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
_process_input
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
n
,
c
,
h
,
w
=
x
.
shape
n
,
c
,
h
,
w
=
x
.
shape
p
=
self
.
patch_size
p
=
self
.
patch_size
torch
.
_assert
(
h
==
self
.
image_size
,
"Wrong image height!"
)
torch
.
_assert
(
h
==
self
.
image_size
,
"Wrong image height!"
)
...
@@ -221,7 +221,14 @@ class VisionTransformer(nn.Module):
...
@@ -221,7 +221,14 @@ class VisionTransformer(nn.Module):
# embedding dimension
# embedding dimension
x
=
x
.
permute
(
0
,
2
,
1
)
x
=
x
.
permute
(
0
,
2
,
1
)
# Expand the class token to the full batch.
return
x
def
forward
(
self
,
x
:
torch
.
Tensor
):
# Reshaping and permuting the input tensor
x
=
self
.
_process_input
(
x
)
n
=
x
.
shape
[
0
]
# Expand the class token to the full batch
batch_class_token
=
self
.
class_token
.
expand
(
n
,
-
1
,
-
1
)
batch_class_token
=
self
.
class_token
.
expand
(
n
,
-
1
,
-
1
)
x
=
torch
.
cat
([
batch_class_token
,
x
],
dim
=
1
)
x
=
torch
.
cat
([
batch_class_token
,
x
],
dim
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment