Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1d784873
Commit
1d784873
authored
May 07, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
d25889b1
851c3ed1
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
656 additions
and
61 deletions
+656
-61
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
+189
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+72
-13
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp
...de/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp
+19
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+1
-1
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+40
-11
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+44
-14
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
...le/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
+13
-9
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+40
-11
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+3
-2
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp
...l_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp
+6
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp
...ht_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp
+6
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp
...3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp
+6
-0
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/position_embedding/CMakeLists.txt
test/position_embedding/CMakeLists.txt
+1
-0
test/position_embedding/position_embedding.cpp
test/position_embedding/position_embedding.cpp
+215
-0
No files found.
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
0 → 100644
View file @
1d784873
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include <cmath>
#include <vector>
namespace
ck_tile
{
enum
struct
PositionEncodingEnum
{
NO
=
0
,
ALIBI
=
1
,
};
/*
VERTICAL:
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
[0] 1 2 3 4 5
TOP_LEFT:
[0] 1 2 3 4 5
1 [0] 1 2 3 4
2 1 [0] 1 2 3
3 2 1 [0] 1 2
FROM_BOTTOM_RIGHT:
2 1 [0] 1 2 3
3 2 1 [0] 1 2
4 3 2 1 [0] 1
5 4 3 2 1 [0]
*/
enum
struct
AlibiMode
{
VERTICAL
=
0
,
FROM_TOP_LEFT
=
1
,
// keep sync with mask enum
FROM_BOTTOM_RIGHT
=
2
,
};
template
<
typename
DataType
,
bool
RowMajor
=
true
>
struct
Alibi
{
// RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
CK_TILE_HOST_DEVICE
Alibi
(
DataType
slope_
,
index_t
y_total_
,
index_t
x_total_
,
AlibiMode
mode_
=
AlibiMode
::
VERTICAL
)
{
slope
=
mode_
==
AlibiMode
::
VERTICAL
?
slope_
:
-
slope
;
shift_left_up
=
[
&
]()
{
if
(
RowMajor
)
{
return
mode_
==
AlibiMode
::
FROM_BOTTOM_RIGHT
?
max
(
y_total_
-
x_total_
,
0
)
:
0
;
}
else
{
return
mode_
==
AlibiMode
::
FROM_BOTTOM_RIGHT
?
max
(
x_total_
-
y_total_
,
0
)
:
0
;
}
}();
shift_right_down
=
[
&
]()
{
if
(
RowMajor
)
{
return
mode_
==
AlibiMode
::
FROM_BOTTOM_RIGHT
?
max
(
x_total_
-
y_total_
,
0
)
:
0
;
}
else
{
return
mode_
==
AlibiMode
::
FROM_BOTTOM_RIGHT
?
max
(
y_total_
-
x_total_
,
0
)
:
0
;
}
}();
mode
=
mode_
;
}
CK_TILE_HOST_DEVICE
void
update
(
DataType
&
pixel
,
index_t
row_idx
,
index_t
col_idx
)
{
if
constexpr
(
RowMajor
)
{
// at least 3 instructions per row
index_t
current_zero_point
=
mode
==
AlibiMode
::
VERTICAL
?
shift_right_down
:
row_idx
+
shift_right_down
;
// for every threads, most of the pixels are along the row, below operation should be
// the main hot spot.
auto
position
=
type_convert
<
DataType
>
(
sad
(
bit_cast
<
uint32_t
>
(
current_zero_point
),
bit_cast
<
uint32_t
>
(
col_idx
+
shift_left_up
),
0
));
pixel
+=
slope
*
position
;
}
else
{
// at least 3 instructions per col;
index_t
current_zero_point
=
mode
==
AlibiMode
::
VERTICAL
?
row_idx
+
col_idx
+
shift_right_down
:
col_idx
+
shift_right_down
;
// for every threads, most of the pixels are along the col, below operation should be
// the main hot spot.
auto
position
=
type_convert
<
DataType
>
(
sad
(
bit_cast
<
uint32_t
>
(
current_zero_point
),
bit_cast
<
uint32_t
>
(
row_idx
+
shift_left_up
),
0
));
pixel
+=
slope
*
position
;
}
}
DataType
slope
;
// float?
index_t
shift_left_up
;
// always possitive
index_t
shift_right_down
;
// always possitive
AlibiMode
mode
;
};
template
<
typename
DataType
>
struct
EmptyPositionEncoding
{
CK_TILE_HOST_DEVICE
void
update
(
DataType
&
/*pixel*/
,
index_t
/*row_idx*/
,
index_t
/*col_idx*/
)
{
}
};
//
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
template
<
typename
DataType
,
bool
RowMajor
=
true
>
CK_TILE_HOST_DEVICE
auto
make_alibi_from_lr_mask
(
DataType
slope
,
index_t
window_left_size
,
index_t
window_right_size
,
index_t
y_total
,
index_t
x_total
,
GenericAttentionMaskEnum
mask_enum
)
{
// assume mask_enum will never be NO_MASK, since if we do not have mask, it's
// totally OK to use constexpr
bool
is_causal
=
window_left_size
<
0
&&
window_right_size
==
0
;
AlibiMode
alibi_mode
=
is_causal
?
AlibiMode
::
VERTICAL
:
static_cast
<
AlibiMode
>
(
mask_enum
)
/*either top-left or bottom-right*/
;
return
Alibi
<
DataType
,
RowMajor
>
{
slope
,
y_total
,
x_total
,
alibi_mode
};
}
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
// Do we need a device version?
template
<
typename
DataType
>
CK_TILE_HOST
std
::
vector
<
DataType
>
get_alibi_slopes
(
ck_tile
::
index_t
nheads
)
{
auto
get_slopes_power_of_2
=
[](
ck_tile
::
index_t
n
)
{
float
start
=
std
::
powf
(
static_cast
<
float
>
(
2
),
-
std
::
powf
(
static_cast
<
float
>
(
2
),
-
static_cast
<
float
>
((
integer_log2_floor
(
n
)
-
3
))));
std
::
vector
<
DataType
>
rtn
;
for
(
auto
i
=
0
;
i
<
n
;
i
++
)
{
rtn
.
push_back
(
static_cast
<
DataType
>
(
start
*
std
::
powf
(
start
,
i
)));
}
return
rtn
;
};
if
(
is_power_of_two_integer
(
nheads
))
{
// power of 2 calculation
return
get_slopes_power_of_2
(
nheads
);
}
else
{
ck_tile
::
index_t
closest_power_of_2
=
1
<<
integer_log2_floor
(
nheads
);
auto
v0
=
get_slopes_power_of_2
(
closest_power_of_2
);
auto
v1
=
get_slopes_power_of_2
(
closest_power_of_2
*
2
);
auto
v1_sliced
=
[
&
](
auto
vec
,
ck_tile
::
index_t
rem
)
{
std
::
vector
<
DataType
>
sliced
;
for
(
ck_tile
::
index_t
i
=
0
;
i
<
static_cast
<
ck_tile
::
index_t
>
(
vec
.
size
());
i
++
)
{
if
(
i
%
2
==
0
)
sliced
.
push_back
(
vec
[
i
]);
}
std
::
vector
<
DataType
>
sliced_2
(
sliced
.
begin
(),
sliced
.
begin
()
+
rem
);
return
sliced_2
;
}(
v1
,
nheads
-
closest_power_of_2
);
v0
.
insert
(
v0
.
end
(),
v1_sliced
.
begin
(),
v1_sliced
.
end
());
return
v0
;
}
}
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
1d784873
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <string>
#include <type_traits>
#include <type_traits>
...
@@ -33,6 +34,7 @@ struct FmhaFwdKernel
...
@@ -33,6 +34,7 @@ struct FmhaFwdKernel
using
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
BiasDataType
>
;
using
BiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
BiasDataType
>
;
using
LSEDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
LSEDataType
>
;
using
LSEDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
LSEDataType
>
;
using
ODataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
ODataType
>
;
using
ODataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
ODataType
>
;
using
SaccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
SaccDataType
>
;
using
VLayout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VLayout
>
;
using
VLayout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
VLayout
>
;
...
@@ -41,7 +43,7 @@ struct FmhaFwdKernel
...
@@ -41,7 +43,7 @@ struct FmhaFwdKernel
static
constexpr
bool
kPadSeqLenK
=
FmhaPipeline
::
kPadSeqLenK
;
static
constexpr
bool
kPadSeqLenK
=
FmhaPipeline
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
FmhaPipeline
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
...
@@ -81,7 +83,8 @@ struct FmhaFwdKernel
...
@@ -81,7 +83,8 @@ struct FmhaFwdKernel
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
kHasBias
?
"_bias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _SS_
#undef _TS_
#undef _TS_
// clang-format on
// clang-format on
...
@@ -136,6 +139,13 @@ struct FmhaFwdKernel
...
@@ -136,6 +139,13 @@ struct FmhaFwdKernel
ck_tile
::
index_t
batch_stride_bias
=
0
;
ck_tile
::
index_t
batch_stride_bias
=
0
;
};
};
struct
FmhaFwdAlibiKargs
{
// alibi is batch*nhead*1, no matter in batch/group mode, they are the same
const
void
*
alibi_slope_ptr
;
ck_tile
::
index_t
alibi_slope_stride
;
// stride in batch, or 0 for all batch share same slope
};
struct
FmhaFwdMaskKargs
struct
FmhaFwdMaskKargs
{
{
// ck_tile::index_t window_size_left, window_size_right;
// ck_tile::index_t window_size_left, window_size_right;
...
@@ -162,7 +172,11 @@ struct FmhaFwdKernel
...
@@ -162,7 +172,11 @@ struct FmhaFwdKernel
struct
FmhaFwdBatchModeKargs
struct
FmhaFwdBatchModeKargs
:
FmhaFwdCommonKargs
,
:
FmhaFwdCommonKargs
,
std
::
conditional_t
<
kHasBias
,
FmhaFwdBatchModeBiasKargs
,
FmhaFwdEmptyKargs
<
0
>>
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
,
FmhaFwdBatchModeBiasKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
,
FmhaFwdAlibiKargs
,
FmhaFwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
FmhaFwdMaskKargs
,
FmhaFwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaFwdMaskKargs
,
FmhaFwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kStoreLSE
,
FmhaFwdBatchModeLSEKargs
,
FmhaFwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kStoreLSE
,
FmhaFwdBatchModeLSEKargs
,
FmhaFwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
FmhaFwdFp8StaticQuantKargs
,
FmhaFwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kDoFp8StaticQuant
,
FmhaFwdFp8StaticQuantKargs
,
FmhaFwdEmptyKargs
<
3
>>
...
@@ -175,7 +189,11 @@ struct FmhaFwdKernel
...
@@ -175,7 +189,11 @@ struct FmhaFwdKernel
struct
FmhaFwdGroupModeKargs
struct
FmhaFwdGroupModeKargs
:
FmhaFwdCommonKargs
,
:
FmhaFwdCommonKargs
,
std
::
conditional_t
<
kHasBias
,
FmhaFwdCommonBiasKargs
,
FmhaFwdEmptyKargs
<
0
>>
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
,
FmhaFwdCommonBiasKargs
,
std
::
conditional_t
<
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
,
FmhaFwdAlibiKargs
,
FmhaFwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasMask
,
FmhaFwdMaskKargs
,
FmhaFwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaFwdMaskKargs
,
FmhaFwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kStoreLSE
,
FmhaFwdCommonLSEKargs
,
FmhaFwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kStoreLSE
,
FmhaFwdCommonLSEKargs
,
FmhaFwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
FmhaFwdFp8StaticQuantKargs
,
FmhaFwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kDoFp8StaticQuant
,
FmhaFwdFp8StaticQuantKargs
,
FmhaFwdEmptyKargs
<
3
>>
...
@@ -255,13 +273,18 @@ struct FmhaFwdKernel
...
@@ -255,13 +273,18 @@ struct FmhaFwdKernel
batch_stride_v
,
batch_stride_v
,
batch_stride_o
};
batch_stride_o
};
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
kargs
.
batch_stride_bias
=
batch_stride_bias
;
kargs
.
batch_stride_bias
=
batch_stride_bias
;
}
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
kargs
.
alibi_slope_ptr
=
bias_ptr
;
kargs
.
alibi_slope_stride
=
stride_bias
;
}
if
constexpr
(
kHasMask
)
if
constexpr
(
kHasMask
)
{
{
kargs
.
window_size_left
=
window_size_left
;
kargs
.
window_size_left
=
window_size_left
;
...
@@ -345,12 +368,17 @@ struct FmhaFwdKernel
...
@@ -345,12 +368,17 @@ struct FmhaFwdKernel
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
bias_ptr
=
bias_ptr
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
stride_bias
=
stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
kargs
.
nhead_stride_bias
=
nhead_stride_bias
;
}
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
kargs
.
alibi_slope_ptr
=
bias_ptr
;
kargs
.
alibi_slope_stride
=
stride_bias
;
}
if
constexpr
(
kHasMask
)
if
constexpr
(
kHasMask
)
{
{
kargs
.
window_size_left
=
window_size_left
;
kargs
.
window_size_left
=
window_size_left
;
...
@@ -421,14 +449,10 @@ struct FmhaFwdKernel
...
@@ -421,14 +449,10 @@ struct FmhaFwdKernel
{
{
batch_offset_v
=
key_start
;
batch_offset_v
=
key_start
;
}
}
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
}
else
{
batch_offset_bias
=
key_start
;
}
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
batch_offset_lse
=
query_start
;
batch_offset_lse
=
query_start
;
...
@@ -461,7 +485,7 @@ struct FmhaFwdKernel
...
@@ -461,7 +485,7 @@ struct FmhaFwdKernel
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
}
}
...
@@ -585,7 +609,7 @@ struct FmhaFwdKernel
...
@@ -585,7 +609,7 @@ struct FmhaFwdKernel
const
auto
bias_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
const
auto
bias_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
bias_dram_window_lengths
=
constexpr
auto
bias_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
const
BiasDataType
*
bias_ptr
=
const
BiasDataType
*
bias_ptr
=
reinterpret_cast
<
const
BiasDataType
*>
(
kargs
.
bias_ptr
)
+
reinterpret_cast
<
const
BiasDataType
*>
(
kargs
.
bias_ptr
)
+
...
@@ -654,6 +678,39 @@ struct FmhaFwdKernel
...
@@ -654,6 +678,39 @@ struct FmhaFwdKernel
return
FmhaMask
{
kargs
.
seqlen_q
,
kargs
.
seqlen_k
};
return
FmhaMask
{
kargs
.
seqlen_q
,
kargs
.
seqlen_k
};
}();
}();
// WA i_batch capture structure binding before c++20
auto
position_encoding
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
// data loading, shared by entire wg
// TODO: how to use s_read?
SaccDataType
slope
=
*
(
reinterpret_cast
<
const
SaccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope
*=
ck_tile
::
log2e_v
<>
;
#endif
if
constexpr
(
kHasMask
)
{
return
make_alibi_from_lr_mask
<
SaccDataType
,
true
>
(
slope
,
kargs
.
window_size_left
,
kargs
.
window_size_right
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
kargs
.
mask_type
);
}
else
{
return
Alibi
<
SaccDataType
,
true
>
{
slope
,
kargs
.
seqlen_q
,
kargs
.
seqlen_k
,
AlibiMode
::
VERTICAL
};
}
}
else
{
return
EmptyPositionEncoding
<
SaccDataType
>
{};
}
}();
auto
o_acc_tile
=
[
&
]()
{
auto
o_acc_tile
=
[
&
]()
{
if
constexpr
(
kDoFp8StaticQuant
)
if
constexpr
(
kDoFp8StaticQuant
)
{
{
...
@@ -672,6 +729,7 @@ struct FmhaFwdKernel
...
@@ -672,6 +729,7 @@ struct FmhaFwdKernel
scales
{
kargs
.
scale_p
},
// p_compute_element_func
scales
{
kargs
.
scale_p
},
// p_compute_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
mask
,
mask
,
position_encoding
,
kargs
.
scale_s
,
kargs
.
scale_s
,
smem_ptr
);
smem_ptr
);
}
}
...
@@ -683,6 +741,7 @@ struct FmhaFwdKernel
...
@@ -683,6 +741,7 @@ struct FmhaFwdKernel
bias_dram_window
,
bias_dram_window
,
lse_dram_window
,
lse_dram_window
,
mask
,
mask
,
position_encoding
,
kargs
.
scale_s
,
kargs
.
scale_s
,
smem_ptr
);
smem_ptr
);
}
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp
View file @
1d784873
...
@@ -13,4 +13,23 @@ enum class BlockFmhaPipelineEnum
...
@@ -13,4 +13,23 @@ enum class BlockFmhaPipelineEnum
QSKSVS
,
QSKSVS
,
};
};
template
<
BlockFmhaPipelineEnum
>
struct
BlockFmhaPipelineEnumToStr
;
template
<
>
struct
BlockFmhaPipelineEnumToStr
<
BlockFmhaPipelineEnum
::
QRKSVS
>
{
static
constexpr
const
char
*
name
=
"qr"
;
};
template
<
>
struct
BlockFmhaPipelineEnumToStr
<
BlockFmhaPipelineEnum
::
QRKSVS_ASYNC
>
{
static
constexpr
const
char
*
name
=
"qr_async"
;
};
template
<
>
struct
BlockFmhaPipelineEnumToStr
<
BlockFmhaPipelineEnum
::
QSKSVS
>
{
static
constexpr
const
char
*
name
=
"qs"
;
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
1d784873
...
@@ -45,7 +45,7 @@ struct BlockFmhaPipelineProblem
...
@@ -45,7 +45,7 @@ struct BlockFmhaPipelineProblem
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
Traits
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
Traits
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
1d784873
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
...
@@ -46,7 +47,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -46,7 +47,7 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
Problem
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
Problem
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
...
@@ -82,7 +83,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -82,7 +83,7 @@ struct BlockFmhaPipelineQRKSVS
}
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
{
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
return
1
;
else
else
return
2
;
return
2
;
...
@@ -113,7 +114,8 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -113,7 +114,8 @@ struct BlockFmhaPipelineQRKSVS
typename
LSEElementFunction
,
typename
LSEElementFunction
,
typename
SAccElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
>
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
QElementFunction
&
q_element_func
,
...
@@ -129,6 +131,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -129,6 +131,7 @@ struct BlockFmhaPipelineQRKSVS
const
PComputeElementFunction
&
p_compute_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
...
@@ -270,13 +273,13 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -270,13 +273,13 @@ struct BlockFmhaPipelineQRKSVS
k_block_tile
=
load_tile
(
k_dram_window
);
k_block_tile
=
load_tile
(
k_dram_window
);
}
}
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
__builtin_amdgcn_sched_barrier
(
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
0
);
// prevent from messing up the order of global loads
}
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
__builtin_amdgcn_sched_barrier
(
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
0
);
// prevent from messing up the order of global loads
...
@@ -322,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -322,7 +325,7 @@ struct BlockFmhaPipelineQRKSVS
}
}
// STAGE 2, scale_s, add bias, mask, softmax
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
...
@@ -338,6 +341,25 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -338,6 +341,25 @@ struct BlockFmhaPipelineQRKSVS
s_acc
,
s_acc
,
bias_tile
);
bias_tile
);
}
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
else
{
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
...
@@ -382,7 +404,8 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -382,7 +404,8 @@ struct BlockFmhaPipelineQRKSVS
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
/// consideration
if
constexpr
(
kHasBias
||
FmhaMask
::
IsMasking
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
...
@@ -403,7 +426,8 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -403,7 +426,8 @@ struct BlockFmhaPipelineQRKSVS
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
}
...
@@ -427,7 +451,8 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -427,7 +451,8 @@ struct BlockFmhaPipelineQRKSVS
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
}
...
@@ -519,7 +544,8 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -519,7 +544,8 @@ struct BlockFmhaPipelineQRKSVS
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
lse
(
i_idx
)
=
m_
[
i_idx
]
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
lse
(
i_idx
)
=
m_
[
i_idx
]
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
}
...
@@ -563,7 +589,8 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -563,7 +589,8 @@ struct BlockFmhaPipelineQRKSVS
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
>
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
...
@@ -571,6 +598,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -571,6 +598,7 @@ struct BlockFmhaPipelineQRKSVS
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
...
@@ -588,6 +616,7 @@ struct BlockFmhaPipelineQRKSVS
...
@@ -588,6 +616,7 @@ struct BlockFmhaPipelineQRKSVS
identity
{},
identity
{},
identity
{},
identity
{},
mask
,
mask
,
position_encoding
,
scale_s
,
scale_s
,
smem_ptr
);
smem_ptr
);
}
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
1d784873
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
...
@@ -51,7 +52,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -51,7 +52,7 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimQ
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kPadHeadDimV
=
true
;
// support multiple of vector(like 8x)
static
constexpr
bool
kHasBias
=
Problem
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
Problem
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
...
@@ -79,21 +80,22 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -79,21 +80,22 @@ struct BlockFmhaPipelineQRKSVSAsync
{
{
if
constexpr
(
kK0BlockLength
<=
32
)
if
constexpr
(
kK0BlockLength
<=
32
)
{
{
if
constexpr
(
kPadSeqLenK
&&
kHasBias
&&
FmhaMask
::
IsMasking
)
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
FmhaMask
::
IsMasking
)
return
1
;
return
1
;
else
else
return
2
;
return
2
;
}
}
else
if
constexpr
(
kK0BlockLength
<=
64
)
else
if
constexpr
(
kK0BlockLength
<=
64
)
{
{
if
constexpr
(
kPadSeqLenK
&&
kHas
Bias
)
if
constexpr
(
kPadSeqLenK
&&
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
2
;
return
2
;
else
else
return
3
;
return
3
;
}
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
{
if
constexpr
(
kPadSeqLenK
&&
kHas
Bias
)
if
constexpr
(
kPadSeqLenK
&&
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
return
1
;
else
else
return
2
;
return
2
;
...
@@ -124,7 +126,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -124,7 +126,8 @@ struct BlockFmhaPipelineQRKSVSAsync
typename
LSEElementFunction
,
typename
LSEElementFunction
,
typename
SAccElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
>
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
QElementFunction
&
q_element_func
,
...
@@ -140,6 +143,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -140,6 +143,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const
PComputeElementFunction
&
p_compute_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
...
@@ -247,8 +251,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -247,8 +251,8 @@ struct BlockFmhaPipelineQRKSVSAsync
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
const
auto
num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
// check early exit
if masked and no work to do.
// check early exit
if
constexpr
(
FmhaMask
::
IsMasking
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
)
{
{
if
(
num_total_loop
<=
0
)
if
(
num_total_loop
<=
0
)
{
{
...
@@ -367,7 +371,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -367,7 +371,7 @@ struct BlockFmhaPipelineQRKSVSAsync
__builtin_amdgcn_sched_barrier
(
1
);
__builtin_amdgcn_sched_barrier
(
1
);
// STAGE 2, scale_s, add bias, mask, softmax
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
...
@@ -383,6 +387,25 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -383,6 +387,25 @@ struct BlockFmhaPipelineQRKSVSAsync
s_acc
,
s_acc
,
bias_tile
);
bias_tile
);
}
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
else
{
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
...
@@ -463,8 +486,9 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -463,8 +486,9 @@ struct BlockFmhaPipelineQRKSVSAsync
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
/// consideration. alibi does not have this problem
if
constexpr
(
kHasBias
||
FmhaMask
::
IsMasking
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
...
@@ -485,7 +509,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -485,7 +509,8 @@ struct BlockFmhaPipelineQRKSVSAsync
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
}
...
@@ -509,7 +534,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -509,7 +534,8 @@ struct BlockFmhaPipelineQRKSVSAsync
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
}
...
@@ -617,7 +643,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -617,7 +643,8 @@ struct BlockFmhaPipelineQRKSVSAsync
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
lse
(
i_idx
)
=
m_
[
i_idx
]
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
lse
(
i_idx
)
=
m_
[
i_idx
]
*
R_LOG2E
+
log
(
l_
[
i_idx
]);
}
}
...
@@ -661,7 +688,8 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -661,7 +688,8 @@ struct BlockFmhaPipelineQRKSVSAsync
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
>
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
...
@@ -669,6 +697,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -669,6 +697,7 @@ struct BlockFmhaPipelineQRKSVSAsync
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
...
@@ -686,6 +715,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -686,6 +715,7 @@ struct BlockFmhaPipelineQRKSVSAsync
identity
{},
identity
{},
identity
{},
identity
{},
mask
,
mask
,
position_encoding
,
scale_s
,
scale_s
,
smem_ptr
);
smem_ptr
);
}
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
View file @
1d784873
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
...
@@ -46,7 +47,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -46,7 +47,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
Problem
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
Problem
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
...
@@ -82,7 +83,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -82,7 +83,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
}
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
{
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
return
1
;
else
else
return
2
;
return
2
;
...
@@ -105,7 +106,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -105,7 +106,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
>
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
...
@@ -113,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -113,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
/*lse_dram_window_tmp*/
,
// not supported
LSEDramBlockWindowTmp
&
/*lse_dram_window_tmp*/
,
// not supported
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
/*position_encoding*/
,
float
scale_s
,
float
scale_s
,
float
descale_qk
,
float
descale_qk
,
float
descale_sv
,
float
descale_sv
,
...
@@ -249,13 +252,13 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -249,13 +252,13 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
k_block_tile
=
load_tile
(
k_dram_window
);
k_block_tile
=
load_tile
(
k_dram_window
);
}
}
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
__builtin_amdgcn_sched_barrier
(
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
0
);
// prevent from messing up the order of global loads
}
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
__builtin_amdgcn_sched_barrier
(
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
0
);
// prevent from messing up the order of global loads
...
@@ -300,7 +303,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -300,7 +303,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
}
}
// STAGE 2, scale_s, add bias, mask, softmax
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
tile_elementwise_inout
(
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
[
&
](
auto
&
x
,
const
auto
&
y
)
{
...
@@ -356,7 +359,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -356,7 +359,8 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
/// consideration
if
constexpr
(
kHasBias
||
FmhaMask
::
IsMasking
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
...
@@ -377,7 +381,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -377,7 +381,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
}
...
@@ -401,7 +405,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
...
@@ -401,7 +405,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
1d784873
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -45,7 +46,7 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -45,7 +46,7 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kHasBias
=
Problem
::
kHas
Bias
;
static
constexpr
auto
BiasEnum
=
Problem
::
Bias
Enum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
index_t
kBlockPerCu
=
[]()
{
static
constexpr
index_t
kBlockPerCu
=
[]()
{
...
@@ -63,7 +64,7 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -63,7 +64,7 @@ struct BlockFmhaPipelineQSKSVS
}
}
else
if
constexpr
(
kK0BlockLength
<=
128
)
else
if
constexpr
(
kK0BlockLength
<=
128
)
{
{
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
return
1
;
else
else
return
2
;
return
2
;
...
@@ -99,7 +100,8 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -99,7 +100,8 @@ struct BlockFmhaPipelineQSKSVS
typename
LSEElementFunction
,
typename
LSEElementFunction
,
typename
SAccElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
>
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
QElementFunction
&
q_element_func
,
...
@@ -115,6 +117,7 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -115,6 +117,7 @@ struct BlockFmhaPipelineQSKSVS
const
PComputeElementFunction
&
p_compute_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
...
@@ -265,13 +268,13 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -265,13 +268,13 @@ struct BlockFmhaPipelineQSKSVS
k_block_tile
=
load_tile
(
k_dram_window
);
k_block_tile
=
load_tile
(
k_dram_window
);
}
}
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
__builtin_amdgcn_sched_barrier
(
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
0
);
// prevent from messing up the order of global loads
}
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
__builtin_amdgcn_sched_barrier
(
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
0
);
// prevent from messing up the order of global loads
...
@@ -313,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -313,7 +316,7 @@ struct BlockFmhaPipelineQSKSVS
}
}
// STAGE 2, scale_s, add bias, mask, softmax
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
kHas
Bias
)
if
constexpr
(
Bias
Enum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
...
@@ -329,6 +332,25 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -329,6 +332,25 @@ struct BlockFmhaPipelineQSKSVS
s_acc
,
s_acc
,
bias_tile
);
bias_tile
);
}
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_dram_block_window
.
get_window_origin
();
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
);
});
});
}
else
else
{
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
...
@@ -373,7 +395,8 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -373,7 +395,8 @@ struct BlockFmhaPipelineQSKSVS
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
/// consideration
if
constexpr
(
kHasBias
||
FmhaMask
::
IsMasking
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
...
@@ -394,7 +417,8 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -394,7 +417,8 @@ struct BlockFmhaPipelineQSKSVS
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
p_compute
(
i_j_idx
)
=
exp2
(
s
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
}
...
@@ -418,7 +442,8 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -418,7 +442,8 @@ struct BlockFmhaPipelineQSKSVS
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
}
...
@@ -510,7 +535,8 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -510,7 +535,8 @@ struct BlockFmhaPipelineQSKSVS
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
sweep_tile_span
(
lse_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
kHasBias
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
lse
(
i_idx
)
=
m_
[
i_idx
]
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
lse
(
i_idx
)
=
m_
[
i_idx
]
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
}
...
@@ -554,7 +580,8 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -554,7 +580,8 @@ struct BlockFmhaPipelineQSKSVS
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
>
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
...
@@ -562,6 +589,7 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -562,6 +589,7 @@ struct BlockFmhaPipelineQSKSVS
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
...
@@ -579,6 +607,7 @@ struct BlockFmhaPipelineQSKSVS
...
@@ -579,6 +607,7 @@ struct BlockFmhaPipelineQSKSVS
identity
{},
identity
{},
identity
{},
identity
{},
mask
,
mask
,
position_encoding
,
scale_s
,
scale_s
,
smem_ptr
);
smem_ptr
);
}
}
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
1d784873
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -11,7 +12,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
...
@@ -11,7 +12,7 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool
kPadSeqLenK_
/* padding for seqlen_k */
,
bool
kPadSeqLenK_
/* padding for seqlen_k */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
bool
kPadHeadDimQ_
/* paddding for hdim_q */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
bool
kPadHeadDimV_
/* paddding for hdim_v */
,
bool
kHasBias
_
,
BlockAttentionBiasEnum
BiasEnum
_
,
bool
kStoreLSE_
,
bool
kStoreLSE_
,
bool
kDoFp8StaticQuant_
,
bool
kDoFp8StaticQuant_
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
...
@@ -21,7 +22,7 @@ struct TileFmhaTraits
...
@@ -21,7 +22,7 @@ struct TileFmhaTraits
static
constexpr
bool
kPadSeqLenK
=
kPadSeqLenK_
;
static
constexpr
bool
kPadSeqLenK
=
kPadSeqLenK_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
bool
kPadHeadDimQ
=
kPadHeadDimQ_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
bool
kPadHeadDimV
=
kPadHeadDimV_
;
static
constexpr
bool
kHasBias
=
kHas
Bias_
;
static
constexpr
auto
BiasEnum
=
Bias
Enum
_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kStoreLSE
=
kStoreLSE_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_data/xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_input_f16_comp_bf8_f8_instance.cpp
View file @
1d784873
...
@@ -26,6 +26,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_
...
@@ -26,6 +26,8 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_
BF8
,
BF8
,
F8
>>>&
instances
)
F8
>>>&
instances
)
{
{
#if CK_BUILD_DEPRECATED
#pragma message "These instances are getting deprecated"
// 1. Default
// 1. Default
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
@@ -44,6 +46,10 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_
...
@@ -44,6 +46,10 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_
Empty_Tuple
,
Empty_Tuple
,
NDHWGC
,
NDHWGC
,
ConvBwdDataFilter1x1Stride1Pad0
>
{});
ConvBwdDataFilter1x1Stride1Pad0
>
{});
#else
#pragma message "These instances were deprecated"
std
::
ignore
=
instances
;
#endif
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp
View file @
1d784873
...
@@ -23,6 +23,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_
...
@@ -23,6 +23,8 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_
BF8
,
BF8
,
F8
>>>&
instances
)
F8
>>>&
instances
)
{
{
#if CK_BUILD_DEPRECATED
#pragma message "These instances are getting deprecated"
// 1. Default
// 1. Default
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
@@ -41,6 +43,10 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_
...
@@ -41,6 +43,10 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_
GKZYXC
,
GKZYXC
,
NDHWGK
,
NDHWGK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
#else
#pragma message "These instances were deprecated"
std
::
ignore
=
instances
;
#endif
}
}
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_fp8_instance.cpp
View file @
1d784873
...
@@ -24,6 +24,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance
...
@@ -24,6 +24,8 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance
PassThrough
,
PassThrough
,
F8
>>>&
instances
)
F8
>>>&
instances
)
{
{
#if CK_BUILD_DEPRECATED
#pragma message "These instances are getting deprecated"
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
device_grouped_conv_fwd_xdl_f16_comp_f8_instances
<
3
,
device_grouped_conv_fwd_xdl_f16_comp_f8_instances
<
3
,
...
@@ -48,6 +50,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance
...
@@ -48,6 +50,10 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_f8_instance
Empty_Tuple
,
Empty_Tuple
,
NDHWGK
,
NDHWGK
,
ConvFwd1x1S1P0
>
{});
ConvFwd1x1S1P0
>
{});
#else
#pragma message "These instances were deprecated"
std
::
ignore
=
instances
;
#endif
}
}
}
// namespace instance
}
// namespace instance
...
...
test/CMakeLists.txt
View file @
1d784873
...
@@ -181,3 +181,4 @@ add_subdirectory(wrapper)
...
@@ -181,3 +181,4 @@ add_subdirectory(wrapper)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
add_subdirectory
(
wmma_op
)
endif
()
endif
()
add_subdirectory
(
position_embedding
)
test/position_embedding/CMakeLists.txt
0 → 100644
View file @
1d784873
add_test_executable
(
test_position_embedding position_embedding.cpp
)
test/position_embedding/position_embedding.cpp
0 → 100644
View file @
1d784873
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha.hpp"
#ifndef TEST_ALIBI_VERBOSE
#define TEST_ALIBI_VERBOSE 0
#endif
template
<
typename
DataType
>
struct
attention_score
{
ck_tile
::
index_t
rows
,
cols
;
std
::
vector
<
DataType
>
pixels
;
attention_score
(
ck_tile
::
index_t
rows_
,
ck_tile
::
index_t
cols_
,
DataType
init_v_
=
static_cast
<
DataType
>
(
0
))
:
rows
(
rows_
),
cols
(
cols_
),
pixels
(
rows_
*
cols_
,
init_v_
)
{
}
auto
&
operator
()(
ck_tile
::
index_t
i_row
,
ck_tile
::
index_t
i_col
)
{
return
pixels
[
i_row
*
cols
+
i_col
];
}
void
print
()
{
for
(
auto
i_row
=
0
;
i_row
<
rows
;
i_row
++
)
{
for
(
auto
i_col
=
0
;
i_col
<
cols
;
i_col
++
)
{
std
::
cout
<<
pixels
[
i_row
*
cols
+
i_col
]
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
}
}
};
template
<
bool
RowMajor
,
typename
DataType
>
void
alibi_traverse_with_slope
(
attention_score
<
DataType
>&
score
,
DataType
slope
,
ck_tile
::
AlibiMode
mode
=
ck_tile
::
AlibiMode
::
VERTICAL
)
{
using
Alibi
=
ck_tile
::
Alibi
<
DataType
,
RowMajor
>
;
auto
alibi
=
Alibi
{
slope
,
score
.
rows
,
score
.
cols
,
mode
};
for
(
ck_tile
::
index_t
i_row
=
0
;
i_row
<
score
.
rows
;
i_row
++
)
{
for
(
ck_tile
::
index_t
i_col
=
0
;
i_col
<
score
.
cols
;
i_col
++
)
{
alibi
.
update
(
score
(
i_row
,
i_col
),
i_row
,
i_col
);
}
}
}
std
::
string
alibi_mode_to_str
(
ck_tile
::
AlibiMode
mode
)
{
if
(
mode
==
ck_tile
::
AlibiMode
::
VERTICAL
)
return
std
::
string
(
"alibi_verti"
);
else
if
(
mode
==
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
)
return
std
::
string
(
"alibi_top-l"
);
else
if
(
mode
==
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
)
return
std
::
string
(
"alibi_bot-r"
);
return
""
;
}
template
<
bool
RowMajor
,
typename
DataType
>
bool
test_alibi_traverse_with_slope
(
ck_tile
::
index_t
rows
,
ck_tile
::
index_t
cols
,
DataType
slope
,
ck_tile
::
AlibiMode
mode
,
const
std
::
vector
<
DataType
>&
expected
)
{
attention_score
<
DataType
>
score
{
rows
,
cols
};
alibi_traverse_with_slope
<
RowMajor
,
DataType
>
(
score
,
slope
,
mode
);
bool
is_match
=
std
::
equal
(
score
.
pixels
.
begin
(),
score
.
pixels
.
end
(),
expected
.
begin
());
#if TEST_ALIBI_VERBOSE
std
::
cout
<<
"---------"
<<
alibi_mode_to_str
(
mode
)
<<
", "
<<
rows
<<
"x"
<<
cols
<<
"("
<<
(
RowMajor
?
"row_major"
:
"col_major"
)
<<
")"
<<
(
is_match
?
", valie:y"
:
", valid:n"
)
<<
std
::
endl
;
score
.
print
();
#endif
return
is_match
;
}
template
<
typename
DataType
>
bool
test_alibi_slope_generation
(
ck_tile
::
index_t
nheads
,
const
std
::
vector
<
DataType
>&
expected
)
{
auto
slopes
=
ck_tile
::
get_alibi_slopes
<
DataType
>
(
nheads
);
bool
is_match
=
std
::
equal
(
slopes
.
begin
(),
slopes
.
end
(),
expected
.
begin
(),
expected
.
end
(),
[](
const
DataType
&
lhs
,
const
DataType
&
rhs
)
{
constexpr
float
rtol
=
1e-6
;
auto
error
=
std
::
abs
(
lhs
-
rhs
);
return
error
<
rtol
*
std
::
abs
(
rhs
);
});
#if TEST_ALIBI_VERBOSE
std
::
cout
<<
"-------------------- slopes "
<<
nheads
<<
", "
<<
(
is_match
?
"y"
:
"n"
)
<<
std
::
endl
;
for
(
ck_tile
::
index_t
i
=
0
;
i
<
nheads
;
i
++
)
{
std
::
cout
<<
slopes
[
i
]
<<
" "
;
}
std
::
cout
<<
std
::
endl
;
#endif
return
is_match
;
}
int
main
()
{
using
dtype
=
int32_t
;
dtype
slope
=
static_cast
<
dtype
>
(
1
);
bool
rtn
=
true
;
// clang-format off
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
VERTICAL
,
{
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
4
,
5
,
1
,
0
,
1
,
2
,
3
,
4
,
2
,
1
,
0
,
1
,
2
,
3
,
3
,
2
,
1
,
0
,
1
,
2
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
1
,
0
,
1
,
2
,
2
,
1
,
0
,
1
,
3
,
2
,
1
,
0
,
4
,
3
,
2
,
1
,
5
,
4
,
3
,
2
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
1
,
0
,
1
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
1
,
0
,
1
,
2
,
3
,
3
,
2
,
1
,
0
,
1
,
2
,
4
,
3
,
2
,
1
,
0
,
1
,
5
,
4
,
3
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
3
,
4
,
5
,
1
,
2
,
3
,
4
,
0
,
1
,
2
,
3
,
1
,
0
,
1
,
2
,
2
,
1
,
0
,
1
,
3
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
true
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
1
,
2
,
1
,
0
,
1
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
VERTICAL
,
{
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
4
,
5
,
1
,
0
,
1
,
2
,
3
,
4
,
2
,
1
,
0
,
1
,
2
,
3
,
3
,
2
,
1
,
0
,
1
,
2
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
3
,
1
,
0
,
1
,
2
,
2
,
1
,
0
,
1
,
3
,
2
,
1
,
0
,
4
,
3
,
2
,
1
,
5
,
4
,
3
,
2
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_TOP_LEFT
,
{
0
,
1
,
2
,
1
,
0
,
1
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
4
,
6
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
1
,
0
,
1
,
2
,
3
,
3
,
2
,
1
,
0
,
1
,
2
,
4
,
3
,
2
,
1
,
0
,
1
,
5
,
4
,
3
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
6
,
4
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
2
,
3
,
4
,
5
,
1
,
2
,
3
,
4
,
0
,
1
,
2
,
3
,
1
,
0
,
1
,
2
,
2
,
1
,
0
,
1
,
3
,
2
,
1
,
0
});
rtn
&=
test_alibi_traverse_with_slope
<
false
,
dtype
>
(
3
,
3
,
slope
,
ck_tile
::
AlibiMode
::
FROM_BOTTOM_RIGHT
,
{
0
,
1
,
2
,
1
,
0
,
1
,
2
,
1
,
0
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
8
,
{
0.5
,
0.25
,
0.125
,
0.0625
,
0.03125
,
0.015625
,
0.0078125
,
0.00390625
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
16
,
{
0.7071067811865476
,
0.5
,
0.35355339059327384
,
0.25000000000000006
,
0.17677669529663692
,
0.12500000000000006
,
0.08838834764831849
,
0.06250000000000004
,
0.044194173824159244
,
0.03125000000000002
,
0.022097086912079626
,
0.01562500000000001
,
0.011048543456039816
,
0.007812500000000007
,
0.005524271728019908
,
0.003906250000000004
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
1
,
{
0.00390625
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
5
,
{
0.25
,
0.0625
,
0.015625
,
0.00390625
,
0.5
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
6
,
{
0.25
,
0.0625
,
0.015625
,
0.00390625
,
0.5
,
0.125
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
7
,
{
0.25
,
0.0625
,
0.015625
,
0.00390625
,
0.5
,
0.125
,
0.03125
});
rtn
&=
test_alibi_slope_generation
<
float
>
(
9
,
{
0.5
,
0.25
,
0.125
,
0.0625
,
0.03125
,
0.015625
,
0.0078125
,
0.00390625
,
0.7071067811865476
});
// clang-format on
return
rtn
?
0
:
-
1
;
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment