Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhangdong1
Block-Sparse-Attention
Commits
4f83cf8f
Commit
4f83cf8f
authored
Oct 10, 2024
by
Junxian
Browse files
[release] v0.0.1
parents
Changes
106
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1445 additions
and
0 deletions
+1445
-0
csrc/block_sparse_attn/src/philox.cuh
csrc/block_sparse_attn/src/philox.cuh
+165
-0
csrc/block_sparse_attn/src/softmax.h
csrc/block_sparse_attn/src/softmax.h
+322
-0
csrc/block_sparse_attn/src/static_switch.h
csrc/block_sparse_attn/src/static_switch.h
+95
-0
csrc/block_sparse_attn/src/utils.h
csrc/block_sparse_attn/src/utils.h
+521
-0
csrc/cutlass
csrc/cutlass
+1
-0
setup.py
setup.py
+341
-0
No files found.
csrc/block_sparse_attn/src/philox.cuh
0 → 100644
View file @
4f83cf8f
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
#pragma once
// Philox CUDA.
namespace
flash
{
struct
ull2
{
unsigned
long
long
x
;
unsigned
long
long
y
;
};
inline
__device__
uint2
mulhilo32
(
const
unsigned
int
a
,
const
unsigned
int
b
)
{
uint2
*
res
;
unsigned
long
long
tmp
;
asm
(
"mul.wide.u32 %0, %1, %2;
\n\t
"
:
"=l"
(
tmp
)
:
"r"
(
a
),
"r"
(
b
));
res
=
(
uint2
*
)(
&
tmp
);
return
*
res
;
}
inline
__device__
uint4
philox_single_round
(
const
uint4
ctr
,
const
uint2
key
)
{
constexpr
unsigned
long
kPhiloxSA
=
0xD2511F53
;
constexpr
unsigned
long
kPhiloxSB
=
0xCD9E8D57
;
uint2
res0
=
mulhilo32
(
kPhiloxSA
,
ctr
.
x
);
uint2
res1
=
mulhilo32
(
kPhiloxSB
,
ctr
.
z
);
uint4
ret
=
{
res1
.
y
^
ctr
.
y
^
key
.
x
,
res1
.
x
,
res0
.
y
^
ctr
.
w
^
key
.
y
,
res0
.
x
};
return
ret
;
}
inline
__device__
uint4
philox
(
unsigned
long
long
seed
,
unsigned
long
long
subsequence
,
unsigned
long
long
offset
)
{
constexpr
unsigned
long
kPhilox10A
=
0x9E3779B9
;
constexpr
unsigned
long
kPhilox10B
=
0xBB67AE85
;
uint2
key
=
reinterpret_cast
<
uint2
&>
(
seed
);
uint4
counter
;
ull2
*
tmp
=
reinterpret_cast
<
ull2
*>
(
&
counter
);
tmp
->
x
=
offset
;
tmp
->
y
=
subsequence
;
#pragma unroll
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
counter
=
philox_single_round
(
counter
,
key
);
key
.
x
+=
(
kPhilox10A
);
key
.
y
+=
(
kPhilox10B
);
}
uint4
output
=
philox_single_round
(
counter
,
key
);
return
output
;
}
}
// namespace flash
namespace
{
class
Philox
{
public:
__device__
inline
Philox
(
unsigned
long
long
seed
,
unsigned
long
long
subsequence
,
unsigned
long
long
offset
)
:
STATE
(
0
)
,
seed_
(
seed
)
,
offset_
(
offset
)
,
key
(
reinterpret_cast
<
const
uint2
&>
(
seed
))
{
//key.x = (unsigned int)seed;
//key.y = (unsigned int)(seed >> 32);
//counter = make_uint4(0, 0, 0, 0);
//counter.z = (unsigned int)(subsequence);
//counter.w = (unsigned int)(subsequence >> 32);
//STATE = 0;
//incr_n(offset / 4);
// key = reinterpret_cast<const uint2&>(seed);
ull2
*
tmp
=
reinterpret_cast
<
ull2
*>
(
&
counter
);
tmp
->
x
=
offset
/
4
;
tmp
->
y
=
subsequence
;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
// }
}
__device__
inline
uint4
operator
()()
{
// // if (STATE == 0) {
// uint4 counter_ = counter;
// uint2 key_ = key;
// // 7-round philox
// #pragma unroll
// for (int i = 0; i < 6; i++) {
// counter_ = flash::philox_single_round(counter_, key_);
// key_.x += (kPhilox10A);
// key_.y += (kPhilox10B);
// }
// // output = philox_single_round(counter_, key_);
// uint4 output = flash::philox_single_round(counter_, key_);
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
// // }
// incr();
// // }
// // return a float4 directly
// // unsigned long ret;
// // switch(STATE) {
// // case 0: ret = output.x; break;
// // case 1: ret = output.y; break;
// // case 2: ret = output.z; break;
// // case 3: ret = output.w; break;
// //}
// // STATE = (STATE + 1) % 4;
// return output;
return
flash
::
philox
(
seed_
,
offset_
,
offset_
);
}
private:
unsigned
long
long
offset_
,
seed_
;
struct
ull2
{
uint64_t
x
;
uint64_t
y
;
};
uint4
counter
;
// uint4 output;
const
uint2
key
;
unsigned
int
STATE
;
__device__
inline
void
incr_n
(
unsigned
long
long
n
)
{
unsigned
int
nlo
=
(
unsigned
int
)(
n
);
unsigned
int
nhi
=
(
unsigned
int
)(
n
>>
32
);
counter
.
x
+=
nlo
;
if
(
counter
.
x
<
nlo
)
nhi
++
;
counter
.
y
+=
nhi
;
if
(
nhi
<=
counter
.
y
)
return
;
if
(
++
counter
.
z
)
return
;
++
counter
.
w
;
}
__device__
uint4
incr128
(
uint4
ctr
)
{
uint4
res
;
asm
(
"add.cc.u32 %0, %4, %8;
\n\t
"
"addc.cc.u32 %1, %5, %9;
\n\t
"
"addc.cc.u32 %2, %6, %10;
\n\t
"
"addc.u32 %3, %7, %11;
\n\t
"
:
"=r"
(
res
.
x
),
"=r"
(
res
.
y
),
"=r"
(
res
.
z
),
"=r"
(
res
.
w
)
:
"r"
(
ctr
.
x
),
"r"
(
ctr
.
y
),
"r"
(
ctr
.
z
),
"r"
(
ctr
.
w
),
"n"
(
1
),
"n"
(
0
),
"n"
(
0
),
"n"
(
0
));
return
res
;
}
__device__
inline
void
incr
()
{
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
counter
=
incr128
(
counter
);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
}
static
const
unsigned
long
kPhilox10A
=
0x9E3779B9
;
static
const
unsigned
long
kPhilox10B
=
0xBB67AE85
;
// static const unsigned long kPhiloxSA = 0xD2511F53;
// static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
}
// namespace
csrc/block_sparse_attn/src/softmax.h
0 → 100644
View file @
4f83cf8f
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
/******************************************************************************
* Adapted by Junxian Guo from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
******************************************************************************/
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "philox.cuh"
#include "utils.h"
namespace
flash
{
using
namespace
cute
;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
inline
void
thread_reduce_
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
summary
,
Operator
&
op
)
{
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
summary
)
==
size
<
0
>
(
tensor
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
mi
++
)
{
summary
(
mi
)
=
zero_init
?
tensor
(
mi
,
0
)
:
op
(
summary
(
mi
),
tensor
(
mi
,
0
));
#pragma unroll
for
(
int
ni
=
1
;
ni
<
size
<
1
>
(
tensor
);
ni
++
)
{
summary
(
mi
)
=
op
(
summary
(
mi
),
tensor
(
mi
,
ni
));
}
}
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
inline
void
quad_allreduce_
(
Tensor
<
Engine0
,
Layout0
>
&
dst
,
Tensor
<
Engine1
,
Layout1
>
&
src
,
Operator
&
op
)
{
CUTE_STATIC_ASSERT_V
(
size
(
dst
)
==
size
(
src
));
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
dst
);
i
++
){
dst
(
i
)
=
Allreduce
<
4
>::
run
(
src
(
i
),
op
);
}
}
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
inline
void
reduce_
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
summary
,
Operator
&
op
)
{
thread_reduce_
<
zero_init
>
(
tensor
,
summary
,
op
);
quad_allreduce_
(
summary
,
summary
,
op
);
}
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
__device__
inline
void
reduce_max
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
max
){
MaxOp
<
float
>
max_op
;
reduce_
<
zero_init
>
(
tensor
,
max
,
max_op
);
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
__device__
inline
void
reduce_sum
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
sum
){
SumOp
<
float
>
sum_op
;
reduce_
(
tensor
,
sum
,
sum_op
);
}
// Apply the exp to all the elements.
template
<
bool
Scale_max
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
scale_apply_exp2
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
max
,
const
float
scale
)
{
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
max
)
==
size
<
0
>
(
tensor
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const
float
max_scaled
=
max
(
mi
)
==
-
INFINITY
?
0.
f
:
max
(
mi
)
*
(
Scale_max
?
scale
:
float
(
M_LOG2E
));
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
tensor
);
++
ni
)
{
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor
(
mi
,
ni
)
=
exp2f
(
tensor
(
mi
,
ni
)
*
scale
-
max_scaled
);
}
}
}
// Apply the exp to all the elements.
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
max_scale_exp2_sum
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
max
,
Tensor
<
Engine1
,
Layout1
>
&
sum
,
const
float
scale
)
{
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
max
)
==
size
<
0
>
(
tensor
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
MaxOp
<
float
>
max_op
;
max
(
mi
)
=
zero_init
?
tensor
(
mi
,
0
)
:
max_op
(
max
(
mi
),
tensor
(
mi
,
0
));
#pragma unroll
for
(
int
ni
=
1
;
ni
<
size
<
1
>
(
tensor
);
ni
++
)
{
max
(
mi
)
=
max_op
(
max
(
mi
),
tensor
(
mi
,
ni
));
}
max
(
mi
)
=
Allreduce
<
4
>::
run
(
max
(
mi
),
max_op
);
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
const
float
max_scaled
=
max
(
mi
)
==
-
INFINITY
?
0.
f
:
max
(
mi
)
*
scale
;
sum
(
mi
)
=
0
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
tensor
);
++
ni
)
{
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor
(
mi
,
ni
)
=
exp2f
(
tensor
(
mi
,
ni
)
*
scale
-
max_scaled
);
sum
(
mi
)
+=
tensor
(
mi
,
ni
);
}
SumOp
<
float
>
sum_op
;
sum
(
mi
)
=
Allreduce
<
4
>::
run
(
sum
(
mi
),
sum_op
);
}
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
max_seqlen_k
)
{
// Without the "make_coord" we get wrong results
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
tensor
(
mi
,
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
}
}
}
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_local
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
window_size_left
,
const
int
window_size_right
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row_idx
=
row_idx_base
+
i
*
8
;
const
int
col_idx_limit_left
=
std
::
max
(
0
,
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
window_size_left
);
const
int
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
+
window_size_right
);
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
col_idx_limit_right
||
(
HasWSLeft
&&
col_idx
<
col_idx_limit_left
))
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
}
// if (cute::thread0()) {
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
// print(tensor(make_coord(i, mi), _));
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
// }
}
}
}
template
<
bool
HasWSLeft
=
true
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_streaming
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
,
const
int
local_size
,
const
int
sink_size
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout
::
rank
==
2
,
"Only support 2D Tensor"
);
const
int
lane_id
=
threadIdx
.
x
%
32
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
%
4
)
*
2
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
,
1
>
(
tensor
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
warp_row_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
,
0
>
(
tensor
);
++
i
)
{
const
int
row_idx
=
row_idx_base
+
i
*
8
;
const
int
col_idx_limit_left
=
std
::
max
(
0
,
row_idx
+
max_seqlen_k
-
max_seqlen_q
-
(
local_size
-
1
));
const
int
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
);
#pragma unroll
for
(
int
nj
=
0
;
nj
<
size
<
1
,
1
>
(
tensor
);
++
nj
)
{
const
int
col_idx_base
=
col_idx_offset
+
nj
*
8
;
#pragma unroll
for
(
int
j
=
0
;
j
<
size
<
1
,
0
>
(
tensor
);
++
j
)
{
const
int
col_idx
=
col_idx_base
+
j
;
if
(
col_idx
>=
col_idx_limit_right
||
(
HasWSLeft
&&
col_idx
<
col_idx_limit_left
&&
col_idx
>=
sink_size
))
{
tensor
(
make_coord
(
i
,
mi
),
make_coord
(
j
,
nj
))
=
-
INFINITY
;
}
}
}
}
}
}
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_mask_causal
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
,
const
int
max_seqlen_q
,
const
int
warp_row_stride
)
{
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
apply_mask_local
<
/*HasWSLeft=*/
false
>
(
tensor
,
col_idx_offset_
,
max_seqlen_k
,
row_idx_offset
,
max_seqlen_q
,
warp_row_stride
,
-
1
,
0
);
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
inline
__device__
void
apply_mask_causal_w_idx
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
idx_rowcol
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset
)
{
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
2
,
"Only support 2D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
tensor
)
==
size
<
0
>
(
idx_rowcol
));
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tensor
)
==
size
<
1
>
(
idx_rowcol
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
const
int
col_idx_limit
=
std
::
min
(
max_seqlen_k
,
1
+
row_idx_offset
+
get
<
0
>
(
idx_rowcol
(
mi
,
0
)));
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
,
1
>
(
tensor
);
++
ni
)
{
if
(
col_idx_offset_
+
get
<
1
>
(
idx_rowcol
(
0
,
ni
))
>=
col_idx_limit
)
{
tensor
(
mi
,
ni
)
=
-
INFINITY
;
}
}
// if (cute::thread0()) {
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
// print(tensor(_, make_coord(j, ni)));
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
// }
}
}
template
<
bool
encode_dropout_in_sign_bit
=
false
,
typename
Engine
,
typename
Layout
>
inline
__device__
void
apply_dropout
(
Tensor
<
Engine
,
Layout
>
&
tensor
,
uint8_t
p_dropout_in_uint8_t
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
int
block_row_start
,
int
block_col_start
,
int
block_row_stride
)
{
// tensor has shape (8, MMA_M, MMA_N / 2)
using
T
=
typename
Engine
::
value_type
;
auto
encode_dropout
=
[](
bool
keep
,
T
val
)
{
return
keep
?
val
:
(
encode_dropout_in_sign_bit
?
-
val
:
T
(
0
));
};
static_assert
(
decltype
(
size
<
2
>
(
tensor
))
::
value
%
2
==
0
);
const
uint16_t
p_dropout_8bit_in_uint16_t
=
uint16_t
(
p_dropout_in_uint8_t
);
const
uint32_t
p_dropout_8bit_in_uint32_t
=
(
uint32_t
(
p_dropout_8bit_in_uint16_t
)
<<
16
)
|
uint32_t
(
p_dropout_8bit_in_uint16_t
);
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tensor
);
++
m
,
block_row_start
+=
block_row_stride
)
{
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
#pragma unroll
for
(
int
n
=
0
;
n
<
size
<
2
>
(
tensor
)
/
2
;
++
n
,
++
rowcol
.
y
)
{
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
uint4
random_uint4
=
flash
::
philox
(
seed
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
offset
);
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
uint8_t
(
&
rnd_8
)[
16
]
=
reinterpret_cast
<
uint8_t
(
&
)[
16
]
>
(
random_uint4
);
// Special implementation for 16-bit types: we duplicate the threshold to the
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
// the random value is less than the threshold.
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
// We're exploiting the fact that floating point comparison is equivalent to integer
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
if
(
!
encode_dropout_in_sign_bit
&&
(
std
::
is_same
<
T
,
cutlass
::
half_t
>::
value
||
std
::
is_same
<
T
,
cutlass
::
bfloat16_t
>::
value
))
{
uint16_t
rnd_16
[
16
];
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
rnd_16
[
i
]
=
uint16_t
(
rnd_8
[
i
]);
}
uint32_t
(
&
rnd_32
)[
8
]
=
reinterpret_cast
<
uint32_t
(
&
)[
8
]
>
(
rnd_16
);
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
Tensor
tensor_uint32
=
recast
<
uint32_t
>
(
tensor
(
_
,
m
,
n
*
2
+
j
));
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
uint32_t
mask
;
asm
volatile
(
"set.le.u32.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
mask
)
:
"r"
(
rnd_32
[
j
*
4
+
i
]),
"r"
(
p_dropout_8bit_in_uint32_t
));
tensor_uint32
(
i
)
&=
mask
;
}
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
else
{
#pragma unroll
for
(
int
j
=
0
;
j
<
2
;
j
++
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
tensor
(
i
,
m
,
n
*
2
+
j
)
=
encode_dropout
(
rnd_8
[
j
*
8
+
i
]
<=
p_dropout_in_uint8_t
,
tensor
(
i
,
m
,
n
*
2
+
j
));
}
Tensor
tensor_uint32
=
recast
<
uint32_t
>
(
tensor
(
_
,
m
,
n
*
2
+
j
));
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
}
}
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
// // }
}
}
}
}
// namespace flash
csrc/block_sparse_attn/src/static_switch.h
0 → 100644
View file @
4f83cf8f
// Inspired by
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define FP16_SWITCH(COND, ...) \
[&] { \
if (COND) { \
using elem_type = cutlass::half_t; \
return __VA_ARGS__(); \
} else { \
using elem_type = cutlass::bfloat16_t; \
return __VA_ARGS__(); \
} \
}()
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
[&] { \
if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 96) { \
constexpr static int kHeadDim = 96; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 160) { \
constexpr static int kHeadDim = 160; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 192) { \
constexpr static int kHeadDim = 192; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 224) { \
constexpr static int kHeadDim = 224; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 256) { \
constexpr static int kHeadDim = 256; \
return __VA_ARGS__(); \
} \
}()
#define FWD_BLOCK_HEADDIM_SWITCH(HEADDIM, ...)\
[&] { \
if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
} \
}()
#define BWD_BLOCK_HEADDIM_SWITCH(HEADDIM, ...)\
[&] { \
if (HEADDIM <= 32) { \
constexpr static int kHeadDim = 32; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 64) { \
constexpr static int kHeadDim = 64; \
return __VA_ARGS__(); \
} else if (HEADDIM <= 128) { \
constexpr static int kHeadDim = 128; \
return __VA_ARGS__(); \
} \
}()
csrc/block_sparse_attn/src/utils.h
0 → 100644
View file @
4f83cf8f
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
#include <cute/algorithm/copy.hpp>
#include <cute/algorithm/gemm.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace
flash
{
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
uint32_t
relu2
(
const
uint32_t
x
);
template
<
>
inline
__device__
uint32_t
relu2
<
cutlass
::
half_t
>
(
const
uint32_t
x
)
{
uint32_t
res
;
const
uint32_t
zero
=
0u
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"max.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
zero
));
#else
asm
volatile
(
\
"{
\n
"
\
"
\t
.reg .f16x2 sela;
\n
"
\
"
\t
set.gtu.u32.f16x2 sela, %1, %2;
\n
"
\
"
\t
and.b32 %0, sela, %1;
\n
"
"}
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
zero
));
#endif
return
res
;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
>
inline
__device__
uint32_t
relu2
<
cutlass
::
bfloat16_t
>
(
const
uint32_t
x
)
{
uint32_t
res
;
const
uint32_t
zero
=
0u
;
asm
volatile
(
"max.bf16x2 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
x
),
"r"
(
zero
));
return
res
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template
<
typename
T
>
inline
__device__
uint32_t
convert_relu2
(
const
float2
x
);
template
<
>
inline
__device__
uint32_t
convert_relu2
<
cutlass
::
half_t
>
(
const
float2
x
)
{
uint32_t
res
;
const
uint32_t
a
=
reinterpret_cast
<
const
uint32_t
&>
(
x
.
x
);
const
uint32_t
b
=
reinterpret_cast
<
const
uint32_t
&>
(
x
.
y
);
asm
volatile
(
"cvt.rn.relu.f16x2.f32 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
b
),
"r"
(
a
));
return
res
;
}
template
<
>
inline
__device__
uint32_t
convert_relu2
<
cutlass
::
bfloat16_t
>
(
const
float2
x
)
{
uint32_t
res
;
const
uint32_t
a
=
reinterpret_cast
<
const
uint32_t
&>
(
x
.
x
);
const
uint32_t
b
=
reinterpret_cast
<
const
uint32_t
&>
(
x
.
y
);
asm
volatile
(
"cvt.rn.relu.bf16x2.f32 %0, %1, %2;
\n
"
:
"=r"
(
res
)
:
"r"
(
b
),
"r"
(
a
));
return
res
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
MaxOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
>
y
?
x
:
y
;
}
};
template
<
>
struct
MaxOp
<
float
>
{
// This is slightly faster
__device__
inline
float
operator
()(
float
const
&
x
,
float
const
&
y
)
{
return
max
(
x
,
y
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
SumOp
{
__device__
inline
T
operator
()(
T
const
&
x
,
T
const
&
y
)
{
return
x
+
y
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS
>
struct
Allreduce
{
static_assert
(
THREADS
==
32
||
THREADS
==
16
||
THREADS
==
8
||
THREADS
==
4
);
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
constexpr
int
OFFSET
=
THREADS
/
2
;
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
OFFSET
));
return
Allreduce
<
OFFSET
>::
run
(
x
,
op
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Allreduce
<
2
>
{
template
<
typename
T
,
typename
Operator
>
static
__device__
inline
T
run
(
T
x
,
Operator
&
op
)
{
x
=
op
(
x
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
x
,
1
));
return
x
;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
A_in_regs
=
false
,
bool
B_in_regs
=
false
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
,
typename
Tensor3
,
typename
Tensor4
,
typename
TiledMma
,
typename
TiledCopyA
,
typename
TiledCopyB
,
typename
ThrCopyA
,
typename
ThrCopyB
>
inline
__device__
void
gemm
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsA
,
Tensor4
const
&
tCsB
,
TiledMma
tiled_mma
,
TiledCopyA
smem_tiled_copy_A
,
TiledCopyB
smem_tiled_copy_B
,
ThrCopyA
smem_thr_copy_A
,
ThrCopyB
smem_thr_copy_B
)
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrB
)
==
size
<
2
>
(
acc
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCrA
)
==
size
<
2
>
(
tCrB
));
// MMA_K
Tensor
tCrA_copy_view
=
smem_thr_copy_A
.
retile_D
(
tCrA
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsA
)
==
size
<
1
>
(
tCrA_copy_view
));
// M
Tensor
tCrB_copy_view
=
smem_thr_copy_B
.
retile_D
(
tCrB
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
1
>
(
tCrB_copy_view
));
// N
if
(
!
A_in_regs
)
{
cute
::
copy
(
smem_tiled_copy_A
,
tCsA
(
_
,
_
,
_0
{}),
tCrA_copy_view
(
_
,
_
,
_0
{}));
}
if
(
!
B_in_regs
)
{
cute
::
copy
(
smem_tiled_copy_B
,
tCsB
(
_
,
_
,
_0
{}),
tCrB_copy_view
(
_
,
_
,
_0
{}));
}
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
2
>
(
tCrA
);
++
i
)
{
if
(
i
<
size
<
2
>
(
tCrA
)
-
1
)
{
if
(
!
A_in_regs
)
{
cute
::
copy
(
smem_tiled_copy_A
,
tCsA
(
_
,
_
,
i
+
1
),
tCrA_copy_view
(
_
,
_
,
i
+
1
));
}
if
(
!
B_in_regs
)
{
cute
::
copy
(
smem_tiled_copy_B
,
tCsB
(
_
,
_
,
i
+
1
),
tCrB_copy_view
(
_
,
_
,
i
+
1
));
}
}
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
i
),
tCrB
(
_
,
_
,
i
),
acc
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
,
typename
Tensor3
,
typename
TiledMma
,
typename
TiledCopy
,
typename
ThrCopy
>
inline
__device__
void
gemm_A_in_regs
(
Tensor0
&
acc
,
Tensor1
&
tCrA
,
Tensor2
&
tCrB
,
Tensor3
const
&
tCsB
,
TiledMma
tiled_mma
,
TiledCopy
smem_tiled_copy_B
,
ThrCopy
smem_thr_copy_B
)
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrA
)
==
size
<
1
>
(
acc
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCrB
)
==
size
<
2
>
(
acc
));
// MMA_N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCrA
)
==
size
<
2
>
(
tCrB
));
// MMA_K
Tensor
tCrB_copy_view
=
smem_thr_copy_B
.
retile_D
(
tCrB
);
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
1
>
(
tCrB_copy_view
));
// N
cute
::
copy
(
smem_tiled_copy_B
,
tCsB
(
_
,
_
,
_0
{}),
tCrB_copy_view
(
_
,
_
,
_0
{}));
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
2
>
(
tCrA
);
++
i
)
{
if
(
i
<
size
<
2
>
(
tCrA
)
-
1
)
{
cute
::
copy
(
smem_tiled_copy_B
,
tCsB
(
_
,
_
,
i
+
1
),
tCrB_copy_view
(
_
,
_
,
i
+
1
));
}
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
i
),
tCrB
(
_
,
_
,
i
),
acc
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
template
<
typename
Layout
>
inline
__device__
auto
convert_layout_acc_rowcol
(
Layout
acc_layout
)
{
static_assert
(
decltype
(
size
<
0
>
(
acc_layout
))
::
value
==
4
);
static_assert
(
decltype
(
rank
(
acc_layout
))
::
value
==
3
);
auto
l
=
logical_divide
(
acc_layout
,
Shape
<
_2
>
{});
// ((2, 2), MMA_M, MMA_N)
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
return
make_layout
(
make_layout
(
get
<
1
>
(
get
<
0
>
(
l
)),
get
<
1
>
(
l
)),
make_layout
(
get
<
0
>
(
get
<
0
>
(
l
)),
get
<
2
>
(
l
)));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
template
<
typename
MMA_traits
,
typename
Layout
>
inline
__device__
auto
convert_layout_rowcol_Aregs
(
Layout
rowcol_layout
)
{
using
X
=
Underscore
;
static_assert
(
decltype
(
size
<
0
,
0
>
(
rowcol_layout
))
::
value
==
2
);
static_assert
(
decltype
(
size
<
1
,
0
>
(
rowcol_layout
))
::
value
==
2
);
constexpr
int
mma_shape_K
=
get
<
2
>
(
typename
MMA_traits
::
Shape_MNK
{});
static_assert
(
mma_shape_K
==
8
||
mma_shape_K
==
16
);
constexpr
int
MMA_N_divisor
=
mma_shape_K
==
8
?
1
:
2
;
auto
l
=
logical_divide
(
rowcol_layout
,
Shape
<
X
,
Shape
<
X
,
Int
<
MMA_N_divisor
>>>
{});
// ((2, MMA_M), (2, (2, MMA_N / 2)))
// TD [2023-08-13]: Same error as above on Cutlass 3.2
// return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
// get<0, 1>(l),
// get<1, 1, 1>(l));
return
make_layout
(
make_layout
(
get
<
0
>
(
get
<
1
>
(
l
)),
get
<
0
>
(
get
<
0
>
(
l
)),
get
<
0
>
(
get
<
1
>
(
get
<
1
>
(
l
)))),
get
<
1
>
(
get
<
0
>
(
l
)),
get
<
1
>
(
get
<
1
>
(
get
<
1
>
(
l
))));
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
To_type
,
typename
Engine
,
typename
Layout
>
inline
__device__
auto
convert_type
(
Tensor
<
Engine
,
Layout
>
const
&
tensor
)
{
using
From_type
=
typename
Engine
::
value_type
;
constexpr
int
numel
=
decltype
(
size
(
tensor
))
::
value
;
cutlass
::
NumericArrayConverter
<
To_type
,
From_type
,
numel
>
convert_op
;
// HACK: this requires tensor to be "contiguous"
auto
frag
=
convert_op
(
*
reinterpret_cast
<
const
cutlass
::
Array
<
From_type
,
numel
>
*>
(
tensor
.
data
()));
return
make_tensor
(
make_rmem_ptr
<
To_type
>
(
&
frag
),
tensor
.
layout
());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Engine
,
typename
Layout
>
inline
__device__
void
relu_
(
Tensor
<
Engine
,
Layout
>
&
tensor
)
{
constexpr
int
numel
=
decltype
(
size
(
tensor
))
::
value
;
static_assert
(
numel
%
2
==
0
);
using
value_t
=
typename
Engine
::
value_type
;
// HACK: this requires tensor to be "contiguous"
Tensor
tensor_uint32
=
recast
<
uint32_t
>
(
tensor
);
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
tensor_uint32
);
++
i
)
{
tensor_uint32
(
i
)
=
relu2
<
value_t
>
(
tensor_uint32
(
i
));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
template
<
typename
To_type
,
typename
Engine
,
typename
Layout
>
inline
__device__
auto
convert_type_relu
(
Tensor
<
Engine
,
Layout
>
const
&
tensor
)
{
using
From_type
=
typename
Engine
::
value_type
;
static_assert
(
std
::
is_same_v
<
To_type
,
cutlass
::
half_t
>
||
std
::
is_same_v
<
To_type
,
cutlass
::
bfloat16_t
>
);
static_assert
(
std
::
is_same_v
<
float
,
From_type
>
);
constexpr
int
numel
=
decltype
(
size
(
tensor
))
::
value
;
static_assert
(
numel
%
2
==
0
);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// HACK: this requires tensor to be "contiguous"
Tensor
tensor_float2
=
recast
<
float2
>
(
tensor
);
Tensor
out_uint32
=
make_tensor
<
uint32_t
>
(
tensor_float2
.
layout
());
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
out_uint32
);
++
i
)
{
out_uint32
(
i
)
=
convert_relu2
<
To_type
>
(
tensor_float2
(
i
));
}
Tensor
out
=
make_tensor
(
make_rmem_ptr
<
To_type
>
(
out_uint32
.
data
()),
tensor
.
layout
());
#else
Tensor
out
=
flash
::
convert_type
<
To_type
>
(
tensor
);
flash
::
relu_
(
out
);
#endif
return
out
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Blocks until all but N previous cp.async.commit_group operations have committed.
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
// (which is equivalent to commit_group then wait_group 0).
// Instead we just call cp.async.wait_group 0, which is slightly faster.
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
template
<
int
N
>
CUTE_HOST_DEVICE
void
cp_async_wait
()
{
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
asm
volatile
(
"cp.async.wait_group %0;
\n
"
::
"n"
(
N
));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
typename
TiledCopy
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy
(
TiledCopy
tiled_copy
,
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
D
));
// MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert
(
!
(
Clear_OOB_MN
&&
!
Clear_OOB_K
));
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S
);
++
m
)
{
if
(
Is_even_MN
||
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
cute
::
copy
(
tiled_copy
,
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
else
if
(
Clear_OOB_K
)
{
cute
::
clear
(
D
(
_
,
m
,
k
));
}
}
}
else
if
(
Clear_OOB_MN
)
{
cute
::
clear
(
D
(
_
,
m
,
_
));
}
}
// TD [2023-04-13]: Strange that the code below can cause race condition.
// I think it's because the copies are under an if statement.
// if (Is_even_K) {
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(tiled_copy, S(_, m, _), D(_, m, _));
// } else if (Clear_OOB_MN) {
// clear(D(_, m, _));
// }
// }
// } else { // It's slightly faster in this case if iterate over K first
// #pragma unroll
// for (int k = 0; k < size<2>(S); ++k) {
// if (predicate_K(k)) {
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
// copy(tiled_copy, S(_, m, k), D(_, m, k));
// } else if (Clear_OOB_MN) {
// clear(D(_, m, k));
// }
// }
// } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN
// if (Clear_OOB_MN || Is_even_MN) {
// clear(D(_, _, k));
// } else {
// #pragma unroll
// for (int m = 0; m < size<1>(S); ++m) {
// if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) {
// clear(D(_, m, k));
// }
// }
// }
// }
// }
// }
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_even_K
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_w_min_idx
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
identity_MN
,
Tensor
<
Engine3
,
Layout3
>
const
&
predicate_K
,
const
int
max_MN
=
0
,
const
int
min_MN
=
0
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
D
));
// MMA_K
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S
);
++
m
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
if
(
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
>=
min_MN
&&
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
predicate_K
(
k
))
{
cute
::
copy
(
S
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_even_K
=
true
,
bool
Clear_OOB_K
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_rotary_interleaved
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
Cos
,
Tensor
<
Engine2
,
Layout2
>
const
&
Sin
,
Tensor
<
Engine3
,
Layout3
>
const
&
identity_MN
,
const
int
max_MN
,
const
int
min_MN
,
const
int
dim
,
const
int
rotary_dim
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
D
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
Cos
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
Cos
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
Sin
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
Sin
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
Cos
)
==
size
<
0
>
(
Sin
));
// MMA_K
static_assert
(
decltype
(
size
<
0
>
(
S
))
::
value
==
decltype
(
size
<
0
>
(
Cos
))
::
value
*
2
);
static_assert
(
decltype
(
size
<
0
>
(
Cos
))
::
value
%
2
==
0
);
// Since we do fast conversion from fp16/bf16 to fp32
Tensor
rCos
=
make_fragment_like
(
Cos
);
Tensor
rSin
=
make_fragment_like
(
Sin
);
Tensor
rS
=
make_fragment_like
(
S
);
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S
);
++
m
)
{
if
(
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
>=
min_MN
&&
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
get
<
1
>
(
identity_MN
(
0
,
0
,
k
))
<
dim
)
{
cute
::
copy
(
S
(
_
,
m
,
k
),
rS
(
_
,
m
,
k
));
if
(
get
<
1
>
(
identity_MN
(
0
,
0
,
k
))
<
rotary_dim
)
{
cute
::
copy
(
Cos
(
_
,
m
,
k
),
rCos
(
_
,
m
,
k
));
cute
::
copy
(
Sin
(
_
,
m
,
k
),
rSin
(
_
,
m
,
k
));
Tensor
S_fp32
=
convert_type
<
float
>
(
rS
(
_
,
m
,
k
));
Tensor
cos_fp32
=
convert_type
<
float
>
(
rCos
(
_
,
m
,
k
));
Tensor
sin_fp32
=
convert_type
<
float
>
(
rSin
(
_
,
m
,
k
));
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
>
(
rS
)
/
2
;
++
i
)
{
float
real
=
S_fp32
(
2
*
i
)
*
cos_fp32
(
i
)
-
S_fp32
(
2
*
i
+
1
)
*
sin_fp32
(
i
);
float
imag
=
S_fp32
(
2
*
i
)
*
sin_fp32
(
i
)
+
S_fp32
(
2
*
i
+
1
)
*
cos_fp32
(
i
);
S_fp32
(
2
*
i
)
=
real
;
S_fp32
(
2
*
i
+
1
)
=
imag
;
}
// Idk but I need to copy for the convert_type to work
Tensor
S_fp32_copy
=
make_fragment_like
(
S_fp32
);
cute
::
copy
(
S_fp32
,
S_fp32_copy
);
using
T
=
typename
Engine0
::
value_type
;
Tensor
S_og_type
=
convert_type
<
T
>
(
S_fp32_copy
);
cute
::
copy
(
S_og_type
,
rS
(
_
,
m
,
k
));
}
cute
::
copy
(
rS
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
else
if
(
Clear_OOB_K
)
{
cute
::
clear
(
D
(
_
,
m
,
k
));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_even_K
=
true
,
bool
Clear_OOB_K
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Engine2
,
typename
Layout2
,
typename
Engine3
,
typename
Layout3
>
inline
__device__
void
copy_rotary_contiguous
(
Tensor
<
Engine0
,
Layout0
>
const
&
S
,
Tensor
<
Engine1
,
Layout1
>
&
D
,
Tensor
<
Engine2
,
Layout2
>
const
&
Cos
,
Tensor
<
Engine2
,
Layout2
>
const
&
Sin
,
Tensor
<
Engine3
,
Layout3
>
const
&
identity_MN
,
const
int
max_MN
,
const
int
min_MN
,
const
int
dim
,
const
int
rotary_dim
)
{
CUTE_STATIC_ASSERT_V
(
rank
(
S
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
rank
(
D
)
==
Int
<
3
>
{});
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
D
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
D
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
D
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
Cos
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
Cos
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
S
)
==
size
<
1
>
(
Sin
));
// MMA_M
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
S
)
==
size
<
2
>
(
Sin
));
// MMA_K
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
S
)
==
size
<
0
>
(
Cos
));
// MMA
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
Cos
)
==
size
<
0
>
(
Sin
));
static_assert
(
decltype
(
size
<
0
>
(
Cos
))
::
value
%
2
==
0
);
// Since we do fast conversion from fp16/bf16 to fp32
Tensor
rCos
=
make_fragment_like
(
Cos
);
Tensor
rSin
=
make_fragment_like
(
Sin
);
Tensor
rS
=
make_fragment_like
(
S
);
Tensor
rS_other
=
make_fragment_like
(
rS
(
_
,
0
,
0
));
#pragma unroll
for
(
int
m
=
0
;
m
<
size
<
1
>
(
S
);
++
m
)
{
if
(
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
>=
min_MN
&&
get
<
0
>
(
identity_MN
(
0
,
m
,
0
))
<
max_MN
)
{
#pragma unroll
for
(
int
k
=
0
;
k
<
size
<
2
>
(
S
);
++
k
)
{
if
(
Is_even_K
||
get
<
1
>
(
identity_MN
(
0
,
0
,
k
))
<
dim
)
{
cute
::
copy
(
S
(
_
,
m
,
k
),
rS
(
_
,
m
,
k
));
if
(
get
<
1
>
(
identity_MN
(
0
,
0
,
k
))
<
rotary_dim
)
{
const
bool
is_left
=
get
<
1
>
(
identity_MN
(
0
,
0
,
k
))
<
rotary_dim
/
2
;
Tensor
gS_other
=
make_tensor
(
S
(
_
,
m
,
k
).
data
()
+
(
is_left
?
rotary_dim
/
2
:
-
rotary_dim
/
2
),
S
(
_
,
m
,
k
).
layout
());
cute
::
copy
(
gS_other
,
rS_other
);
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
Tensor
gCos
=
make_tensor
(
Cos
(
_
,
m
,
k
).
data
()
+
(
is_left
?
0
:
-
rotary_dim
/
2
),
Cos
(
_
,
m
,
k
).
layout
());
Tensor
gSin
=
make_tensor
(
Sin
(
_
,
m
,
k
).
data
()
+
(
is_left
?
0
:
-
rotary_dim
/
2
),
Sin
(
_
,
m
,
k
).
layout
());
cute
::
copy
(
gCos
,
rCos
(
_
,
m
,
k
));
cute
::
copy
(
gSin
,
rSin
(
_
,
m
,
k
));
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
Tensor
S_fp32
=
convert_type
<
float
>
(
rS
(
_
,
m
,
k
));
Tensor
S_other_fp32
=
convert_type
<
float
>
(
rS_other
);
Tensor
cos_fp32
=
convert_type
<
float
>
(
rCos
(
_
,
m
,
k
));
Tensor
sin_fp32
=
convert_type
<
float
>
(
rSin
(
_
,
m
,
k
));
#pragma unroll
for
(
int
i
=
0
;
i
<
size
<
0
>
(
rS
);
++
i
)
{
S_fp32
(
i
)
=
S_fp32
(
i
)
*
cos_fp32
(
i
)
+
S_other_fp32
(
i
)
*
(
is_left
?
-
sin_fp32
(
i
)
:
sin_fp32
(
i
));
}
// Idk but I need to copy for the convert_type to work
Tensor
S_fp32_copy
=
make_fragment_like
(
S_fp32
);
cute
::
copy
(
S_fp32
,
S_fp32_copy
);
using
T
=
typename
Engine0
::
value_type
;
Tensor
S_og_type
=
convert_type
<
T
>
(
S_fp32_copy
);
cute
::
copy
(
S_og_type
,
rS
(
_
,
m
,
k
));
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
}
cute
::
copy
(
rS
(
_
,
m
,
k
),
D
(
_
,
m
,
k
));
}
else
if
(
Clear_OOB_K
)
{
cute
::
clear
(
D
(
_
,
m
,
k
));
}
}
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace flash
cutlass
@
a75b4ac4
Subproject commit a75b4ac483166189a45290783cb0a18af5ff0ea5
setup.py
0 → 100644
View file @
4f83cf8f
# Copyright (c) 2023, Tri Dao.
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/setup.py
import
sys
import
warnings
import
os
import
re
import
ast
from
pathlib
import
Path
from
packaging.version
import
parse
,
Version
import
platform
from
setuptools
import
setup
,
find_packages
import
subprocess
import
urllib.request
import
urllib.error
from
wheel.bdist_wheel
import
bdist_wheel
as
_bdist_wheel
import
torch
from
torch.utils.cpp_extension
import
(
BuildExtension
,
CppExtension
,
CUDAExtension
,
CUDA_HOME
,
)
with
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
)
as
fh
:
long_description
=
fh
.
read
()
# ninja build does not work unless include_dirs are abs path
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
PACKAGE_NAME
=
"block_sparse_attn"
BASE_WHEEL_URL
=
(
"https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
)
# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD
=
os
.
getenv
(
"FLASH_ATTENTION_FORCE_BUILD"
,
"FALSE"
)
==
"TRUE"
SKIP_CUDA_BUILD
=
os
.
getenv
(
"FLASH_ATTENTION_SKIP_CUDA_BUILD"
,
"FALSE"
)
==
"TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI
=
os
.
getenv
(
"FLASH_ATTENTION_FORCE_CXX11_ABI"
,
"FALSE"
)
==
"TRUE"
def
get_platform
():
"""
Returns the platform name as used in wheel filenames.
"""
if
sys
.
platform
.
startswith
(
"linux"
):
return
"linux_x86_64"
elif
sys
.
platform
==
"darwin"
:
mac_version
=
"."
.
join
(
platform
.
mac_ver
()[
0
].
split
(
"."
)[:
2
])
return
f
"macosx_
{
mac_version
}
_x86_64"
elif
sys
.
platform
==
"win32"
:
return
"win_amd64"
else
:
raise
ValueError
(
"Unsupported platform: {}"
.
format
(
sys
.
platform
))
def
get_cuda_bare_metal_version
(
cuda_dir
):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
bare_metal_version
=
parse
(
output
[
release_idx
].
split
(
","
)[
0
])
return
raw_output
,
bare_metal_version
def
check_if_cuda_home_none
(
global_option
:
str
)
->
None
:
if
CUDA_HOME
is
not
None
:
return
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
# in that case.
warnings
.
warn
(
f
"
{
global_option
}
was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
)
def
append_nvcc_threads
(
nvcc_extra_args
):
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
cmdclass
=
{}
ext_modules
=
[]
# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
# files included in the source distribution, in case the user compiles from source.
subprocess
.
run
([
"git"
,
"submodule"
,
"update"
,
"--init"
,
"csrc/cutlass"
])
if
not
SKIP_CUDA_BUILD
:
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
# See https://github.com/pytorch/pytorch/pull/70650
generator_flag
=
[]
torch_dir
=
torch
.
__path__
[
0
]
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
check_if_cuda_home_none
(
"block_sparse_attn"
)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
if
CUDA_HOME
is
not
None
:
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
<
Version
(
"11.6"
):
raise
RuntimeError
(
"FlashAttention is only supported on CUDA 11.6 and above. "
"Note: make sure nvcc has a supported version by running nvcc -V."
)
# cc_flag.append("-gencode")
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
if
CUDA_HOME
is
not
None
:
if
bare_metal_version
>=
Version
(
"11.8"
):
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
if
FORCE_CXX11_ABI
:
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
=
True
ext_modules
.
append
(
CUDAExtension
(
name
=
"block_sparse_attn_cuda"
,
sources
=
[
"csrc/block_sparse_attn/flash_api.cpp"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim32_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim32_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim96_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim96_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim128_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim128_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim160_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim160_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim192_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim192_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim224_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim224_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim256_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_hdim256_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim32_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim32_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim64_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim64_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim96_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim96_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim128_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim128_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim160_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim160_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim192_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim192_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim224_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim224_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim256_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_hdim256_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu"
,
# add by JXGuo
"csrc/block_sparse_attn/src/flash_fwd_block_hdim32_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_block_hdim32_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_block_hdim32_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_block_hdim32_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_block_hdim64_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_block_hdim64_bf16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_block_hdim128_fp16_sm80.cu"
,
"csrc/block_sparse_attn/src/flash_bwd_block_hdim128_bf16_sm80.cu"
,
],
extra_compile_args
=
{
"cxx"
:
[
"-O3"
,
"-std=c++17"
]
+
generator_flag
,
"nvcc"
:
append_nvcc_threads
(
[
"-O3"
,
"-std=c++17"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"-U__CUDA_NO_HALF2_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--use_fast_math"
,
# "--ptxas-options=-v",
# "--ptxas-options=-O2",
"-lineinfo"
,
# "-G",
# "-g",
]
+
generator_flag
+
cc_flag
),
},
include_dirs
=
[
Path
(
this_dir
)
/
"csrc"
/
"block_sparse_attn"
,
Path
(
this_dir
)
/
"csrc"
/
"block_sparse_attn"
/
"src"
,
Path
(
this_dir
)
/
"csrc"
/
"cutlass"
/
"include"
,
],
)
)
def
get_package_version
():
with
open
(
Path
(
this_dir
)
/
"block_sparse_attn"
/
"__init__.py"
,
"r"
)
as
f
:
version_match
=
re
.
search
(
r
"^__version__\s*=\s*(.*)$"
,
f
.
read
(),
re
.
MULTILINE
)
public_version
=
ast
.
literal_eval
(
version_match
.
group
(
1
))
local_version
=
os
.
environ
.
get
(
"FLASH_ATTN_LOCAL_VERSION"
)
if
local_version
:
return
f
"
{
public_version
}
+
{
local_version
}
"
else
:
return
str
(
public_version
)
def
get_wheel_url
():
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_version_raw
=
parse
(
torch
.
__version__
)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
# to save CI time. Minor versions should be compatible.
torch_cuda_version
=
parse
(
"11.8"
)
if
torch_cuda_version
.
major
==
11
else
parse
(
"12.2"
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}{
torch_cuda_version
.
minor
}
"
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
"
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
"
{
PACKAGE_NAME
}
-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
cxx11abi
{
cxx11_abi
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl"
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
flash_version
}
"
,
wheel_name
=
wheel_filename
)
return
wheel_url
,
wheel_filename
class
CachedWheelsCommand
(
_bdist_wheel
):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""
def
run
(
self
):
if
FORCE_BUILD
:
return
super
().
run
()
wheel_url
,
wheel_filename
=
get_wheel_url
()
print
(
"Guessing wheel URL: "
,
wheel_url
)
try
:
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if
not
os
.
path
.
exists
(
self
.
dist_dir
):
os
.
makedirs
(
self
.
dist_dir
)
impl_tag
,
abi_tag
,
plat_tag
=
self
.
get_tag
()
archive_basename
=
f
"
{
self
.
wheel_dist_name
}
-
{
impl_tag
}
-
{
abi_tag
}
-
{
plat_tag
}
"
wheel_path
=
os
.
path
.
join
(
self
.
dist_dir
,
archive_basename
+
".whl"
)
print
(
"Raw wheel path"
,
wheel_path
)
os
.
rename
(
wheel_filename
,
wheel_path
)
except
urllib
.
error
.
HTTPError
:
print
(
"Precompiled wheel not found. Building from source..."
)
# If the wheel could not be downloaded, build from source
super
().
run
()
setup
(
name
=
PACKAGE_NAME
,
version
=
get_package_version
(),
packages
=
find_packages
(
exclude
=
(
"build"
,
"csrc"
,
"include"
,
"tests"
,
"dist"
,
"docs"
,
"benchmarks"
,
"block_sparse_attn.egg-info"
,
)
),
author
=
"Junxian Guo"
,
author_email
=
"junxian@mit.edu"
,
description
=
"Block Sparse Attention"
,
long_description
=
long_description
,
long_description_content_type
=
"text/markdown"
,
url
=
"https://github.com/mit-han-lab/Block-Sparse-Attention"
,
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: BSD License"
,
"Operating System :: Unix"
,
],
ext_modules
=
ext_modules
,
cmdclass
=
{
"bdist_wheel"
:
CachedWheelsCommand
,
"build_ext"
:
BuildExtension
}
if
ext_modules
else
{
"bdist_wheel"
:
CachedWheelsCommand
,
},
python_requires
=
">=3.7"
,
install_requires
=
[
"torch"
,
"einops"
,
"packaging"
,
"ninja"
,
],
)
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment