Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0fa5c0d7
Commit
0fa5c0d7
authored
Nov 17, 2022
by
Tri Dao
Browse files
Add PatchEmbed
parent
ece539ab
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
0 deletions
+56
-0
flash_attn/layers/patch_embed.py
flash_attn/layers/patch_embed.py
+56
-0
No files found.
flash_attn/layers/patch_embed.py
0 → 100644
View file @
0fa5c0d7
# We use the same API as https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/timm/models/layers/patch_embed.py
# But we use nn.Linear instead of Conv2d and it's about 8x faster.
from
functools
import
partial
import
torch.nn
as
nn
from
torch
import
_assert
from
torch.nn.modules.utils
import
_pair
from
einops
import
rearrange
try
:
from
flash_attn.ops.fused_dense
import
FusedDenseTD
except
ImportError
:
FusedDenseTD
=
None
class
PatchEmbed
(
nn
.
Module
):
""" 2D Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
norm_layer
=
None
,
flatten
=
True
,
bias
=
True
,
fused_bias_fc
=
False
,
):
super
().
__init__
()
img_size
=
_pair
(
img_size
)
patch_size
=
_pair
(
patch_size
)
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
grid_size
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
flatten
=
flatten
if
fused_bias_fc
and
FusedDenseTD
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
or
not
bias
else
FusedDenseTD
self
.
proj
=
linear_cls
(
in_chans
*
patch_size
[
0
]
*
patch_size
[
1
],
embed_dim
,
bias
=
bias
)
self
.
norm
=
norm_layer
(
embed_dim
)
if
norm_layer
else
nn
.
Identity
()
def
forward
(
self
,
x
):
_
,
_
,
H
,
W
=
x
.
shape
_assert
(
H
==
self
.
img_size
[
0
],
f
"Input image height (
{
H
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
)."
)
_assert
(
W
==
self
.
img_size
[
1
],
f
"Input image width (
{
W
}
) doesn't match model (
{
self
.
img_size
[
1
]
}
)."
)
x
=
self
.
proj
(
rearrange
(
x
,
'b c (h p1) (w p2) -> b h w (c p1 p2)'
,
p1
=
self
.
patch_size
[
0
],
p2
=
self
.
patch_size
[
1
]))
if
self
.
flatten
:
x
=
rearrange
(
x
,
'b h w c -> b (h w) c'
)
x
=
self
.
norm
(
x
)
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment