Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
6cc73425
Commit
6cc73425
authored
Aug 05, 2022
by
Tri Dao
Browse files
Support index_first_axis with more than 2 dimensions
parent
713ea302
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
14 deletions
+53
-14
flash_attn/bert_padding.py
flash_attn/bert_padding.py
+53
-14
No files found.
flash_attn/bert_padding.py
View file @
6cc73425
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
import
numpy
as
np
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -11,21 +13,26 @@ class IndexFirstAxis(torch.autograd.Function):
...
@@ -11,21 +13,26 @@ class IndexFirstAxis(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
indices
):
def
forward
(
ctx
,
input
,
indices
):
ctx
.
save_for_backward
(
indices
)
ctx
.
save_for_backward
(
indices
)
ctx
.
first_axis_dim
=
input
.
shape
[
0
]
assert
input
.
ndim
>=
2
assert
input
.
ndim
==
2
ctx
.
first_axis_dim
,
other_shape
=
input
.
shape
[
0
],
input
.
shape
[
1
:]
second_dim
=
np
.
prod
(
other_shape
)
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
# return input[indices]
return
torch
.
gather
(
input
,
0
,
repeat
(
indices
,
'z -> z d'
,
d
=
input
.
shape
[
1
]))
return
torch
.
gather
(
rearrange
(
input
,
'b ... -> b (...)'
),
0
,
repeat
(
indices
,
'z -> z d'
,
d
=
second_dim
)).
reshape
(
-
1
,
*
other_shape
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
indices
,
=
ctx
.
saved_tensors
indices
,
=
ctx
.
saved_tensors
grad_input
=
torch
.
zeros
([
ctx
.
first_axis_dim
,
*
grad_output
.
shape
[
1
:]],
assert
grad_output
.
ndim
>=
2
device
=
grad_output
.
device
,
dtype
=
grad_output
.
dtype
)
other_shape
=
grad_output
.
shape
[
1
:]
grad_output
=
rearrange
(
grad_output
,
'b ... -> b (...)'
)
grad_input
=
torch
.
zeros
([
ctx
.
first_axis_dim
,
grad_output
.
shape
[
1
]],
device
=
grad_output
.
device
,
dtype
=
grad_output
.
dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
# grad_input[indices] = grad_output
grad_input
.
scatter_
(
0
,
repeat
(
indices
,
'z -> z d'
,
d
=
grad_output
.
shape
[
1
]),
grad_output
)
grad_input
.
scatter_
(
0
,
repeat
(
indices
,
'z -> z d'
,
d
=
grad_output
.
shape
[
1
]),
grad_output
)
return
grad_input
,
None
return
grad_input
.
reshape
(
ctx
.
first_axis_dim
,
*
other_shape
)
,
None
index_first_axis
=
IndexFirstAxis
.
apply
index_first_axis
=
IndexFirstAxis
.
apply
...
@@ -37,8 +44,8 @@ class IndexPutFirstAxis(torch.autograd.Function):
...
@@ -37,8 +44,8 @@ class IndexPutFirstAxis(torch.autograd.Function):
def
forward
(
ctx
,
values
,
indices
,
first_axis_dim
):
def
forward
(
ctx
,
values
,
indices
,
first_axis_dim
):
ctx
.
save_for_backward
(
indices
)
ctx
.
save_for_backward
(
indices
)
assert
indices
.
ndim
==
1
assert
indices
.
ndim
==
1
assert
values
.
ndim
=
=
2
assert
values
.
ndim
>
=
2
output
=
torch
.
zeros
(
first_axis_dim
,
values
.
shape
[
1
],
device
=
values
.
device
,
output
=
torch
.
zeros
(
first_axis_dim
,
*
values
.
shape
[
1
:
],
device
=
values
.
device
,
dtype
=
values
.
dtype
)
dtype
=
values
.
dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output
[
indices
]
=
values
output
[
indices
]
=
values
...
@@ -57,13 +64,45 @@ class IndexPutFirstAxis(torch.autograd.Function):
...
@@ -57,13 +64,45 @@ class IndexPutFirstAxis(torch.autograd.Function):
index_put_first_axis
=
IndexPutFirstAxis
.
apply
index_put_first_axis
=
IndexPutFirstAxis
.
apply
class
IndexFirstAxisResidual
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
indices
):
ctx
.
save_for_backward
(
indices
)
assert
input
.
ndim
>=
2
ctx
.
first_axis_dim
,
other_shape
=
input
.
shape
[
0
],
input
.
shape
[
1
:]
second_dim
=
np
.
prod
(
other_shape
)
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
output
=
input
[
indices
]
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
# memory format to channel_first. In other words, input might not be contiguous.
# If we don't detach, Pytorch complains about output being a view and is being modified inplace
return
output
,
input
.
detach
()
@
staticmethod
def
backward
(
ctx
,
grad_output
,
grad_residual
):
indices
,
=
ctx
.
saved_tensors
assert
grad_output
.
ndim
>=
2
other_shape
=
grad_output
.
shape
[
1
:]
assert
grad_residual
.
shape
[
1
:]
==
other_shape
grad_input
=
grad_residual
# grad_input[indices] += grad_output
indices
=
indices
.
reshape
(
indices
.
shape
[
0
],
*
((
1
,)
*
(
grad_output
.
ndim
-
1
)))
indices
=
indices
.
expand_as
(
grad_output
)
grad_input
.
scatter_add_
(
0
,
indices
,
grad_output
)
return
grad_input
.
reshape
(
ctx
.
first_axis_dim
,
*
other_shape
),
None
index_first_axis_residual
=
IndexFirstAxisResidual
.
apply
def
unpad_input
(
hidden_states
,
attention_mask
):
def
unpad_input
(
hidden_states
,
attention_mask
):
"""
"""
Arguments:
Arguments:
hidden_states: (batch, seqlen,
dim
)
hidden_states: (batch, seqlen,
...
)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
Return:
Return:
hidden_states: (total_nnz,
dim
), where total_nnz = number of tokens in selected in attention_mask.
hidden_states: (total_nnz,
...
), where total_nnz = number of tokens in selected in attention_mask.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
max_seqlen_in_batch: int
max_seqlen_in_batch: int
"""
"""
...
@@ -76,20 +115,20 @@ def unpad_input(hidden_states, attention_mask):
...
@@ -76,20 +115,20 @@ def unpad_input(hidden_states, attention_mask):
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
# so we write custom forward and backward to make it a bit faster.
return
(
index_first_axis
(
rearrange
(
hidden_states
,
'b s
d
-> (b s)
d
'
),
indices
),
indices
,
return
(
index_first_axis
(
rearrange
(
hidden_states
,
'b s
...
-> (b s)
...
'
),
indices
),
indices
,
cu_seqlens
,
max_seqlen_in_batch
)
cu_seqlens
,
max_seqlen_in_batch
)
def
pad_input
(
hidden_states
,
indices
,
batch
,
seqlen
):
def
pad_input
(
hidden_states
,
indices
,
batch
,
seqlen
):
"""
"""
Arguments:
Arguments:
hidden_states: (total_nnz,
dim
), where total_nnz = number of tokens in selected in attention_mask.
hidden_states: (total_nnz,
...
), where total_nnz = number of tokens in selected in attention_mask.
indices: (total_nnz)
indices: (total_nnz)
Return:
Return:
hidden_states: (batch, seqlen,
dim
)
hidden_states: (batch, seqlen,
...
)
"""
"""
dim
=
hidden_states
.
shape
[
-
1
]
dim
=
hidden_states
.
shape
[
-
1
]
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
# output[indices] = hidden_states
output
=
index_put_first_axis
(
hidden_states
,
indices
,
batch
*
seqlen
)
output
=
index_put_first_axis
(
hidden_states
,
indices
,
batch
*
seqlen
)
return
rearrange
(
output
,
'(b s)
d
-> b s
d
'
,
b
=
batch
)
return
rearrange
(
output
,
'(b s)
...
-> b s
...
'
,
b
=
batch
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment