Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
4ea866ca
Commit
4ea866ca
authored
Jan 14, 2024
by
Tri Dao
Browse files
Make Alibi an object
parent
5aca153d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
81 deletions
+61
-81
csrc/flash_attn/src/alibi.h
csrc/flash_attn/src/alibi.h
+45
-33
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+4
-10
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+12
-38
No files found.
csrc/flash_attn/src/alibi.h
View file @
4ea866ca
...
...
@@ -13,50 +13,62 @@ using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_causal
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_alibi
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
float
alibi_slope
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
if
constexpr
(
Is_causal
)
{
// Simpler, we add the same bias vector to all rows
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
template
<
bool
Is_causal
>
struct
Alibi
{
const
float
alibi_slope
;
const
int
max_seqlen_k
,
max_seqlen_q
;
inline
__device__
Alibi
(
const
float
alibi_slope
,
const
int
max_seqlen_k
,
const
int
max_seqlen_q
)
:
alibi_slope
(
alibi_slope
)
,
max_seqlen_k
(
max_seqlen_k
)
,
max_seqlen_q
(
max_seqlen_q
)
{
};
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_alibi
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
row_idx_offset
,
const
int
warp_row_stride
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
if
constexpr
(
Is_causal
)
{
// Simpler, we add the same bias vector to all rows
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_
ba
se
+
j
;
for
(
int
n
j
=
0
;
n
j
<
size
<
1
,
1
>
(
tensor
);
++
n
j
)
{
const
int
col_idx
_base
=
col_idx_
off
se
t
+
nj
*
8
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
tensor
(
mi
,
make_coord
(
j
,
nj
))
+=
alibi_slope
*
col_idx
;
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
tensor
(
mi
,
make_coord
(
j
,
nj
))
+=
alibi_slope
*
col_idx
;
}
}
}
}
}
else
{
// Bias depends on both row_idx and col_idx
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
}
else
{
// Bias depends on both row_idx and col_idx
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row_idx
=
row_idx_
ba
se
+
i
*
8
;
for
(
int
m
i
=
0
;
m
i
<
size
<
0
,
1
>
(
tensor
);
++
m
i
)
{
const
int
row_idx
_base
=
row_idx_
off
se
t
+
m
i
*
warp_row_stride
;
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col
_idx
_base
=
col
_idx_
off
se
t
+
nj
*
8
;
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row
_idx
=
row
_idx_
ba
se
+
i
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
-=
alibi_slope
*
abs
(
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
col_idx
);
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
-=
alibi_slope
*
abs
(
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
col_idx
);
}
}
}
}
}
}
}
};
}
// namespace flash
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
4ea866ca
...
...
@@ -448,7 +448,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
clear
(
acc_dv
);
clear
(
acc_dk
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
const
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
flash
::
Alibi
<
Is_causal
>
alibi
(
alibi_slope
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_q
);
for
(;
m_block
>=
m_block_min
;
--
m_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma_sdp
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
// (MMA=4, MMA_N, MMA_N)
...
...
@@ -475,15 +476,8 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in
// if (cute::thread(32, 0)) { print(scores); }
if
(
Has_alibi
)
{
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
binfo
.
actual_seqlen_q
,
AtomLayoutMS
*
16
,
alibi_slope
);
alibi
.
apply_alibi
(
scores
,
n_block
*
kBlockN
+
(
tidx
/
32
/
AtomLayoutMS
)
*
MMA_N_SdP
*
16
,
m_block
*
kBlockM
+
get
<
0
>
(
taccScS_row
(
0
)),
AtomLayoutMS
*
16
);
}
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
4ea866ca
...
...
@@ -267,7 +267,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
clear
(
acc_o
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
const
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
flash
::
Alibi
<
Is_causal
>
alibi
(
alibi_slope
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_q
);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
...
...
@@ -313,15 +314,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// can produce Inf / NaN.
if
(
Has_alibi
)
{
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
alibi_slope
);
alibi
.
apply_alibi
(
scores
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
}
if
(
!
Is_causal
&&
!
Is_local
)
{
...
...
@@ -428,15 +422,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
alibi_slope
);
alibi
.
apply_alibi
(
scores
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
}
if
(
Is_local
&&
n_block
*
kBlockN
<
(
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
)
{
...
...
@@ -875,7 +862,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
clear
(
acc_o
);
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
const
float
alibi_slope
=
!
Has_alibi
?
0.0
f
:
reinterpret_cast
<
float
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
flash
::
Alibi
<
Is_causal
>
alibi
(
alibi_slope
,
binfo
.
actual_seqlen_k
,
binfo
.
actual_seqlen_q
);
// For performance reason, we separate out two kinds of iterations:
// those that need masking on S, and those that don't.
...
...
@@ -917,15 +905,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
alibi_slope
);
alibi
.
apply_alibi
(
scores
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
}
// if (cute::thread0()) { print(scores); }
...
...
@@ -1009,15 +990,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
if
(
Has_alibi
)
{
flash
::
apply_alibi
<
Is_causal
>
(
scores
,
n_block
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
binfo
.
actual_seqlen_q
,
kNWarps
*
16
,
alibi_slope
);
alibi
.
apply_alibi
(
scores
,
n_block
*
kBlockN
,
m_block
*
kBlockM
+
(
tidx
/
32
)
*
16
+
(
tidx
%
32
)
/
4
,
kNWarps
*
16
);
}
if
(
Is_local
&&
n_block
*
kBlockN
<
(
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment