Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
05087332
Commit
05087332
authored
Jun 02, 2022
by
Tri Dao
Browse files
Remove softmax fp16 max
parent
14dc326e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
125 deletions
+5
-125
csrc/flash_attn/src/fmha/softmax.h
csrc/flash_attn/src/fmha/softmax.h
+5
-120
csrc/flash_attn/src/fmha/utils.h
csrc/flash_attn/src/fmha/utils.h
+0
-5
No files found.
csrc/flash_attn/src/fmha/softmax.h
View file @
05087332
...
@@ -58,12 +58,6 @@ inline __device__ float apply_exp_(float x, float max) {
...
@@ -58,12 +58,6 @@ inline __device__ float apply_exp_(float x, float max) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
__half2
apply_exp_
(
__half2
x
,
__half2
max
)
{
return
h2exp
(
__hsub2
(
x
,
max
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
apply_exp2_
(
float
x
,
float
max
)
{
inline
__device__
float
apply_exp2_
(
float
x
,
float
max
)
{
return
exp2f
(
x
-
max
);
return
exp2f
(
x
-
max
);
// With fast-math, this produces the same PTX instruction as the assembly below
// With fast-math, this produces the same PTX instruction as the assembly below
...
@@ -75,17 +69,9 @@ inline __device__ float apply_exp2_(float x, float max) {
...
@@ -75,17 +69,9 @@ inline __device__ float apply_exp2_(float x, float max) {
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
__half2
apply_exp2_
(
__half2
x
,
__half2
max
)
{
template
<
int
COLS
>
struct
ReadType
{};
return
h2exp2
(
__hsub2
(
x
,
max
));
template
<
>
struct
ReadType
<
4
>
{
using
T
=
float
;};
}
template
<
>
struct
ReadType
<
8
>
{
using
T
=
float2
;};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
COLS
,
bool
half
>
struct
ReadType
{};
template
<
>
struct
ReadType
<
4
,
false
>
{
using
T
=
float
;};
template
<
>
struct
ReadType
<
8
,
false
>
{
using
T
=
float2
;};
template
<
>
struct
ReadType
<
4
,
true
>
{
using
T
=
__half2
;};
template
<
>
struct
ReadType
<
8
,
true
>
{
using
T
=
float2
;};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -118,8 +104,7 @@ struct Smem_tile_reduce {
...
@@ -118,8 +104,7 @@ struct Smem_tile_reduce {
static
constexpr
int
LOOPS
=
Kernel_traits
::
Gmem_tile_o
::
LOOPS
;
static
constexpr
int
LOOPS
=
Kernel_traits
::
Gmem_tile_o
::
LOOPS
;
static_assert
(
LOOPS
==
1
);
static_assert
(
LOOPS
==
1
);
using
read_t
=
typename
ReadType
<
COLS
,
/*half=*/
false
>::
T
;
using
read_t
=
typename
ReadType
<
COLS
>::
T
;
using
read_half_t
=
typename
ReadType
<
COLS
,
/*half=*/
true
>::
T
;
__device__
inline
Smem_tile_reduce
(
float
*
smem_
,
const
int
tidx
)
{
__device__
inline
Smem_tile_reduce
(
float
*
smem_
,
const
int
tidx
)
{
...
@@ -152,17 +137,6 @@ struct Smem_tile_reduce {
...
@@ -152,17 +137,6 @@ struct Smem_tile_reduce {
}
}
}
}
__device__
inline
void
store
(
__half2
(
&
frag
)[
MMAS_M
])
{
__half2
*
smem_write_half_
=
reinterpret_cast
<
__half2
*>
(
smem_write_
);
if
(
qid_
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
int
offset
=
mi
*
16
*
WARPS_N
;
smem_write_half_
[
offset
+
0
*
8
*
WARPS_N
]
=
frag
[
mi
];
}
}
}
__device__
inline
void
load
(
read_t
(
&
frag
)[
2
*
MMAS_M
])
{
__device__
inline
void
load
(
read_t
(
&
frag
)[
2
*
MMAS_M
])
{
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
...
@@ -172,15 +146,6 @@ struct Smem_tile_reduce {
...
@@ -172,15 +146,6 @@ struct Smem_tile_reduce {
}
}
}
}
__device__
inline
void
load
(
read_half_t
(
&
frag
)[
MMAS_M
])
{
read_half_t
*
smem_read_half_
=
reinterpret_cast
<
read_half_t
*>
(
smem_read_
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
int
offset
=
mi
*
16
*
4
;
frag
[
mi
]
=
smem_read_half_
[
offset
+
0
*
8
*
4
];
}
}
__device__
inline
void
load_row
(
read_t
(
&
frag
)[
MMAS_M
],
int
row
)
{
__device__
inline
void
load_row
(
read_t
(
&
frag
)[
MMAS_M
],
int
row
)
{
#pragma unroll
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
...
@@ -304,29 +269,6 @@ struct Softmax_base {
...
@@ -304,29 +269,6 @@ struct Softmax_base {
}
}
}
}
// Apply the exp to all the elements.
inline
__device__
void
apply_exp
(
const
__half2
(
&
max
)[
MMAS_M
])
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
constexpr
float
kLog2e
=
M_LOG2E
;
const
float2
max_f
=
__half22float2
(
max
[
mi
]);
const
float
max0_log2e
=
max_f
.
x
*
kLog2e
,
max1_log2e
=
max_f
.
y
*
kLog2e
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
*
4
;
++
ni
)
{
float2
elt
=
__half22float2
(
elt_half_
[
mi
][
ni
]);
elt_
[
mi
*
2
+
0
][
ni
]
=
apply_exp2_
(
elt
.
x
*
kLog2e
,
max0_log2e
);
elt_
[
mi
*
2
+
1
][
ni
]
=
apply_exp2_
(
elt
.
y
*
kLog2e
,
max1_log2e
);
// __half2 out = apply_exp_(elt_half_[mi][ni], max[mi]);
// float2 outf = __half22float2(out);
// elt_[mi * 2 + 0][ni] = outf.x;
// elt_[mi * 2 + 1][ni] = outf.y;
}
}
}
// Apply the exp to all the elements.
// Apply the exp to all the elements.
template
<
bool
max_in_base2
=
false
>
template
<
bool
max_in_base2
=
false
>
inline
__device__
void
apply_exp_col
(
const
float
(
&
max
)[
MMAS_N
*
4
])
{
inline
__device__
void
apply_exp_col
(
const
float
(
&
max
)[
MMAS_N
*
4
])
{
...
@@ -527,7 +469,6 @@ struct Softmax_base {
...
@@ -527,7 +469,6 @@ struct Softmax_base {
int
tidx_
;
int
tidx_
;
// The elements.
// The elements.
float
elt_
[
MMAS_M
*
2
][
MMAS_N
*
4
];
float
elt_
[
MMAS_M
*
2
][
MMAS_N
*
4
];
__half2
elt_half_
[
MMAS_M
][
MMAS_N
*
4
];
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -638,34 +579,6 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -638,34 +579,6 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
}
}
}
}
// Scale FP32 fragments
template
<
typename
Mask
>
inline
__device__
void
unpack_noscale_half_and_apply_mask
(
const
Accumulator
(
&
acc
)[
MMAS_M
][
MMAS_N
],
const
Mask
&
mask
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
MMAS_N
;
++
ni
)
{
float
tmp
[
2
][
4
];
// 1st row - 4 elements per row.
tmp
[
0
][
0
]
=
mask
.
is_valid
(
mi
,
ni
,
0
,
0
)
?
acc
[
mi
][
ni
].
elt
(
0
)
:
-
INFINITY
;
tmp
[
0
][
1
]
=
mask
.
is_valid
(
mi
,
ni
,
0
,
1
)
?
acc
[
mi
][
ni
].
elt
(
1
)
:
-
INFINITY
;
tmp
[
0
][
2
]
=
mask
.
is_valid
(
mi
,
ni
,
0
,
2
)
?
acc
[
mi
][
ni
].
elt
(
4
)
:
-
INFINITY
;
tmp
[
0
][
3
]
=
mask
.
is_valid
(
mi
,
ni
,
0
,
3
)
?
acc
[
mi
][
ni
].
elt
(
5
)
:
-
INFINITY
;
// 2nd row - 4 elements per row.
tmp
[
1
][
0
]
=
mask
.
is_valid
(
mi
,
ni
,
1
,
0
)
?
acc
[
mi
][
ni
].
elt
(
2
)
:
-
INFINITY
;
tmp
[
1
][
1
]
=
mask
.
is_valid
(
mi
,
ni
,
1
,
1
)
?
acc
[
mi
][
ni
].
elt
(
3
)
:
-
INFINITY
;
tmp
[
1
][
2
]
=
mask
.
is_valid
(
mi
,
ni
,
1
,
2
)
?
acc
[
mi
][
ni
].
elt
(
6
)
:
-
INFINITY
;
tmp
[
1
][
3
]
=
mask
.
is_valid
(
mi
,
ni
,
1
,
3
)
?
acc
[
mi
][
ni
].
elt
(
7
)
:
-
INFINITY
;
this
->
elt_half_
[
mi
][
4
*
ni
+
0
]
=
__floats2half2_rn
(
tmp
[
0
][
0
],
tmp
[
1
][
0
]);
this
->
elt_half_
[
mi
][
4
*
ni
+
1
]
=
__floats2half2_rn
(
tmp
[
0
][
1
],
tmp
[
1
][
1
]);
this
->
elt_half_
[
mi
][
4
*
ni
+
2
]
=
__floats2half2_rn
(
tmp
[
0
][
2
],
tmp
[
1
][
2
]);
this
->
elt_half_
[
mi
][
4
*
ni
+
3
]
=
__floats2half2_rn
(
tmp
[
0
][
3
],
tmp
[
1
][
3
]);
}
}
}
template
<
bool
zero_init
=
true
,
typename
Operator
>
template
<
bool
zero_init
=
true
,
typename
Operator
>
__device__
inline
void
thread_reduce_
(
float
(
&
frag
)[
2
*
MMAS_M
],
Operator
&
op
)
{
__device__
inline
void
thread_reduce_
(
float
(
&
frag
)[
2
*
MMAS_M
],
Operator
&
op
)
{
#pragma unroll
#pragma unroll
...
@@ -678,18 +591,6 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -678,18 +591,6 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
}
}
}
}
template
<
typename
Operator
>
__device__
inline
void
thread_reduce_
(
__half2
(
&
frag
)[
MMAS_M
],
Operator
&
op
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
MMAS_M
;
mi
++
)
{
frag
[
mi
]
=
this
->
elt_half_
[
mi
][
0
];
#pragma unroll
for
(
int
ni
=
1
;
ni
<
4
*
MMAS_N
;
ni
++
)
{
frag
[
mi
]
=
op
(
frag
[
mi
],
this
->
elt_half_
[
mi
][
ni
]);
}
}
}
template
<
bool
zero_init
=
true
,
typename
Operator
>
template
<
bool
zero_init
=
true
,
typename
Operator
>
__device__
inline
void
reduce_
(
float
(
&
frag
)[
2
*
MMAS_M
],
Operator
&
op
,
Smem_tile_red
&
smem_red
)
{
__device__
inline
void
reduce_
(
float
(
&
frag
)[
2
*
MMAS_M
],
Operator
&
op
,
Smem_tile_red
&
smem_red
)
{
thread_reduce_
<
zero_init
>
(
frag
,
op
);
thread_reduce_
<
zero_init
>
(
frag
,
op
);
...
@@ -701,28 +602,12 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
...
@@ -701,28 +602,12 @@ struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
quad_allreduce
(
frag
,
tmp
,
op
);
quad_allreduce
(
frag
,
tmp
,
op
);
}
}
template
<
typename
Operator
>
__device__
inline
void
reduce_
(
__half2
(
&
frag
)[
MMAS_M
],
Operator
&
op
,
Smem_tile_red
&
smem_red
)
{
thread_reduce_
(
frag
,
op
);
quad_reduce
(
frag
,
frag
,
op
);
smem_red
.
store
(
frag
);
__syncthreads
();
typename
Smem_tile_red
::
read_half_t
tmp
[
MMAS_M
];
smem_red
.
load
(
tmp
);
quad_allreduce
(
frag
,
tmp
,
op
);
}
template
<
bool
zero_init
=
true
>
template
<
bool
zero_init
=
true
>
__device__
inline
void
reduce_max
(
float
(
&
frag
)[
2
*
MMAS_M
]){
__device__
inline
void
reduce_max
(
float
(
&
frag
)[
2
*
MMAS_M
]){
MaxOp
<
float
>
max
;
MaxOp
<
float
>
max
;
reduce_
<
zero_init
>
(
frag
,
max
,
smem_max_
);
reduce_
<
zero_init
>
(
frag
,
max
,
smem_max_
);
}
}
__device__
inline
void
reduce_max
(
__half2
(
&
frag
)[
MMAS_M
]){
MaxOp
<
__half2
>
max
;
reduce_
(
frag
,
max
,
smem_max_
);
}
__device__
inline
void
reduce_sum
(
float
(
&
frag
)[
2
*
MMAS_M
]){
__device__
inline
void
reduce_sum
(
float
(
&
frag
)[
2
*
MMAS_M
]){
SumOp
<
float
>
sum
;
SumOp
<
float
>
sum
;
reduce_
(
frag
,
sum
,
smem_sum_
);
reduce_
(
frag
,
sum
,
smem_sum_
);
...
...
csrc/flash_attn/src/fmha/utils.h
View file @
05087332
...
@@ -1024,11 +1024,6 @@ struct MaxOp<float> {
...
@@ -1024,11 +1024,6 @@ struct MaxOp<float> {
__device__
inline
float
operator
()(
float
const
&
x
,
float
const
&
y
)
{
return
max
(
x
,
y
);
}
__device__
inline
float
operator
()(
float
const
&
x
,
float
const
&
y
)
{
return
max
(
x
,
y
);
}
};
};
template
<
>
struct
MaxOp
<
__half2
>
{
__device__
inline
__half2
operator
()(
__half2
const
&
x
,
__half2
const
&
y
)
{
return
__hmax2
(
x
,
y
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
template
<
typename
T
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment