Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
5813dcc1
Commit
5813dcc1
authored
Jan 26, 2026
by
zhanghj2
Browse files
添加softmax
parent
0e1300f7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
606 additions
and
0 deletions
+606
-0
csrc/softmax.h
csrc/softmax.h
+606
-0
No files found.
csrc/softmax.h
0 → 100644
View file @
5813dcc1
// Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/softmax.h
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "utils.h"
namespace
flash
{
using
namespace
cute
;
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
__forceinline__
void
thread_reduce_
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
summary
,
Operator
&
op
)
{
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
summary
)
==
size
<
0
>
(
tensor
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
mi
++
)
{
summary
(
mi
)
=
zero_init
?
tensor
(
mi
,
0
)
:
op
(
summary
(
mi
),
tensor
(
mi
,
0
));
#pragma unroll
for
(
int
ni
=
1
;
ni
<
size
<
1
>
(
tensor
);
ni
++
)
{
summary
(
mi
)
=
op
(
summary
(
mi
),
tensor
(
mi
,
ni
));
}
}
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
__forceinline__
void
quad_allreduce_
(
Tensor
<
Engine0
,
Layout0
>
&
dst
,
Tensor
<
Engine1
,
Layout1
>
&
src
,
Operator
&
op
)
{
CUTE_STATIC_ASSERT_V
(
size
(
dst
)
==
size
(
src
));
#pragma unroll
for
(
int
i
=
0
;
i
<
size
(
dst
);
i
++
){
dst
(
i
)
=
Allreduce
<
64
>::
run
(
src
(
i
),
op
);
}
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
__forceinline__
void
warp_allreduce_
(
Tensor
<
Engine0
,
Layout0
>
&
dst
,
Tensor
<
Engine1
,
Layout1
>
&
smem_reduce
,
Operator
&
op
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
row
=
tidx
%
16
;
const
int
col
=
tidx
/
64
;
const
int
warp_id
=
tidx
/
64
;
// static_assert(size(dst) == 1);
// 这里两种写法,一种是写连续,读不连续;另一种是读不连续,写连续。如何权衡?性能影响不大
if
((
tidx
%
64
)
/
16
==
0
)
// if (tidx >= warp_id * 64 && tidx <= warp_id * 64 + 16)
{
// smem_reduce(row + warp_id * 16) = dst(0);
smem_reduce
(
row
*
4
+
warp_id
*
1
)
=
dst
(
0
);
// smem_reduce(row, col) = dst(0);
}
__syncthreads
();
if
(
tidx
<
16
)
{
smem_reduce
(
row
+
64
)
=
op
(
op
(
smem_reduce
(
row
*
4
),
smem_reduce
(
row
*
4
+
1
)),
op
(
smem_reduce
(
row
*
4
+
2
),
smem_reduce
(
row
*
4
+
3
)));
}
__syncthreads
();
dst
(
0
)
=
smem_reduce
(
row
+
64
);
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
__forceinline__
void
warp_allreduce_tp1
(
Tensor
<
Engine0
,
Layout0
>
&
dst
,
Tensor
<
Engine1
,
Layout1
>
&
smem_reduce
,
Operator
&
op
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
col
=
(
tidx
%
64
)
/
16
;
const
int
warp_id
=
tidx
/
64
;
const
int
row
=
tidx
%
16
+
(
warp_id
%
4
)
*
16
;
// 0-4 1-5 2-6 3-7
if
(
col
==
0
)
{
// printf("sum %d %d %d %d %.2f \n", row * 2 + (warp_id / 4), row, tidx, warp_id, dst(0));
smem_reduce
[
row
*
2
+
(
warp_id
/
4
)]
=
dst
[
0
];
}
__syncthreads
();
if
(
col
==
0
&&
warp_id
<
4
)
{
// printf("sum %d %d %d %.2f %.2f \n", row, tidx, warp_id, smem_reduce[row * 2], smem_reduce[row * 2 + warp_id / 4]);
smem_reduce
[
128
+
row
]
=
op
(
smem_reduce
[
row
*
2
],
smem_reduce
[
row
*
2
+
1
]);
}
__syncthreads
();
dst
(
0
)
=
smem_reduce
(
128
+
row
);
}
template
<
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
__forceinline__
void
warp_allreduce_tp4
(
Tensor
<
Engine0
,
Layout0
>
&
dst
,
Tensor
<
Engine1
,
Layout1
>
&
smem_reduce
,
Operator
&
op
)
{
const
int
tidx
=
threadIdx
.
x
;
const
int
col
=
(
tidx
%
64
)
/
16
;
const
int
warp_id
=
tidx
/
64
;
const
int
row
=
tidx
%
16
+
(
warp_id
%
2
)
*
16
;
// 0-4 1-5 2-6 3-7
if
(
col
==
0
)
{
// printf("sum %d %d %d %d %.2f \n", row * 2 + (warp_id / 4), row, tidx, warp_id, dst(0));
smem_reduce
[
row
*
2
+
(
warp_id
/
2
)]
=
dst
[
0
];
}
__syncthreads
();
if
(
col
==
0
&&
warp_id
<
2
)
{
// printf("sum %d %d %d %.2f %.2f \n", row, tidx, warp_id, smem_reduce[row * 2], smem_reduce[row * 2 + warp_id / 4]);
smem_reduce
[
row
+
64
]
=
op
(
smem_reduce
[
row
*
2
],
smem_reduce
[
row
*
2
+
1
]);
}
__syncthreads
();
dst
(
0
)
=
smem_reduce
(
row
+
64
);
}
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
,
typename
Operator
>
__device__
__forceinline__
void
reduce_
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
summary
,
Operator
&
op
)
{
thread_reduce_
<
zero_init
>
(
tensor
,
summary
,
op
);
quad_allreduce_
(
summary
,
summary
,
op
);
}
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
__device__
__forceinline__
void
reduce_max
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
max
){
MaxOp
<
float
>
max_op
;
reduce_
<
zero_init
>
(
tensor
,
max
,
max_op
);
}
template
<
bool
zero_init
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
__device__
__forceinline__
void
reduce_sum
(
Tensor
<
Engine0
,
Layout0
>
const
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
&
sum
){
SumOp
<
float
>
sum_op
;
thread_reduce_
<
zero_init
>
(
tensor
,
sum
,
sum_op
);
}
// Apply the exp to all the elements.
template
<
bool
Scale_max
=
true
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
__forceinline__
__device__
void
scale_apply_exp2
(
Tensor
<
Engine0
,
Layout0
>
&
tensor
,
Tensor
<
Engine1
,
Layout1
>
const
&
max
,
const
float
scale
)
{
static_assert
(
Layout0
::
rank
==
2
,
"Only support 2D Tensor"
);
static_assert
(
Layout1
::
rank
==
1
,
"Only support 1D Tensor"
);
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
max
)
==
size
<
0
>
(
tensor
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
++
mi
)
{
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const
float
max_scaled
=
max
(
mi
)
==
-
INFINITY
?
0.
f
:
max
(
mi
)
*
(
Scale_max
?
scale
:
float
(
M_LOG2E
));
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
tensor
);
++
ni
)
{
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// The following macro will disable the use of fma.
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
// This macro is set in PyTorch and not FlashAttention
#if 0
#ifdef UNFUSE_FMA
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
#else
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
#endif
#else
tensor
(
mi
,
ni
)
=
__builtin_amdgcn_exp2f
(
tensor
(
mi
,
ni
)
*
scale
-
max_scaled
);
#endif
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
kNRows
>
struct
Softmax
{
using
TensorT
=
decltype
(
make_tensor
<
float
>
(
Shape
<
Int
<
kNRows
>>
{}));
TensorT
row_max
,
row_sum
;
__forceinline__
__device__
Softmax
()
{};
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
bool
is_tp1
=
false
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
>
__forceinline__
__device__
void
softmax_rescale_o
(
Tensor0
&
acc_s
,
Tensor1
&
acc_o
,
Tensor2
&
sRow_max_reduce_buffer
,
float
softmax_scale_log2
)
{
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp
<
float
>
max_op
;
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
scores
))
::
value
==
kNRows
);
if
(
Is_first
)
{
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
row_max
);
if
constexpr
(
is_tp1
)
{
flash
::
template
warp_allreduce_tp1
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
}
else
{
flash
::
template
warp_allreduce_
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
}
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
flash
::
reduce_sum
<
/*zero_init=*/
true
>
(
scores
,
row_sum
);
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
row_max
);
cute
::
copy
(
row_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
row_max
);
if
constexpr
(
is_tp1
)
{
flash
::
template
warp_allreduce_tp1
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
}
else
{
flash
::
template
warp_allreduce_
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
}
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
acc_o_rowcol
))
::
value
==
kNRows
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
row_max
);
++
mi
)
{
float
scores_max_cur
=
!
Check_inf
?
row_max
(
mi
)
:
(
row_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
row_max
(
mi
));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float
scores_scale
=
__builtin_amdgcn_exp2f
((
scores_max_prev
(
mi
)
-
scores_max_cur
)
*
softmax_scale_log2
);
#endif
// if (blockIdx.x == 0 && threadIdx.x == 0)
// {
// printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale );
// }
row_sum
(
mi
)
*=
scores_scale
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scores_scale
;
}
}
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash
::
reduce_sum
<
/*zero_init=*/
false
>
(
scores
,
row_sum
);
}
// if (block0())
// {
// printf("normalize_softmax_lse %.4f\n", row_sum(0));
// }
};
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor2
>
__forceinline__
__device__
void
softmax_rescale_o_fp8
(
Tensor0
&
acc_s
,
Tensor2
&
sRow_max_reduce_buffer
,
float
softmax_scale_log2
,
v4f
&
c0_0
,
v4f
&
c0_1
,
v4f
&
c1_0
,
v4f
&
c1_1
,
v4f
&
c2_0
,
v4f
&
c2_1
,
v4f
&
c3_0
,
v4f
&
c3_1
)
{
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp
<
float
>
max_op
;
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
scores
))
::
value
==
kNRows
);
if
(
Is_first
)
{
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
row_max
);
flash
::
template
warp_allreduce_
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
flash
::
reduce_sum
<
/*zero_init=*/
true
>
(
scores
,
row_sum
);
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
row_max
);
cute
::
copy
(
row_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
row_max
);
flash
::
template
warp_allreduce_
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert
(
1
==
kNRows
);
// #pragma unroll
// for (int mi = 0; mi < size(row_max); ++mi)
{
int
mi
=
0
;
float
scores_max_cur
=
!
Check_inf
?
row_max
(
mi
)
:
(
row_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
row_max
(
mi
));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float
scores_scale
=
__builtin_amdgcn_exp2f
((
scores_max_prev
(
mi
)
-
scores_max_cur
)
*
softmax_scale_log2
);
#endif
row_sum
(
mi
)
*=
scores_scale
;
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
c0_0
.
x
*=
scores_scale
;
c0_0
.
y
*=
scores_scale
;
c0_0
.
z
*=
scores_scale
;
c0_0
.
w
*=
scores_scale
;
c0_1
.
x
*=
scores_scale
;
c0_1
.
y
*=
scores_scale
;
c0_1
.
z
*=
scores_scale
;
c0_1
.
w
*=
scores_scale
;
c1_0
.
x
*=
scores_scale
;
c1_0
.
y
*=
scores_scale
;
c1_0
.
z
*=
scores_scale
;
c1_0
.
w
*=
scores_scale
;
c1_1
.
x
*=
scores_scale
;
c1_1
.
y
*=
scores_scale
;
c1_1
.
z
*=
scores_scale
;
c1_1
.
w
*=
scores_scale
;
c2_0
.
x
*=
scores_scale
;
c2_0
.
y
*=
scores_scale
;
c2_0
.
z
*=
scores_scale
;
c2_0
.
w
*=
scores_scale
;
c2_1
.
x
*=
scores_scale
;
c2_1
.
y
*=
scores_scale
;
c2_1
.
z
*=
scores_scale
;
c2_1
.
w
*=
scores_scale
;
c3_0
.
x
*=
scores_scale
;
c3_0
.
y
*=
scores_scale
;
c3_0
.
z
*=
scores_scale
;
c3_0
.
w
*=
scores_scale
;
c3_1
.
x
*=
scores_scale
;
c3_1
.
y
*=
scores_scale
;
c3_1
.
z
*=
scores_scale
;
c3_1
.
w
*=
scores_scale
;
}
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash
::
reduce_sum
<
/*zero_init=*/
false
>
(
scores
,
row_sum
);
}
};
template
<
bool
Is_dropout
=
false
,
bool
Split
=
false
,
bool
is_tp1
=
false
,
typename
Tensor0
,
typename
Tensor1
>
__forceinline__
__device__
TensorT
normalize_softmax_lse
(
Tensor0
&
acc_o
,
Tensor1
&
sRow_sum_reduce_buffer
,
float
softmax_scale
,
float
rp_dropout
=
1.0
)
{
SumOp
<
float
>
sum_op
;
quad_allreduce_
(
row_sum
,
row_sum
,
sum_op
);
if
constexpr
(
is_tp1
)
{
flash
::
template
warp_allreduce_tp1
(
row_sum
,
sRow_sum_reduce_buffer
,
sum_op
);
}
else
{
flash
::
template
warp_allreduce_
(
row_sum
,
sRow_sum_reduce_buffer
,
sum_op
);
}
// if (block0())
// {
// printf("is_tp1 %d %d normalize_softmax_lse %.4f\n",is_tp1, threadIdx.x, row_sum(0));
// }
TensorT
lse
=
make_fragment_like
(
row_sum
);
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
acc_o_rowcol
))
::
value
==
kNRows
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
float
sum
=
row_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__logf
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
}
return
lse
;
};
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor1
,
typename
Tensor2
>
__forceinline__
__device__
void
softmax_rescale_o_prefill
(
Tensor0
&
acc_s
,
Tensor1
&
acc_o
,
Tensor2
&
sRow_max_reduce_buffer
,
float
softmax_scale_log2
)
{
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp
<
float
>
max_op
;
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
scores
))
::
value
==
kNRows
);
if
(
Is_first
)
{
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
row_max
);
flash
::
template
warp_allreduce_
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
flash
::
reduce_sum
<
/*zero_init=*/
true
>
(
scores
,
row_sum
);
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
row_max
);
cute
::
copy
(
row_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
row_max
);
flash
::
template
warp_allreduce_
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
acc_o_rowcol
))
::
value
==
kNRows
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
row_max
);
++
mi
)
{
float
scores_max_cur
=
!
true
?
row_max
(
mi
)
:
(
row_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
row_max
(
mi
));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float
scores_scale
=
__builtin_amdgcn_exp2f
((
scores_max_prev
(
mi
)
-
scores_max_cur
)
*
softmax_scale_log2
);
#endif
// if (blockIdx.x == 0 && threadIdx.x == 0)
// {
// printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale );
// }
row_sum
(
mi
)
*=
scores_scale
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scores_scale
;
}
}
// if (blockIdx.x == 2)
// {
// printf("threadIdx.x %.2f \n",row_sum(mi) );
// }
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash
::
reduce_sum
<
/*zero_init=*/
false
>
(
scores
,
row_sum
);
}
};
template
<
bool
Is_dropout
=
false
,
bool
Split
=
false
,
typename
Tensor0
,
typename
Tensor1
>
__forceinline__
__device__
TensorT
normalize_softmax_lse_prefill
(
Tensor0
&
acc_o
,
Tensor1
&
sRow_sum_reduce_buffer
,
float
softmax_scale
,
float
rp_dropout
=
1.0
)
{
SumOp
<
float
>
sum_op
;
quad_allreduce_
(
row_sum
,
row_sum
,
sum_op
);
flash
::
template
warp_allreduce_
(
row_sum
,
sRow_sum_reduce_buffer
,
sum_op
);
TensorT
lse
=
make_fragment_like
(
row_sum
);
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
acc_o_rowcol
))
::
value
==
kNRows
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
float
sum
=
row_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__log2f
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
}
return
lse
;
};
template
<
bool
Is_dropout
=
false
,
bool
Split
=
false
,
typename
Tensor0
,
typename
Tensor1
>
__forceinline__
__device__
TensorT
normalize_softmax_lse_fp8
(
Tensor0
&
acc_o
,
Tensor1
&
sRow_sum_reduce_buffer
,
float
softmax_scale
,
float
descale_v
,
float
rp_dropout
=
1.0
)
{
SumOp
<
float
>
sum_op
;
quad_allreduce_
(
row_sum
,
row_sum
,
sum_op
);
flash
::
template
warp_allreduce_
(
row_sum
,
sRow_sum_reduce_buffer
,
sum_op
);
TensorT
lse
=
make_fragment_like
(
row_sum
);
Tensor
acc_o_rowcol
=
make_tensor
(
acc_o
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_o
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
acc_o_rowcol
))
::
value
==
kNRows
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
acc_o_rowcol
);
++
mi
)
{
float
sum
=
row_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
descale_v
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__logf
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
}
return
lse
;
};
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
bool
is_tp1
=
false
,
typename
Tensor0
,
typename
Tensor2
>
__forceinline__
__device__
void
softmax_rescale_o_fp8_tp1
(
Tensor0
&
acc_s
,
Tensor2
&
sRow_max_reduce_buffer
,
float
softmax_scale_log2
,
v4f
*
acco_f32
)
{
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp
<
float
>
max_op
;
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
scores
))
::
value
==
kNRows
);
if
constexpr
(
Is_first
)
{
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
row_max
);
if
constexpr
(
is_tp1
)
{
flash
::
template
warp_allreduce_tp1
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
}
else
{
flash
::
template
warp_allreduce_
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
}
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
flash
::
reduce_sum
<
/*zero_init=*/
true
>
(
scores
,
row_sum
);
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
row_max
);
cute
::
copy
(
row_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
row_max
);
if
constexpr
(
is_tp1
)
{
flash
::
template
warp_allreduce_tp1
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
}
else
{
flash
::
template
warp_allreduce_
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
}
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert
(
1
==
kNRows
);
// #pragma unroll
// for (int mi = 0; mi < size(row_max); ++mi)
{
int
mi
=
0
;
float
scores_max_cur
=
!
Check_inf
?
row_max
(
mi
)
:
(
row_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
row_max
(
mi
));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float
scores_scale
=
__builtin_amdgcn_exp2f
((
scores_max_prev
(
mi
)
-
scores_max_cur
)
*
softmax_scale_log2
);
#endif
row_sum
(
mi
)
*=
scores_scale
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
acco_f32
[
i
].
x
*=
scores_scale
;
acco_f32
[
i
].
y
*=
scores_scale
;
acco_f32
[
i
].
z
*=
scores_scale
;
acco_f32
[
i
].
w
*=
scores_scale
;
}
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
// c0_0.x *= scores_scale; c0_0.y *= scores_scale; c0_0.z *= scores_scale; c0_0.w *= scores_scale;
// c0_1.x *= scores_scale; c0_1.y *= scores_scale; c0_1.z *= scores_scale; c0_1.w *= scores_scale;
// c1_0.x *= scores_scale; c1_0.y *= scores_scale; c1_0.z *= scores_scale; c1_0.w *= scores_scale;
// c1_1.x *= scores_scale; c1_1.y *= scores_scale; c1_1.z *= scores_scale; c1_1.w *= scores_scale;
// c2_0.x *= scores_scale; c2_0.y *= scores_scale; c2_0.z *= scores_scale; c2_0.w *= scores_scale;
// c2_1.x *= scores_scale; c2_1.y *= scores_scale; c2_1.z *= scores_scale; c2_1.w *= scores_scale;
// c3_0.x *= scores_scale; c3_0.y *= scores_scale; c3_0.z *= scores_scale; c3_0.w *= scores_scale;
// c3_1.x *= scores_scale; c3_1.y *= scores_scale; c3_1.z *= scores_scale; c3_1.w *= scores_scale;
}
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash
::
reduce_sum
<
/*zero_init=*/
false
>
(
scores
,
row_sum
);
}
};
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
,
typename
Tensor2
>
__forceinline__
__device__
void
softmax_rescale_o_fp8_tp4
(
Tensor0
&
acc_s
,
Tensor2
&
sRow_max_reduce_buffer
,
float
softmax_scale_log2
,
v4f
*
acco_f32
)
{
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp
<
float
>
max_op
;
Tensor
scores
=
make_tensor
(
acc_s
.
data
(),
flash
::
convert_layout_acc_rowcol
(
acc_s
.
layout
()));
static_assert
(
decltype
(
size
<
0
>
(
scores
))
::
value
==
kNRows
);
if
constexpr
(
Is_first
)
{
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
row_max
);
flash
::
template
warp_allreduce_tp4
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
flash
::
reduce_sum
<
/*zero_init=*/
true
>
(
scores
,
row_sum
);
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
row_max
);
cute
::
copy
(
row_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
row_max
);
flash
::
template
warp_allreduce_tp4
(
row_max
,
sRow_max_reduce_buffer
,
max_op
);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert
(
1
==
kNRows
);
// #pragma unroll
// for (int mi = 0; mi < size(row_max); ++mi)
{
int
mi
=
0
;
float
scores_max_cur
=
!
Check_inf
?
row_max
(
mi
)
:
(
row_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
row_max
(
mi
));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float
scores_scale
=
__builtin_amdgcn_exp2f
((
scores_max_prev
(
mi
)
-
scores_max_cur
)
*
softmax_scale_log2
);
#endif
row_sum
(
mi
)
*=
scores_scale
;
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
acco_f32
[
i
].
x
*=
scores_scale
;
acco_f32
[
i
].
y
*=
scores_scale
;
acco_f32
[
i
].
z
*=
scores_scale
;
acco_f32
[
i
].
w
*=
scores_scale
;
}
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
// c0_0.x *= scores_scale; c0_0.y *= scores_scale; c0_0.z *= scores_scale; c0_0.w *= scores_scale;
// c0_1.x *= scores_scale; c0_1.y *= scores_scale; c0_1.z *= scores_scale; c0_1.w *= scores_scale;
// c1_0.x *= scores_scale; c1_0.y *= scores_scale; c1_0.z *= scores_scale; c1_0.w *= scores_scale;
// c1_1.x *= scores_scale; c1_1.y *= scores_scale; c1_1.z *= scores_scale; c1_1.w *= scores_scale;
// c2_0.x *= scores_scale; c2_0.y *= scores_scale; c2_0.z *= scores_scale; c2_0.w *= scores_scale;
// c2_1.x *= scores_scale; c2_1.y *= scores_scale; c2_1.z *= scores_scale; c2_1.w *= scores_scale;
// c3_0.x *= scores_scale; c3_0.y *= scores_scale; c3_0.z *= scores_scale; c3_0.w *= scores_scale;
// c3_1.x *= scores_scale; c3_1.y *= scores_scale; c3_1.z *= scores_scale; c3_1.w *= scores_scale;
}
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash
::
reduce_sum
<
/*zero_init=*/
false
>
(
scores
,
row_sum
);
}
};
template
<
bool
Is_dropout
=
false
,
bool
Split
=
false
,
bool
is_tp1
=
false
,
typename
Tensor1
>
__forceinline__
__device__
TensorT
normalize_softmax_lse_fp8_tp1
(
v4f
*
acco_f
,
Tensor1
&
sRow_sum_reduce_buffer
,
float
softmax_scale
,
float
descale_v
,
float
rp_dropout
=
1.0
)
{
SumOp
<
float
>
sum_op
;
quad_allreduce_
(
row_sum
,
row_sum
,
sum_op
);
if
constexpr
(
is_tp1
)
{
flash
::
template
warp_allreduce_tp1
(
row_sum
,
sRow_sum_reduce_buffer
,
sum_op
);
}
else
{
flash
::
template
warp_allreduce_
(
row_sum
,
sRow_sum_reduce_buffer
,
sum_op
);
}
TensorT
lse
=
make_fragment_like
(
row_sum
);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
1
;
++
mi
)
{
float
sum
=
row_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
descale_v
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__logf
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
acco_f
[
i
].
x
*=
scale
;
acco_f
[
i
].
y
*=
scale
;
acco_f
[
i
].
z
*=
scale
;
acco_f
[
i
].
w
*=
scale
;
}
}
return
lse
;
};
template
<
bool
Is_dropout
=
false
,
bool
Split
=
false
,
bool
is_tp1
=
false
,
typename
Tensor1
>
__forceinline__
__device__
TensorT
normalize_softmax_lse_fp8_tp4
(
v4f
*
acco_f
,
Tensor1
&
sRow_sum_reduce_buffer
,
float
softmax_scale
,
float
descale_v
,
float
rp_dropout
=
1.0
)
{
SumOp
<
float
>
sum_op
;
quad_allreduce_
(
row_sum
,
row_sum
,
sum_op
);
flash
::
template
warp_allreduce_tp4
(
row_sum
,
sRow_sum_reduce_buffer
,
sum_op
);
TensorT
lse
=
make_fragment_like
(
row_sum
);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
1
;
++
mi
)
{
float
sum
=
row_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
descale_v
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__logf
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
// #pragma unroll
// for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
acco_f
[
i
].
x
*=
scale
;
acco_f
[
i
].
y
*=
scale
;
acco_f
[
i
].
z
*=
scale
;
acco_f
[
i
].
w
*=
scale
;
}
}
return
lse
;
};
};
}
// namespace flash
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment