Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
6777336a
Commit
6777336a
authored
Jan 14, 2024
by
Tri Dao
Browse files
Move masking to a separate file (mask.h)
parent
9448264d
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
122 additions
and
101 deletions
+122
-101
csrc/flash_attn/src/dropout.h
csrc/flash_attn/src/dropout.h
+4
-0
csrc/flash_attn/src/flash_bwd_kernel.h
csrc/flash_attn/src/flash_bwd_kernel.h
+2
-1
csrc/flash_attn/src/flash_bwd_launch_template.h
csrc/flash_attn/src/flash_bwd_launch_template.h
+3
-1
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+2
-1
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+1
-1
csrc/flash_attn/src/mask.h
csrc/flash_attn/src/mask.h
+110
-0
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+0
-97
No files found.
csrc/flash_attn/src/dropout.h
View file @
6777336a
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#pragma once
#include "philox.cuh"
#include "philox.cuh"
...
...
csrc/flash_attn/src/flash_bwd_kernel.h
View file @
6777336a
/***************************************************************************************************
/***************************************************************************************************
* Copyright (c) 202
3
, Tri Dao.
* Copyright (c) 202
4
, Tri Dao.
******************************************************************************/
******************************************************************************/
#pragma once
#pragma once
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "kernel_traits.h"
#include "kernel_traits.h"
#include "utils.h"
#include "utils.h"
#include "softmax.h"
#include "softmax.h"
#include "mask.h"
#include "dropout.h"
#include "dropout.h"
#include "alibi.h"
#include "alibi.h"
...
...
csrc/flash_attn/src/flash_bwd_launch_template.h
View file @
6777336a
// Copyright (c) 2023, Tri Dao.
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#pragma once
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
6777336a
/******************************************************************************
/******************************************************************************
* Copyright (c) 202
3
, Tri Dao.
* Copyright (c) 202
4
, Tri Dao.
******************************************************************************/
******************************************************************************/
#pragma once
#pragma once
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "kernel_traits.h"
#include "kernel_traits.h"
#include "utils.h"
#include "utils.h"
#include "softmax.h"
#include "softmax.h"
#include "mask.h"
#include "dropout.h"
#include "dropout.h"
#include "alibi.h"
#include "alibi.h"
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
6777336a
/******************************************************************************
/******************************************************************************
* Copyright (c) 202
3
, Tri Dao.
* Copyright (c) 202
4
, Tri Dao.
******************************************************************************/
******************************************************************************/
#pragma once
#pragma once
...
...
csrc/flash_attn/src/mask.h
0 → 100644
View file @
6777336a
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
******************************************************************************/
#pragma once
#include <cute/tensor.hpp>
namespace
flash
{
using
namespace
cute
;
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
max_seqlen_k
)
{
// Without the "make_coord" we get wrong results
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
tensor
(
mi
,
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
}
}
}
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_local
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row_idx
=
row_idx_base
+
i
*
8
;
const
int
col_idx_limit_left
=
std
::
max
(
0
,
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
window_size_left
);
const
int
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
+
window_size_right
);
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
col_idx_limit_right
||
(
HasWSLeft
&&
col_idx
<
col_idx_limit_left
))
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
}
// if (cute::thread0()) {
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
// print(tensor(make_coord(i, mi), _));
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
// }
}
}
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset
,
max_seqlen_q
,
warp_row_stride
,
-
1
,
0
);
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
2
,
"Only support 2D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
tensor
)
==
size
<
0
>
(
idx_rowcol
));
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tensor
)
==
size
<
1
>
(
idx_rowcol
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
,
1
>
(
tensor
);
++
ni
)
{
if
(
col_idx_offset_
+
get
<
1
>
(
idx_rowcol
(
0
,
ni
))
>=
col_idx_limit
)
{
tensor
(
mi
,
ni
)
=
-
INFINITY
;
}
}
// if (cute::thread0()) {
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
// print(tensor(_, make_coord(j, ni)));
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
// }
}
}
}
// namespace flash
csrc/flash_attn/src/softmax.h
View file @
6777336a
...
@@ -115,101 +115,4 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
...
@@ -115,101 +115,4 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
}
}
}
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
max_seqlen_k
)
{
// Without the "make_coord" we get wrong results
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
tensor
(
mi
,
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
}
}
}
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_local
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row_idx
=
row_idx_base
+
i
*
8
;
const
int
col_idx_limit_left
=
std
::
max
(
0
,
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
window_size_left
);
const
int
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
+
window_size_right
);
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
col_idx_limit_right
||
(
HasWSLeft
&&
col_idx
<
col_idx_limit_left
))
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
}
// if (cute::thread0()) {
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
// print(tensor(make_coord(i, mi), _));
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
// }
}
}
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset
,
max_seqlen_q
,
warp_row_stride
,
-
1
,
0
);
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
2
,
"Only support 2D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
tensor
)
==
size
<
0
>
(
idx_rowcol
));
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tensor
)
==
size
<
1
>
(
idx_rowcol
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
,
1
>
(
tensor
);
++
ni
)
{
if
(
col_idx_offset_
+
get
<
1
>
(
idx_rowcol
(
0
,
ni
))
>=
col_idx_limit
)
{
tensor
(
mi
,
ni
)
=
-
INFINITY
;
}
}
// if (cute::thread0()) {
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
// print(tensor(_, make_coord(j, ni)));
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
// }
}
}
}
// namespace flash
}
// namespace flash
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment