Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
780e8eea
Commit
780e8eea
authored
Jan 16, 2023
by
Tri Dao
Browse files
[ViT] Support timm checkpoint, add tests
parent
2ec7d3f7
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
96 additions
and
11 deletions
+96
-11
flash_attn/models/vit.py
flash_attn/models/vit.py
+29
-4
flash_attn/modules/block.py
flash_attn/modules/block.py
+10
-3
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+7
-3
tests/models/test_opt.py
tests/models/test_opt.py
+1
-1
tests/models/test_vit.py
tests/models/test_vit.py
+49
-0
No files found.
flash_attn/models/vit.py
View file @
780e8eea
# Copyright (c) 2022, Tri Dao.
# Copyright (c) 2022, Tri Dao.
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Inspired by / adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
import
math
import
math
import
re
from
functools
import
partial
from
functools
import
partial
from
copy
import
deepcopy
from
copy
import
deepcopy
from
collections
import
OrderedDict
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -218,6 +221,7 @@ class VisionTransformer(nn.Module):
...
@@ -218,6 +221,7 @@ class VisionTransformer(nn.Module):
hidden_states
=
self
.
_pos_embed
(
x
)
hidden_states
=
self
.
_pos_embed
(
x
)
residual
=
None
residual
=
None
if
self
.
global_pool
!=
'token'
or
all_tokens
:
if
self
.
global_pool
!=
'token'
or
all_tokens
:
# if True:
for
block
in
self
.
blocks
:
for
block
in
self
.
blocks
:
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
else
:
else
:
...
@@ -225,10 +229,8 @@ class VisionTransformer(nn.Module):
...
@@ -225,10 +229,8 @@ class VisionTransformer(nn.Module):
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
hidden_states
,
residual
=
block
(
hidden_states
,
residual
)
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# For the last layer, we only want the 1st token of the output. So we do cross-attention
# where the query is the 1st token and the key/value is the whole sequence.
# where the query is the 1st token and the key/value is the whole sequence.
hidden_states_1st
=
rearrange
(
hidden_states
[:,
0
],
'b d -> b 1 d'
)
hidden_states
,
residual
=
self
.
blocks
[
-
1
](
hidden_states
,
residual
,
residual_1st
=
rearrange
(
residual
[:,
0
],
'b d -> b 1 d'
)
mixer_subset
=
slice
(
0
,
1
))
hidden_states
,
residual
=
self
.
blocks
[
-
1
](
hidden_states_1st
,
residual_1st
,
mixer_kwargs
=
{
'x_kv'
:
hidden_states
})
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
drop_path
(
self
.
dropout
(
hidden_states
))
+
residual
residual
=
self
.
drop_path
(
self
.
dropout
(
hidden_states
))
+
residual
hidden_states
=
self
.
norm
(
residual
.
to
(
dtype
=
self
.
norm
.
weight
.
dtype
))
hidden_states
=
self
.
norm
(
residual
.
to
(
dtype
=
self
.
norm
.
weight
.
dtype
))
...
@@ -258,6 +260,29 @@ class VisionTransformer(nn.Module):
...
@@ -258,6 +260,29 @@ class VisionTransformer(nn.Module):
x
=
self
.
forward_head
(
x
)
x
=
self
.
forward_head
(
x
)
return
x
return
x
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
patch_embed_weight
=
state_dict
[
'patch_embed.proj.weight'
]
if
patch_embed_weight
.
dim
()
==
4
:
# convert from Conv2d to Linear
state_dict
[
'patch_embed.proj.weight'
]
=
rearrange
(
patch_embed_weight
,
'o c h w -> o (c h w)'
)
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
'^blocks.(\d+).attn.qkv.'
,
r
'blocks.\1.mixer.Wqkv.'
,
key
)
key
=
re
.
sub
(
r
'^blocks.(\d+).attn.proj.'
,
r
'blocks.\1.mixer.out_proj.'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
n_layer
=
len
(
self
.
blocks
)
# Convert from Wqkv to Wq and Wkv for cross attention (last layer)
if
(
self
.
blocks
[
-
1
].
mixer
.
cross_attn
and
f
'blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight'
in
state_dict
):
Wqkv
=
state_dict
.
pop
(
f
'blocks.
{
n_layer
-
1
}
.mixer.Wqkv.weight'
)
bqkv
=
state_dict
.
pop
(
f
'blocks.
{
n_layer
-
1
}
.mixer.Wqkv.bias'
)
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wq.weight'
]
=
Wqkv
[:
self
.
embed_dim
]
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wkv.weight'
]
=
Wqkv
[
self
.
embed_dim
:]
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wq.bias'
]
=
bqkv
[:
self
.
embed_dim
]
state_dict
[
f
'blocks.
{
n_layer
-
1
}
.mixer.Wkv.bias'
]
=
bqkv
[
self
.
embed_dim
:]
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
def
init_weights_vit_timm
(
module
:
nn
.
Module
,
name
:
str
=
''
):
def
init_weights_vit_timm
(
module
:
nn
.
Module
,
name
:
str
=
''
):
""" ViT weight initialization, original timm impl (for reproducibility) """
""" ViT weight initialization, original timm impl (for reproducibility) """
...
...
flash_attn/modules/block.py
View file @
780e8eea
...
@@ -89,12 +89,15 @@ class Block(nn.Module):
...
@@ -89,12 +89,15 @@ class Block(nn.Module):
p
.
_shared_params
=
True
p
.
_shared_params
=
True
def
forward
(
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
def
forward
(
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
mixer_kwargs
=
None
):
mixer_subset
=
None
,
mixer_kwargs
=
None
):
r
"""Pass the input through the encoder layer.
r
"""Pass the input through the encoder layer.
Args:
Args:
hidden_states: the sequence to the encoder layer (required).
hidden_states: the sequence to the encoder layer (required).
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual))
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
"""
"""
if
self
.
prenorm
:
if
self
.
prenorm
:
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
...
@@ -116,8 +119,12 @@ class Block(nn.Module):
...
@@ -116,8 +119,12 @@ class Block(nn.Module):
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
rowscale
=
rowscale1
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
rowscale
=
rowscale1
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
)
)
hidden_states
=
self
.
mixer
(
hidden_states
,
if
mixer_kwargs
is
None
:
**
(
mixer_kwargs
if
mixer_kwargs
is
not
None
else
{}))
mixer_kwargs
=
{}
mixer_kwargs
[
'mixer_subset'
]
=
mixer_subset
hidden_states
=
self
.
mixer
(
hidden_states
,
**
mixer_kwargs
)
if
mixer_subset
is
not
None
:
residual
=
residual
[:,
mixer_subset
]
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_path2
(
self
.
dropout2
(
hidden_states
))
dropped
=
self
.
drop_path2
(
self
.
dropout2
(
hidden_states
))
...
...
flash_attn/modules/mha.py
View file @
780e8eea
...
@@ -420,7 +420,7 @@ class MHA(nn.Module):
...
@@ -420,7 +420,7 @@ class MHA(nn.Module):
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
return
_update_kv_cache
(
kv
,
inference_params
,
self
.
layer_idx
)
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
inference_params
=
None
,
**
kwargs
):
mixer_subset
=
None
,
inference_params
=
None
,
**
kwargs
):
"""
"""
Arguments:
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
...
@@ -433,6 +433,9 @@ class MHA(nn.Module):
...
@@ -433,6 +433,9 @@ class MHA(nn.Module):
max_seqlen: int. Maximum sequence length in the batch.
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
(batch, seqlen). Only applicable when not using FlashAttention.
mixer_subset: for cross-attention only. If not None, will take a subset of x
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
inference_params: for generation. Adapted from Megatron-LM (and Apex)
inference_params: for generation. Adapted from Megatron-LM (and Apex)
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470
"""
"""
...
@@ -454,6 +457,7 @@ class MHA(nn.Module):
...
@@ -454,6 +457,7 @@ class MHA(nn.Module):
kwargs
=
({
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen
,
**
kwargs
}
kwargs
=
({
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen
,
**
kwargs
}
if
self
.
use_flash_attn
else
{
'key_padding_mask'
:
key_padding_mask
,
**
kwargs
})
if
self
.
use_flash_attn
else
{
'key_padding_mask'
:
key_padding_mask
,
**
kwargs
})
if
not
self
.
cross_attn
:
if
not
self
.
cross_attn
:
assert
x_kv
is
None
and
mixer_subset
is
None
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
qkv
=
self
.
Wqkv
(
x
)
qkv
=
self
.
Wqkv
(
x
)
else
:
else
:
...
@@ -491,14 +495,14 @@ class MHA(nn.Module):
...
@@ -491,14 +495,14 @@ class MHA(nn.Module):
context
=
rearrange
(
context
,
'b h d -> b 1 h d'
)
context
=
rearrange
(
context
,
'b h d -> b 1 h d'
)
else
:
else
:
if
not
self
.
return_residual
:
if
not
self
.
return_residual
:
q
=
self
.
Wq
(
x
)
q
=
self
.
Wq
(
x
if
mixer_subset
is
None
else
x
[:,
mixer_subset
]
)
kv
=
self
.
Wkv
(
x_kv
if
x_kv
is
not
None
else
x
)
kv
=
self
.
Wkv
(
x_kv
if
x_kv
is
not
None
else
x
)
else
:
else
:
if
x_kv
is
not
None
:
if
x_kv
is
not
None
:
kv
,
x_kv
=
self
.
Wkv
(
x_kv
)
kv
,
x_kv
=
self
.
Wkv
(
x_kv
)
else
:
else
:
kv
,
x
=
self
.
Wkv
(
x
)
kv
,
x
=
self
.
Wkv
(
x
)
q
=
self
.
Wq
(
x
)
q
=
self
.
Wq
(
x
if
mixer_subset
is
None
else
x
[:,
mixer_subset
]
)
q
=
rearrange
(
q
,
'... (h d) -> ... h d'
,
d
=
self
.
head_dim
)
q
=
rearrange
(
q
,
'... (h d) -> ... h d'
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
'... (two h d) -> ... two h d'
,
two
=
2
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
'... (two h d) -> ... two h d'
,
two
=
2
,
d
=
self
.
head_dim
)
if
self
.
dwconv
:
if
self
.
dwconv
:
...
...
tests/models/test_opt.py
View file @
780e8eea
...
@@ -26,7 +26,7 @@ def test_opt_state_dict(model_name):
...
@@ -26,7 +26,7 @@ def test_opt_state_dict(model_name):
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
])
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def
test_opt_optimized
(
model_name
):
def
test_opt_optimized
(
model_name
):
"""Check that our implementation of OPT (without a
ny
optimizations enabled) matches the
"""Check that our implementation of OPT (without a
ll
optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
"""
...
...
tests/models/test_vit.py
0 → 100644
View file @
780e8eea
import
re
import
torch
import
pytest
from
timm.models.vision_transformer
import
vit_base_patch16_224
from
flash_attn.models.vit
import
vit_base_patch16_224
as
flash_vit_base_patch16_224
@
pytest
.
mark
.
parametrize
(
'fused_dense_gelu_dense'
,
[
False
,
True
])
# @pytest.mark.parametrize('fused_dense_gelu_dense', [False])
@
pytest
.
mark
.
parametrize
(
'optimized'
,
[
False
,
True
])
# @pytest.mark.parametrize('optimized', [True])
def
test_vit
(
optimized
,
fused_dense_gelu_dense
):
"""Check that our implementation of ViT matches the timm's implementation:
the output of our forward pass in fp16 should be around the same as
timm' forward pass in fp16, when compared to timm's forward pass in fp32.
"""
dtype
=
torch
.
float16
device
=
'cuda'
kwargs
=
{}
if
optimized
:
kwargs
=
dict
(
use_flash_attn
=
True
,
fused_bias_fc
=
True
,
fused_dropout_add_ln
=
True
)
kwargs
[
'fused_dense_gelu_dense'
]
=
fused_dense_gelu_dense
model
=
flash_vit_base_patch16_224
(
**
kwargs
).
to
(
device
=
device
,
dtype
=
dtype
)
model_ref
=
vit_base_patch16_224
(
pretrained
=
True
).
to
(
device
=
device
)
model_timm
=
vit_base_patch16_224
(
pretrained
=
True
).
to
(
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
model_ref
.
state_dict
())
model
.
eval
()
model_ref
.
eval
()
model_timm
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
x
=
torch
.
randn
(
batch_size
,
3
,
224
,
224
,
device
=
device
,
dtype
=
dtype
)
out
=
model
(
x
)
out_timm
=
model_timm
(
x
)
out_ref
=
model_ref
(
x
.
float
())
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'timm fp16 max diff:
{
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'timm fp16 mean diff:
{
(
out_timm
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_timm
-
out_ref
).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment