Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
780e8eea
Commit
780e8eea
authored
Jan 16, 2023
by
Tri Dao
Browse files
[ViT] Support timm checkpoint, add tests
parent
2ec7d3f7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
96 additions
and
11 deletions
+96
-11
flash_attn/models/vit.py
flash_attn/models/vit.py
+29
-4
flash_attn/modules/block.py
flash_attn/modules/block.py
+10
-3
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+7
-3
tests/models/test_opt.py
tests/models/test_opt.py
+1
-1
tests/models/test_vit.py
tests/models/test_vit.py
+49
-0
No files found.
flash_attn/models/vit.py
View file @
780e8eea
# Copyright (c) 2022, Tri Dao.
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import
math
import
re
from
functools
import
partial
from
copy
import
deepcopy
from
collections
import
OrderedDict
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -218,6 +221,7 @@ class VisionTransformer(nn.Module):
hidden_states
=
self
.
_pos_embed
(
x
)
residual
=
None
if
self
.
global_pool
!=
'token'
or
all_tokens
:
# if True:
for
block
in
self
.
blocks
:
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
else
:
...
...
@@ -225,10 +229,8 @@ class VisionTransformer(nn.Module):
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states_1st
=
rearrange
(
hidden_states
[:,
0
],
'b d -> b 1 d'
)
residual_1st
=
rearrange
(
residual
[:,
0
],
'b d -> b 1 d'
)
hidden_states
,
residual
=
self
.
blocks
[
-
1
](
hidden_states_1st
,
residual_1st
,
mixer_kwargs
=
{
'x_kv'
:
hidden_states
})
hidden_states
,
residual
=
self
.
blocks
[
-
1
](
hidden_states
,
residual
,
mixer_subset
=
slice
(
0
,
1
))
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
drop_path
(
self
.
dropout
(
hidden_states
))
+
residual
hidden_states
=
self
.
norm
(
residual
.
to
(
dtype
=
self
.
norm
.
weight
.
dtype
))
...
...
@@ -258,6 +260,29 @@ class VisionTransformer(nn.Module):
x
=
self
.
forward_head
(
x
)
return
x
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
patch_embed_weight
=
state_dict
[
'patch_embed.proj.weight'
]
if
patch_embed_weight
.
dim
()
==
4
:
# convert from Conv2d to Linear
state_dict
[
'patch_embed.proj.weight'
]
=
rearrange
(
patch_embed_weight
,
'o c h w -> o (c h w)'
)
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
'^blocks.(\d+).attn.qkv.'
,
r
'blocks.\1.mixer.Wqkv.'
,
key
)
key
=
re
.
sub
(
r
'^blocks.(\d+).attn.proj.'
,
r
'blocks.\1.mixer.out_proj.'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
n_layer
=
len
(
self
.
blocks
)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if
(
self
.
blocks
[
-
1
].
mixer
.
cross_attn
and
f
'blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight'
in
state_dict
):
Wqkv
=
state_dict
.
pop
(
f
'blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight'
)
bqkv
=
state_dict
.
pop
(
f
'blocks.
{
n_layer
-
1
}
.mixer.Wqkv.bias'
)
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wq.weight'
]
=
Wqkv
[:
self
.
embed_dim
]
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wkv.weight'
]
=
Wqkv
[
self
.
embed_dim
:]
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wq.bias'
]
=
bqkv
[:
self
.
embed_dim
]
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wkv.bias'
]
=
bqkv
[
self
.
embed_dim
:]
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
def
init_weights_vit_timm
(
module
:
nn
.
Module
,
name
:
str
=
''
):
""" ViT weight initialization, original timm impl (for reproducibility) """
...
...
flash_attn/modules/block.py
View file @
780e8eea
...
...
@@ -89,12 +89,15 @@ class Block(nn.Module):
p
.
_shared_params
=
True
def
forward
(
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
mixer_kwargs
=
None
):
mixer_subset
=
None
,
mixer_kwargs
=
None
):
r
"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
if
self
.
prenorm
:
if
not
self
.
fused_dropout_add_ln
:
...
...
@@ -116,8 +119,12 @@ class Block(nn.Module):
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
rowscale
=
rowscale1
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
)
hidden_states
=
self
.
mixer
(
hidden_states
,
**
(
mixer_kwargs
if
mixer_kwargs
is
not
None
else
{}))
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
mixer_kwargs
[
'mixer_subset'
]
=
mixer_subset
hidden_states
=
self
.
mixer
(
hidden_states
,
**
mixer_kwargs
)
if
mixer_subset
is
not
None
:
residual
=
residual
[:,
mixer_subset
]
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_path2
(
self
.
dropout2
(
hidden_states
))
...
...
flash_attn/modules/mha.py
View file @
780e8eea
...
...
@@ -420,7 +420,7 @@ class MHA(nn.Module):
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
mixer_subset
=
None
,
inference_params
=
None
,
**
kwargs
):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
...
...
@@ -433,6 +433,9 @@ class MHA(nn.Module):
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
...
...
@@ -454,6 +457,7 @@ class MHA(nn.Module):
kwargs
=
({
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen
,
**
kwargs
}
if
self
.
use_flash_attn
else
{
'key_padding_mask'
:
key_padding_mask
,
**
kwargs
})
if
not
self
.
cross_attn
:
assert
x_kv
is
None
and
mixer_subset
is
None
if
not
self
.
return_residual
:
qkv
=
self
.
Wqkv
(
x
)
else
:
...
...
@@ -491,14 +495,14 @@ class MHA(nn.Module):
context
=
rearrange
(
context
,
'b h d -> b 1 h d'
)
else
:
if
not
self
.
return_residual
:
q
=
self
.
Wq
(
x
)
q
=
self
.
Wq
(
x
if
mixer_subset
is
None
else
x
[:,
mixer_subset
]
)
kv
=
self
.
Wkv
(
x_kv
if
x_kv
is
not
None
else
x
)
else
:
if
x_kv
is
not
None
:
kv
,
x_kv
=
self
.
Wkv
(
x_kv
)
else
:
kv
,
x
=
self
.
Wkv
(
x
)
q
=
self
.
Wq
(
x
)
q
=
self
.
Wq
(
x
if
mixer_subset
is
None
else
x
[:,
mixer_subset
]
)
q
=
rearrange
(
q
,
'... (h d) -> ... h d'
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
'... (two h d) -> ... two h d'
,
two
=
2
,
d
=
self
.
head_dim
)
if
self
.
dwconv
:
...
...
tests/models/test_opt.py
View file @
780e8eea
...
...
@@ -26,7 +26,7 @@ def test_opt_state_dict(model_name):
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
])
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def
test_opt_optimized
(
model_name
):
"""Check that our implementation of OPT (without a
ny
optimizations enabled) matches the
"""Check that our implementation of OPT (without a
ll
optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
...
...
tests/models/test_vit.py
0 → 100644
View file @
780e8eea
import
re
import
torch
import
pytest
from
timm.models.vision_transformer
import
vit_base_patch16_224
from
flash_attn.models.vit
import
vit_base_patch16_224
as
flash_vit_base_patch16_224
@
pytest
.
mark
.
parametrize
(
'fused_dense_gelu_dense'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_dense_gelu_dense', [False])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [True])
def
test_vit
(
optimized
,
fused_dense_gelu_dense
):
"""Check that our implementation of ViT matches the timm's implementation:
the output of our forward pass in fp16 should be around the same as
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
"""
dtype
=
torch
.
float16
device
=
'cuda'
kwargs
=
{}
if
optimized
:
kwargs
=
dict
(
use_flash_attn
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
)
kwargs
[
'fused_dense_gelu_dense'
]
=
fused_dense_gelu_dense
model
=
flash_vit_base_patch16_224
(
**
kwargs
).
to
(
device
=
device
,
dtype
=
dtype
)
model_ref
=
vit_base_patch16_224
(
pretrained
=
True
).
to
(
device
=
device
)
model_timm
=
vit_base_patch16_224
(
pretrained
=
True
).
to
(
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
model_ref
.
state_dict
())
model
.
eval
()
model_ref
.
eval
()
model_timm
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
x
=
torch
.
randn
(
batch_size
,
3
,
224
,
224
,
device
=
device
,
dtype
=
dtype
)
out
=
model
(
x
)
out_timm
=
model_timm
(
x
)
out_ref
=
model_ref
(
x
.
float
())
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'timm fp16 max diff:
{
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'timm fp16 mean diff:
{
(
out_timm
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment