Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e536d321
"test/vscode:/vscode.git/clone" did not exist on "85af926ddc5f3c8fb438001743e65ec3a039ceec"
Commit
e536d321
authored
Sep 04, 2024
by
illsilin
Browse files
merge from public repo
parents
829e0eb3
52410b49
Changes
76
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1541 additions
and
66 deletions
+1541
-66
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+82
-34
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
...grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+30
-14
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+33
-1
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
+52
-2
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+9
-0
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+34
-0
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+10
-3
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+52
-1
include/ck_tile/core/utility/type_traits.hpp
include/ck_tile/core/utility/type_traits.hpp
+17
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+9
-0
include/ck_tile/host/kernel_launch.hpp
include/ck_tile/host/kernel_launch.hpp
+5
-5
include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp
...reference/reference_batched_rotary_position_embedding.hpp
+73
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+6
-2
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
+19
-3
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
+108
-0
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
+279
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
+679
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
...le/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
+42
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
e536d321
...
...
@@ -102,10 +102,9 @@ __global__ void
// offset base pointer for each work-group
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
long_index_t
e_batch_offset
=
const
long_index_t
e_group_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
));
const
auto
&
ds_
batch
_offset
=
compute_ptr_offset_of_groups
.
GetDsPtrOffset
(
g_idx
);
const
auto
&
ds_
group
_offset
=
compute_ptr_offset_of_groups
.
GetDsPtrOffset
(
g_idx
);
const
long_index_t
e_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
));
...
...
@@ -118,14 +117,14 @@ __global__ void
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_
batch
_offset
[
i
];
});
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_
group
_offset
[
i
];
});
if
constexpr
(
isMultiA
||
isMultiB
)
{
AsPointer
p_as_grid_grp
;
BsPointer
p_bs_grid_grp
;
const
auto
&
as_
batch
_offset
=
compute_ptr_offset_of_groups
.
GetAsPtrOffset
(
g_idx
);
const
auto
&
as_
group
_offset
=
compute_ptr_offset_of_groups
.
GetAsPtrOffset
(
g_idx
);
// compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true
...
...
@@ -136,27 +135,27 @@ __global__ void
static
constexpr
index_t
NumATensor
=
AGridDesc_AK0_M_AK1
::
Size
();
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
batch
_offset
[
i
]
+
as_n_offset
[
i
];
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
group
_offset
[
i
]
+
as_n_offset
[
i
];
});
}
else
{
const
long_index_t
a_n_offset
=
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
);
static_for
<
0
,
1
,
1
>
{}(
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
batch
_offset
[
i
]
+
a_n_offset
;
});
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
group
_offset
[
i
]
+
a_n_offset
;
});
}
const
auto
&
bs_
batch
_offset
=
compute_ptr_offset_of_groups
.
GetBsPtrOffset
(
g_idx
);
const
auto
&
bs_
group
_offset
=
compute_ptr_offset_of_groups
.
GetBsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumBTensor
=
BGridDesc_BK0_N_BK1
::
Size
();
static_for
<
0
,
NumBTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_bs_grid_grp
(
i
)
=
p_bs_grid
[
i
]
+
bs_
batch
_offset
[
i
];
});
[
&
](
auto
i
)
{
p_bs_grid_grp
(
i
)
=
p_bs_grid
[
i
]
+
bs_
group
_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid_grp
,
p_bs_grid_grp
,
p_ds_grid_grp
,
p_e_grid
+
e_
batch
_offset
+
e_n_offset
,
p_e_grid
+
e_
group
_offset
+
e_n_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -169,19 +168,19 @@ __global__ void
}
else
{
const
long_index_t
a_
batch
_offset
=
const
long_index_t
a_
group
_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_
batch
_offset
=
const
long_index_t
b_
group
_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
a_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
+
a_
batch
_offset
+
a_n_offset
,
p_bs_grid
+
b_
batch
_offset
,
p_as_grid
+
a_
group
_offset
+
a_n_offset
,
p_bs_grid
+
b_
group
_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_
batch
_offset
+
e_n_offset
,
p_e_grid
+
e_
group
_offset
+
e_n_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -283,7 +282,8 @@ template <index_t NDimSpatial,
// in tuple for MultiAB), unpack if tuple was
// passed
typename
BComputeDataType
=
AComputeDataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
index_t
NumGroupsToMerge
=
1
>
struct
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
:
public
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
ALayout
,
...
...
@@ -302,6 +302,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
;
static_assert
(
NumGroupsToMerge
>=
1
);
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
...
...
@@ -318,7 +320,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ConvForwardSpecialization
,
true
/*SplitN*/
,
ADataType
,
EDataType
>
;
EDataType
,
NumGroupsToMerge
>
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
...
...
@@ -517,7 +520,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
...
...
@@ -545,7 +549,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
...
...
@@ -565,8 +570,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
// p_as and p_bs are pointers
...
...
@@ -583,7 +590,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
// D batch stride
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
...
...
@@ -602,7 +610,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer_d
);
});
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
// populate desc for Ds/E
...
...
@@ -726,7 +734,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg
.
a_g_n_c_wis_lengths_
[
I1
]
/
arg
.
conv_N_per_block_
;
const
index_t
gdx
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
index_t
gdy
=
arg
.
num_group_
;
const
index_t
gdy
=
arg
.
num_group_
/
NumGroupsToMerge
;
const
index_t
gdz
=
num_workgroups_per_Conv_N
;
const
auto
K
=
...
...
@@ -850,6 +858,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace
ctc
=
tensor_layout
::
convolution
;
const
index_t
G
=
arg
.
b_g_k_c_xs_lengths_
[
I0
];
const
index_t
K
=
arg
.
b_g_k_c_xs_lengths_
[
I1
];
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
I2
];
// check device
if
(
get_device_name
()
==
"gfx908"
)
{
...
...
@@ -898,6 +910,42 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter3x3
)
{
if
(
C
!=
1
)
{
return
false
;
}
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
const
index_t
filter_spatial_dim
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
I3
];
if
(
filter_spatial_dim
!=
I3
)
{
return
false
;
}
}
if
constexpr
(
!
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
}
if
constexpr
(
NumGroupsToMerge
>
1
)
{
if
(
!
(
C
==
1
))
{
return
false
;
}
if
(
G
%
NumGroupsToMerge
!=
0
)
{
return
false
;
}
if
constexpr
(
!
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
}
// check vector access of A
// FIXME: layout
...
...
@@ -907,11 +955,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
{
const
index_t
C
=
arg
.
a_g_n_c_wis_lengths_
[
2
];
// Check access per C
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
// If not possible, check access per G
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
C
==
1
&&
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
()
&&
G
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
}
else
...
...
@@ -928,8 +981,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
{
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
2
];
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
...
...
@@ -953,8 +1004,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
DLayout
,
ctc
::
NWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
G_K
>
)
{
const
index_t
K
=
arg
.
ds_g_n_k_wos_lengths_
[
i
][
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
valid
=
false
;
...
...
@@ -999,8 +1048,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
{
const
index_t
K
=
arg
.
e_g_n_k_wos_lengths_
[
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
...
...
@@ -1298,7 +1345,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CDEBlockTransferScalarPerVector_NPerBlock
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
NumGroupsToMerge
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
View file @
e536d321
...
...
@@ -713,7 +713,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return
false
;
}
}
if
constexpr
(
!
is_NSpatialG
K
_GKSpatial_NSpatialG
C
<
ALayout
,
BLayout
,
ELayout
>
())
if
constexpr
(
!
is_NSpatialG
C
_GKSpatial_NSpatialG
K
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
e536d321
...
...
@@ -12,7 +12,7 @@ namespace device {
// 1d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NWG
K
_GKXC_NWG
C
()
constexpr
bool
is_NWG
C
_GKXC_NWG
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
...
...
@@ -20,7 +20,7 @@ constexpr bool is_NWGK_GKXC_NWGC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNW
K
_GKXC_GNW
C
()
constexpr
bool
is_GNW
C
_GKXC_GNW
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
...
...
@@ -28,7 +28,7 @@ constexpr bool is_GNWK_GKXC_GNWC()
}
// 2d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NHWG
K
_GKYXC_NHWG
C
()
constexpr
bool
is_NHWG
C
_GKYXC_NHWG
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
...
...
@@ -36,15 +36,23 @@ constexpr bool is_NHWGK_GKYXC_NHWGC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNHW
K
_GKYXC_GNHW
C
()
constexpr
bool
is_GNHW
C
_GKYXC_GNHW
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNHWK
>
;
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NGCHW_GKYXC_NGKHW
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NGCHW
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NGKHW
>
;
}
// 3d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NDHWG
K
_GKZYXC_NDHWG
C
()
constexpr
bool
is_NDHWG
C
_GKZYXC_NDHWG
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
...
...
@@ -52,7 +60,7 @@ constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNDHW
K
_GKZYXC_GNDHW
C
()
constexpr
bool
is_GNDHW
C
_GKZYXC_GNDHW
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
...
...
@@ -60,19 +68,27 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NSpatialGK_GKSpatial_NSpatialGC
()
constexpr
bool
is_NGCDHW_GKZYXC_NGKDHW
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NGCDHW
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NGKDHW
>
;
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NSpatialGC_GKSpatial_NSpatialGK
()
{
return
is_NWG
K
_GKXC_NWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
();
return
is_NWG
C
_GKXC_NWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNSpatial
K
_GKSpatial_GNSpatial
C
()
constexpr
bool
is_GNSpatial
C
_GKSpatial_GNSpatial
K
()
{
return
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
();
return
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
...
...
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
e536d321
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -115,6 +115,23 @@ struct NDHWGC : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"NDHWGC"
;
};
// input tensor
// packed NGCW/NGCHW/NGCDHW
struct
NGCW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGCW"
;
};
struct
NGCHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGCHW"
;
};
struct
NGCDHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGCDHW"
;
};
// input tensor
// strided layout
struct
G_NW_C
:
public
BaseTensorLayout
...
...
@@ -325,6 +342,21 @@ struct NDHWGK : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"NDHWGK"
;
};
struct
NGKW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGKW"
;
};
struct
NGKHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGKHW"
;
};
struct
NGKDHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGKDHW"
;
};
// output tensor
// strided layout
struct
G_NW_K
:
public
BaseTensorLayout
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
View file @
e536d321
...
...
@@ -41,6 +41,55 @@ __global__ void
elementwise_op
);
}
template
<
typename
GridwiseElementwiseFunctor
,
typename
InAGridDescTuple
,
typename
InBGridDescTuple
,
typename
OutAGridDescTuple
,
typename
OutBGridDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
Block2TileMapA
,
typename
Block2TileMapB
,
typename
ElementwiseOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_elementwise_dual
(
const
InBGridDescTuple
in_grid_desc_tuple_a
,
const
InBGridDescTuple
in_grid_desc_tuple_b
,
const
OutAGridDescTuple
out_grid_desc_tuple_a
,
const
OutBGridDescTuple
out_grid_desc_tuple_b
,
const
InDataTypePointerTuple
p_in_global_tuple_a
,
const
InDataTypePointerTuple
p_in_global_tuple_b
,
const
OutDataTypePointerTuple
p_out_global_tuple_a
,
const
OutDataTypePointerTuple
p_out_global_tuple_b
,
const
Block2TileMapA
block_2_tile_map_a
,
const
Block2TileMapB
block_2_tile_map_b
,
const
ElementwiseOperation
elementwise_op
,
const
index_t
a_grid_size
)
{
if
(
get_block_1d_id
()
<
a_grid_size
)
{
GridwiseElementwiseFunctor
::
Run
(
in_grid_desc_tuple_a
,
out_grid_desc_tuple_a
,
p_in_global_tuple_a
,
p_out_global_tuple_a
,
block_2_tile_map_a
,
elementwise_op
,
get_block_1d_id
());
}
else
{
GridwiseElementwiseFunctor
::
Run
(
in_grid_desc_tuple_b
,
out_grid_desc_tuple_b
,
p_in_global_tuple_b
,
p_out_global_tuple_b
,
block_2_tile_map_b
,
elementwise_op
,
get_block_1d_id
()
-
a_grid_size
);
}
}
template
<
typename
GridwiseElementwiseFunctor
,
typename
InGridDescTuple
,
typename
OutGridDescTuple
,
...
...
@@ -133,7 +182,8 @@ struct GridwiseElementwise
const
InDataTypePointerTuple
&
p_in_global_tuple
,
const
OutDataTypePointerTuple
&
p_out_global_tuple
,
const
Block2TileMap
&
block_2_tile_map
,
const
ElementwiseOperation
&
elementwise_op
)
const
ElementwiseOperation
&
elementwise_op
,
const
index_t
block_id
=
get_block_1d_id
())
{
constexpr
auto
src_datas
=
generate_tuple
(
...
...
@@ -169,7 +219,7 @@ struct GridwiseElementwise
Number
<
NumOutput
>
{});
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_
block_
1d_id
()
));
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
block_
id
));
const
index_t
m0_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
M0PerBlock
);
...
...
include/ck_tile/core/config.hpp
View file @
e536d321
...
...
@@ -46,6 +46,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
...
...
@@ -156,6 +157,14 @@
#endif
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
...
...
include/ck_tile/core/numeric/bfloat16.hpp
View file @
e536d321
...
...
@@ -17,6 +17,7 @@ enum class bf16_rounding_mode
standard
=
0
,
// rtn
truncate_with_nan
,
truncate
,
standard_asm
,
};
template
<
bf16_rounding_mode
rounding
=
...
...
@@ -148,6 +149,37 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
return
uint16_t
(
u
.
int32
>>
16
);
}
CK_TILE_HOST
constexpr
uint16_t
float_to_bf16_rtn_asm
(
float
f
)
{
return
float_to_bf16_rtn_raw
(
f
);
}
CK_TILE_DEVICE
uint16_t
float_to_bf16_rtn_asm
(
float
f
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
f
};
static
constexpr
uint32_t
FP32_NAN
=
0x7fff0000
;
static
constexpr
uint32_t
ROUND_BIAS_FOR_BF16
=
0x7fff
;
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
check_nan
;
uint32_t
tmp
;
asm
volatile
(
"
\n
\
v_cmp_u_f32 %0, %2, %2
\n
\
v_bfe_u32 %1, %2, 16, 1
\n
\
v_add3_u32 %1, %2, %1, %3
\n
\
v_cndmask_b32 %2, %1, %4, %0
\n
\
v_lshrrev_b32 %2, 16, %2
\n
\
"
:
"=s"
(
check_nan
),
"+v"
(
tmp
),
"+v"
(
u
.
fp32
)
:
"v"
(
ROUND_BIAS_FOR_BF16
),
"v"
(
FP32_NAN
));
return
uint16_t
(
u
.
int32
);
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_truc_nan_raw
(
float
f
)
...
...
@@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
{
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard
)
return
float_to_bf16_rtn_raw
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard_asm
)
return
float_to_bf16_rtn_asm
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
truncate_with_nan
)
return
float_to_bf16_truc_nan_raw
(
f
);
else
...
...
include/ck_tile/core/numeric/math.hpp
View file @
e536d321
...
...
@@ -536,13 +536,20 @@ float log(float x) { return __logf(x); };
CK_TILE_HOST
float
log
(
float
x
)
{
return
std
::
logf
(
x
);
};
CK_TILE_DEVICE
uint
32
_t
sad
(
uint
32
_t
x
,
uint
32
_t
y
,
uint
32
_t
acc
)
CK_TILE_DEVICE
uint
16
_t
sad
_u16
(
uint
16
_t
x
,
uint
16
_t
y
,
uint
16
_t
acc
)
{
// TODO: this is hacky, we use u16
return
__builtin_amdgcn_sad_u16
(
x
,
y
,
acc
);
}
CK_TILE_HOST
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
CK_TILE_DEVICE
uint32_t
sad_u32
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
/// TODO: replace inline asm when intrinsic is available
uint32_t
res
;
asm
volatile
(
"v_sad_u32 %0, %1, %2, %3"
:
"=v"
(
res
)
:
"v"
(
x
),
"v"
(
y
),
"v"
(
acc
));
return
res
;
}
CK_TILE_HOST
uint32_t
sad_u32
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
}
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
e536d321
...
...
@@ -214,6 +214,12 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
CK_TILE_DEVICE
constexpr
void
set_bottom_tensor_view_data_ptr
(
typename
BottomTensorView
::
DataType
*
data
)
{
bottom_tensor_view_
.
buf_
.
p_data_
=
data
;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
...
...
@@ -393,7 +399,8 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
...
...
@@ -843,6 +850,17 @@ struct tile_window_with_static_lengths
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
CK_TILE_DEVICE
void
set_window_origin
(
const
BottomTensorIndex
&
new_window_origin
)
{
window_origin_
=
new_window_origin
;
}
CK_TILE_DEVICE
constexpr
void
set_bottom_tensor_view_data_ptr
(
typename
BottomTensorView
::
DataType
*
data
)
{
bottom_tensor_view_
.
buf_
.
p_data_
=
data
;
}
// move window-origin
CK_TILE_DEVICE
void
move
(
const
BottomTensorIndex
&
step
)
{
window_origin_
+=
step
;
}
...
...
@@ -871,6 +889,39 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view
,
window_lengths
,
origin
};
}
// duplicate tile window and replace its origin
template
<
typename
TensorView
,
typename
WindowLengths
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
multi_index
<
TensorView
::
get_num_of_dimension
()
>&
origin
)
{
return
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>
{
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
origin
};
}
template
<
typename
TensorView
,
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
multi_index
<
TensorView
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution
&
tile_distribution
)
{
return
make_tile_window
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
origin
,
tile_distribution
);
}
template
<
typename
TensorView
,
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
StaticTileDistribution
&
tile_distribution
)
{
return
make_tile_window
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
tile_window
.
get_window_origin
(),
tile_distribution
);
}
template
<
typename
TensorView_
,
typename
WindowLengths_
>
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>&
window
,
...
...
include/ck_tile/core/utility/type_traits.hpp
View file @
e536d321
...
...
@@ -22,6 +22,23 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
template
<
typename
From
,
typename
To
>
struct
copy_const
{
static_assert
(
!
std
::
is_const_v
<
From
>
);
using
type
=
To
;
};
template
<
typename
From
,
typename
To
>
struct
copy_const
<
const
From
,
To
>
{
using
type
=
std
::
add_const_t
<
typename
copy_const
<
From
,
To
>::
type
>
;
};
template
<
typename
From
,
typename
To
>
using
copy_const_t
=
typename
copy_const
<
From
,
To
>::
type
;
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
...
...
include/ck_tile/host.hpp
View file @
e536d321
...
...
@@ -15,6 +15,7 @@
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
...
...
include/ck_tile/host/host_tensor.hpp
View file @
e536d321
...
...
@@ -155,7 +155,12 @@ struct HostTensorDescriptor
return
space
;
}
std
::
size_t
get_length
(
std
::
size_t
dim
)
const
{
return
mLens
[
dim
];
}
const
std
::
vector
<
std
::
size_t
>&
get_lengths
()
const
{
return
mLens
;
}
std
::
size_t
get_stride
(
std
::
size_t
dim
)
const
{
return
mStrides
[
dim
];
}
const
std
::
vector
<
std
::
size_t
>&
get_strides
()
const
{
return
mStrides
;
}
template
<
typename
...
Is
>
...
...
@@ -325,8 +330,12 @@ struct HostTensor
{
}
std
::
size_t
get_length
(
std
::
size_t
dim
)
const
{
return
mDesc
.
get_length
(
dim
);
}
decltype
(
auto
)
get_lengths
()
const
{
return
mDesc
.
get_lengths
();
}
std
::
size_t
get_stride
(
std
::
size_t
dim
)
const
{
return
mDesc
.
get_stride
(
dim
);
}
decltype
(
auto
)
get_strides
()
const
{
return
mDesc
.
get_strides
();
}
std
::
size_t
get_num_of_dimension
()
const
{
return
mDesc
.
get_num_of_dimension
();
}
...
...
include/ck_tile/host/kernel_launch.hpp
View file @
e536d321
...
...
@@ -73,17 +73,17 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
{
// clang-format off
if
(
!
s
.
time_kernel_
)
{
(
callables
(
s
),...);
hip_check_error
(
hipGetLastError
());
(
callables
(
s
),...);
HIP_CHECK_ERROR
(
hipGetLastError
());
return
0
;
}
if
(
s
.
is_gpu_timer_
)
{
gpu_timer
timer
{};
// warmup
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
start
(
s
.
stream_id_
);
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
stop
(
s
.
stream_id_
);
return
timer
.
duration
()
/
s
.
nrepeat_
;
...
...
@@ -92,10 +92,10 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
cpu_timer
timer
{};
// warmup
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
start
(
s
.
stream_id_
);
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
stop
(
s
.
stream_id_
);
return
timer
.
duration
()
/
s
.
nrepeat_
;
...
...
include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp
0 → 100644
View file @
e536d321
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <thread>
namespace
ck_tile
{
template
<
typename
DataType
,
typename
ComputeDataType
=
float
>
CK_TILE_HOST
void
reference_batched_rotary_position_embedding
(
const
HostTensor
<
DataType
>&
input_bsd
,
const
HostTensor
<
DataType
>&
cos_sd
,
const
HostTensor
<
DataType
>&
sin_sd
,
bool
interleaved
,
HostTensor
<
DataType
>&
output_bsd
,
bool
use_1_row_sin_cos
=
false
)
{
assert
(
cos_sd
.
get_num_of_dimension
()
==
2
&&
sin_sd
.
get_num_of_dimension
()
==
2
);
assert
(
cos_sd
.
get_length
(
0
)
==
sin_sd
.
get_length
(
0
)
&&
cos_sd
.
get_length
(
1
)
==
sin_sd
.
get_length
(
1
));
const
index_t
rotary_dim
=
cos_sd
.
get_length
(
1
)
*
2
;
assert
(
static_cast
<
std
::
size_t
>
(
rotary_dim
)
<=
input_bsd
.
get_length
(
2
));
output_bsd
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
const
index_t
i_d
=
i
[
2
];
if
(
rotary_dim
<=
i_d
)
{
self
(
i
)
=
input_bsd
(
i
);
return
;
}
assert
(
i_d
<
rotary_dim
);
const
index_t
i_s
=
i
[
1
];
const
index_t
i_s_cos_sin
=
(
use_1_row_sin_cos
?
0
:
i_s
);
const
ComputeDataType
cos
=
type_convert
<
ComputeDataType
>
(
interleaved
?
cos_sd
(
i_s_cos_sin
,
i_d
/
2
)
:
cos_sd
(
i_s_cos_sin
,
i_d
%
cos_sd
.
get_length
(
1
)));
const
ComputeDataType
sin
=
type_convert
<
ComputeDataType
>
(
interleaved
?
sin_sd
(
i_s_cos_sin
,
i_d
/
2
)
:
sin_sd
(
i_s_cos_sin
,
i_d
%
sin_sd
.
get_length
(
1
)));
const
ComputeDataType
half_rotated_input
=
[
&
]
{
const
index_t
i_b
=
i
[
0
];
if
(
interleaved
)
{
const
bool
is_even
=
(
i_d
%
2
==
0
);
const
index_t
pos
=
i_d
+
(
is_even
?
1
:
-
1
);
const
ComputeDataType
sign
=
(
is_even
?
-
1
:
1
);
return
sign
*
type_convert
<
ComputeDataType
>
(
input_bsd
(
i_b
,
i_s
,
pos
));
}
else
{
const
index_t
half_rdim
=
(
rotary_dim
/
2
);
const
index_t
pos
=
(
i_d
+
half_rdim
)
%
rotary_dim
;
const
ComputeDataType
sign
=
(
pos
<
half_rdim
?
1
:
-
1
);
return
sign
*
type_convert
<
ComputeDataType
>
(
input_bsd
(
i_b
,
i_s
,
pos
));
}
}();
ComputeDataType
result
=
type_convert
<
ComputeDataType
>
(
input_bsd
(
i
))
*
cos
+
half_rotated_input
*
sin
;
self
(
i
)
=
type_convert
<
DataType
>
(
result
);
});
}
}
// namespace ck_tile
include/ck_tile/ops/fmha.hpp
View file @
e536d321
...
...
@@ -7,7 +7,11 @@
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/block/page_block_navigator.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
...
...
@@ -21,11 +25,11 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
...
...
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
View file @
e536d321
...
...
@@ -43,9 +43,12 @@ enum struct AlibiMode
FROM_BOTTOM_RIGHT
=
2
,
};
template
<
typename
DataType
,
bool
RowMajor
=
true
>
template
<
typename
DataType
,
bool
RowMajor
=
true
,
unsigned
LogMaxSadOprndSize
=
16
>
struct
Alibi
{
static_assert
(
1
<=
LogMaxSadOprndSize
&&
LogMaxSadOprndSize
<=
32
,
"for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t"
);
// RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
...
...
@@ -79,6 +82,19 @@ struct Alibi
mode
=
mode_
;
}
CK_TILE_HOST
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
return
sad_u32
(
x
,
y
,
acc
);
}
CK_TILE_DEVICE
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
if
constexpr
(
LogMaxSadOprndSize
<=
16
)
{
return
sad_u16
(
static_cast
<
uint16_t
>
(
x
),
static_cast
<
uint16_t
>
(
y
),
static_cast
<
uint16_t
>
(
acc
));
}
return
sad_u32
(
x
,
y
,
acc
);
}
CK_TILE_HOST_DEVICE
void
update
(
DataType
&
pixel
,
index_t
row_idx
,
index_t
col_idx
)
{
if
constexpr
(
RowMajor
)
...
...
@@ -128,7 +144,7 @@ struct EmptyPositionEncoding
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
template
<
typename
DataType
,
bool
RowMajor
=
true
>
template
<
typename
DataType
,
bool
RowMajor
=
true
,
unsigned
LogMaxSadOprndSize
=
16
>
CK_TILE_HOST_DEVICE
auto
make_alibi_from_lr_mask
(
DataType
slope
,
index_t
window_left_size
,
index_t
window_right_size
,
...
...
@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
AlibiMode
alibi_mode
=
is_causal
?
AlibiMode
::
VERTICAL
:
static_cast
<
AlibiMode
>
(
mask_enum
)
/*either top-left or bottom-right*/
;
return
Alibi
<
DataType
,
RowMajor
>
{
slope
,
y_total
,
x_total
,
alibi_mode
};
return
Alibi
<
DataType
,
RowMajor
,
LogMaxSadOprndSize
>
{
slope
,
y_total
,
x_total
,
alibi_mode
};
}
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
...
...
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
0 → 100644
View file @
e536d321
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace
ck_tile
{
// This class is used for codegen pattern matching
enum
class
RotaryEmbeddingEnum
{
NONE
=
0
,
INTERLEAVED
=
1
,
// combine dimensions 0 & 1, 2 & 3, etc
HALF_ROTATED
=
2
,
// combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
};
template
<
RotaryEmbeddingEnum
>
struct
RotaryEmbeddingEnumToStr
;
template
<
>
struct
RotaryEmbeddingEnumToStr
<
RotaryEmbeddingEnum
::
NONE
>
{
static
constexpr
const
char
*
name
=
""
;
};
template
<
>
struct
RotaryEmbeddingEnumToStr
<
RotaryEmbeddingEnum
::
INTERLEAVED
>
{
static
constexpr
const
char
*
name
=
"inter"
;
};
template
<
>
struct
RotaryEmbeddingEnumToStr
<
RotaryEmbeddingEnum
::
HALF_ROTATED
>
{
static
constexpr
const
char
*
name
=
"half"
;
};
template
<
RotaryEmbeddingEnum
RotaryEnum
,
typename
ComputeDataType
=
float
>
struct
BlockRotaryEmbedding
{
template
<
typename
DistributedTensor
,
typename
OtherDramBlockWindow
,
typename
RotaryCosDramBlockWindow
,
typename
RotarySinDramBlockWindow
>
CK_TILE_HOST_DEVICE
static
void
apply
(
DistributedTensor
&
tile
,
OtherDramBlockWindow
other_window
,
RotaryCosDramBlockWindow
rotary_cos_window
,
RotarySinDramBlockWindow
rotary_sin_window
,
index_t
rotary_dim
,
index_t
thread_end
)
{
using
DataType
=
typename
remove_cvref_t
<
DistributedTensor
>::
DataType
;
if
constexpr
(
RotaryEnum
==
RotaryEmbeddingEnum
::
INTERLEAVED
)
{
auto
rotary_cos_tile
=
load_tile
(
rotary_cos_window
);
auto
rotary_sin_tile
=
load_tile
(
rotary_sin_window
);
if
(
thread_end
<=
rotary_dim
)
{
constexpr
index_t
thread_buffer_size
=
decltype
(
tile
.
thread_buf_
)
::
size
();
static_for
<
0
,
thread_buffer_size
,
2
>
{}([
&
](
auto
idx
)
{
const
auto
left
=
type_convert
<
ComputeDataType
>
(
tile
.
thread_buf_
[
idx
]);
const
auto
right
=
type_convert
<
ComputeDataType
>
(
tile
.
thread_buf_
[
idx
+
1
]);
const
auto
cos
=
type_convert
<
ComputeDataType
>
(
rotary_cos_tile
.
thread_buf_
[
idx
/
2
]);
const
auto
sin
=
type_convert
<
ComputeDataType
>
(
rotary_sin_tile
.
thread_buf_
[
idx
/
2
]);
tile
.
thread_buf_
[
idx
]
=
type_convert
<
DataType
>
(
left
*
cos
-
right
*
sin
);
tile
.
thread_buf_
[
idx
+
1
]
=
type_convert
<
DataType
>
(
right
*
cos
+
left
*
sin
);
});
}
}
else
if
constexpr
(
RotaryEnum
==
RotaryEmbeddingEnum
::
HALF_ROTATED
)
{
if
(
thread_end
<=
rotary_dim
)
{
const
bool
is_left
=
(
thread_end
<=
(
rotary_dim
/
2
));
move_tile_window
(
other_window
,
{
0
,
is_left
?
rotary_dim
/
2
:
-
(
rotary_dim
/
2
)});
auto
other_tile
=
load_tile
(
other_window
);
move_tile_window
(
rotary_cos_window
,
{
0
,
is_left
?
0
:
-
(
rotary_dim
/
2
)});
auto
rotary_cos_tile
=
load_tile
(
rotary_cos_window
);
move_tile_window
(
rotary_sin_window
,
{
0
,
is_left
?
0
:
-
(
rotary_dim
/
2
)});
auto
rotary_sin_tile
=
load_tile
(
rotary_sin_window
);
constexpr
index_t
thread_buffer_size
=
decltype
(
tile
.
thread_buf_
)
::
size
();
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
idx
)
{
const
auto
curr
=
type_convert
<
ComputeDataType
>
(
tile
.
thread_buf_
[
idx
]);
const
auto
other
=
type_convert
<
ComputeDataType
>
(
other_tile
.
thread_buf_
[
idx
]);
const
auto
cos
=
type_convert
<
ComputeDataType
>
(
rotary_cos_tile
.
thread_buf_
[
idx
]);
const
auto
sin
=
type_convert
<
ComputeDataType
>
(
rotary_sin_tile
.
thread_buf_
[
idx
]);
tile
.
thread_buf_
[
idx
]
=
type_convert
<
DataType
>
(
curr
*
cos
+
other
*
(
is_left
?
-
sin
:
sin
));
});
}
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
0 → 100644
View file @
e536d321
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
namespace
ck_tile
{
// assume that we have only 1 page-block/tensor view
template
<
typename
TensorView
>
struct
TrivialPageBlockNavigator
{
using
DataType
=
typename
TensorView
::
DataType
;
using
WindowOrigin
=
multi_index
<
2
>
;
CK_TILE_HOST_DEVICE
constexpr
TrivialPageBlockNavigator
(
const
TensorView
&
tensor_view_
)
:
tensor_view
(
tensor_view_
)
{
}
template
<
typename
WindowLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
)
const
{
return
make_tuple
(
/*block_index=*/
0
,
ck_tile
::
make_tile_window
(
tensor_view
,
window_lengths
,
window_origin
));
}
template
<
typename
WindowLengths
,
typename
TileDistribution
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
,
const
TileDistribution
&
tile_distribution
)
const
{
return
make_tuple
(
/*block_index=*/
0
,
ck_tile
::
make_tile_window
(
tensor_view
,
window_lengths
,
window_origin
,
tile_distribution
));
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
static
index_t
move_tile_window
(
index_t
/*block_index*/
,
TileWindow
&
tile_window
,
const
typename
remove_cvref_t
<
TileWindow
>::
BottomTensorIndex
&
step
)
{
ck_tile
::
move_tile_window
(
tile_window
,
step
);
return
/*block_index=*/
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
WindowOrigin
to_local_window_origin
(
const
WindowOrigin
&
global_window_origin
)
{
return
global_window_origin
;
}
CK_TILE_HOST_DEVICE
static
constexpr
WindowOrigin
to_global_window_origin
(
index_t
/*block_index*/
,
const
WindowOrigin
&
local_window_origin
)
{
return
local_window_origin
;
}
private:
TensorView
tensor_view
;
};
// default page-block navigator, assume that tensor view size is same as page-block size or smaller
// if tile window on last page-block
template
<
typename
DataType_
,
index_t
VirtualDim
,
typename
TensorView
>
struct
PageBlockNavigator
{
using
DataType
=
DataType_
;
static_assert
(
std
::
is_same_v
<
DataType
,
typename
TensorView
::
DataType
>
);
static_assert
(
VirtualDim
==
0
||
VirtualDim
==
1
,
"only support 2d tile window"
);
using
WindowOrigin
=
multi_index
<
2
>
;
CK_TILE_HOST_DEVICE
constexpr
PageBlockNavigator
(
copy_const_t
<
DataType
,
void
>*
physical_blocks_
,
long_index_t
block_stride_
,
long_index_t
fixed_offset_
,
const
int32_t
*
physical_block_indices_
,
index_t
num_blocks_
,
index_t
page_block_size_
,
const
TensorView
&
complete_view_
,
const
TensorView
&
last_view_
)
:
physical_blocks
(
reinterpret_cast
<
DataType
*>
(
physical_blocks_
)),
block_stride
(
block_stride_
),
fixed_offset
(
fixed_offset_
),
physical_block_indices
(
physical_block_indices_
),
num_blocks
(
num_blocks_
),
page_block_size
(
page_block_size_
),
complete_view
(
complete_view_
),
last_view
(
last_view_
)
{
}
template
<
typename
WindowLengths
>
CK_TILE_HOST_DEVICE
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
)
const
{
const
index_t
block_index
=
get_block_index
(
window_origin
);
const
WindowOrigin
local_window_origin
=
to_local_window_origin
(
window_origin
);
auto
new_tile_window
=
ck_tile
::
make_tile_window
(
is_last_block
(
block_index
)
?
last_view
:
complete_view
,
window_lengths
,
local_window_origin
);
new_tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
block_index
));
return
make_tuple
(
block_index
,
new_tile_window
);
}
template
<
typename
WindowLengths
,
typename
TileDistribution
>
CK_TILE_HOST_DEVICE
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
,
const
TileDistribution
&
tile_distribution
)
const
{
const
index_t
block_index
=
get_block_index
(
window_origin
);
const
WindowOrigin
local_window_origin
=
to_local_window_origin
(
window_origin
);
auto
new_tile_window
=
ck_tile
::
make_tile_window
(
is_last_block
(
block_index
)
?
last_view
:
complete_view
,
window_lengths
,
local_window_origin
,
tile_distribution
);
new_tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
block_index
));
return
make_tuple
(
block_index
,
new_tile_window
);
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
index_t
move_tile_window
(
index_t
block_index
,
TileWindow
&
tile_window
,
const
typename
remove_cvref_t
<
TileWindow
>::
BottomTensorIndex
&
step
)
const
{
ck_tile
::
move_tile_window
(
tile_window
,
step
);
const
WindowOrigin
global_window_origin
=
to_global_window_origin
(
block_index
,
tile_window
.
get_window_origin
());
const
WindowOrigin
local_window_origin
=
to_local_window_origin
(
global_window_origin
);
const
index_t
new_block_index
=
get_block_index
(
global_window_origin
);
/// TODO: only update necessary attributes
tile_window
.
bottom_tensor_view_
.
desc_
=
(
is_last_block
(
new_block_index
)
?
last_view
:
complete_view
).
get_tensor_descriptor
();
tile_window
.
set_window_origin
(
local_window_origin
);
tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
new_block_index
));
return
new_block_index
;
}
CK_TILE_HOST_DEVICE
bool
is_last_block
(
index_t
block_index
)
const
{
return
block_index
==
num_blocks
-
1
;
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
bool
is_cross_block
(
index_t
block_index
,
const
TileWindow
&
tile_window
)
const
{
const
index_t
origin
=
tile_window
.
get_window_origin
().
at
(
number
<
VirtualDim
>
{});
const
index_t
length
=
tile_window
.
get_window_lengths
().
at
(
number
<
VirtualDim
>
{});
return
(
block_index
<
num_blocks
-
1
)
&&
(
page_block_size
<
origin
+
length
);
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
void
move_to_block
(
index_t
block_index
,
TileWindow
&
tile_window
,
index_t
new_block_index
)
const
{
const
multi_index
<
2
>
step
=
[
&
]()
{
const
index_t
origin_diff
=
(
block_index
-
new_block_index
)
*
page_block_size
;
if
constexpr
(
VirtualDim
==
0
)
{
return
make_multi_index
(
origin_diff
,
0
);
}
else
{
return
make_multi_index
(
0
,
origin_diff
);
}
}();
/// TODO: only update necessary attributes
tile_window
.
bottom_tensor_view_
.
desc_
=
(
is_last_block
(
new_block_index
)
?
last_view
:
complete_view
).
get_tensor_descriptor
();
tile_window
.
set_window_origin
(
tile_window
.
get_window_origin
()
+
step
);
tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
new_block_index
));
}
CK_TILE_HOST_DEVICE
WindowOrigin
to_local_window_origin
(
const
WindowOrigin
&
global_window_origin
)
const
{
if
constexpr
(
VirtualDim
==
0
)
{
const
index_t
length
=
global_window_origin
.
at
(
number
<
0
>
{});
const
index_t
num_complete_blocks
=
integer_divide_floor
(
length
,
page_block_size
);
return
make_multi_index
(
length
-
page_block_size
*
num_complete_blocks
,
global_window_origin
.
at
(
number
<
1
>
{}));
}
else
{
const
index_t
length
=
global_window_origin
.
at
(
number
<
1
>
{});
const
index_t
num_complete_blocks
=
integer_divide_floor
(
length
,
page_block_size
);
return
make_multi_index
(
global_window_origin
.
at
(
number
<
0
>
{}),
length
-
page_block_size
*
num_complete_blocks
);
}
}
CK_TILE_HOST_DEVICE
WindowOrigin
to_global_window_origin
(
index_t
block_index
,
const
WindowOrigin
&
local_window_origin
)
const
{
if
constexpr
(
VirtualDim
==
0
)
{
return
make_multi_index
(
block_index
*
page_block_size
+
local_window_origin
.
at
(
number
<
0
>
{}),
local_window_origin
.
at
(
number
<
1
>
{}));
}
else
{
return
make_multi_index
(
local_window_origin
.
at
(
number
<
0
>
{}),
block_index
*
page_block_size
+
local_window_origin
.
at
(
number
<
1
>
{}));
}
}
private:
CK_TILE_HOST_DEVICE
DataType
*
get_block_ptr
(
index_t
block_index
)
const
{
return
physical_blocks
+
physical_block_indices
[
block_index
]
*
block_stride
+
fixed_offset
;
}
CK_TILE_HOST_DEVICE
int32_t
get_block_index
(
const
WindowOrigin
&
global_window_origin
)
const
{
return
integer_divide_floor
(
global_window_origin
.
at
(
number
<
VirtualDim
>
{}),
page_block_size
);
}
DataType
*
physical_blocks
;
long_index_t
block_stride
;
long_index_t
fixed_offset
;
const
int32_t
*
physical_block_indices
;
index_t
num_blocks
;
index_t
page_block_size
;
TensorView
complete_view
;
TensorView
last_view
;
};
template
<
typename
TensorView
>
CK_TILE_HOST_DEVICE
auto
make_page_block_navigator
(
const
TensorView
&
tensor_view
)
{
return
TrivialPageBlockNavigator
<
TensorView
>
(
tensor_view
);
}
template
<
typename
DataType
,
index_t
VirtualDim
,
typename
TensorView
>
CK_TILE_HOST_DEVICE
auto
make_page_block_navigator
(
copy_const_t
<
DataType
,
void
>*
physical_blocks
,
long_index_t
block_stride
,
long_index_t
fixed_offset
,
const
int32_t
*
physical_block_indices
,
index_t
num_blocks
,
index_t
page_block_size
,
const
TensorView
&
complete_view
,
const
TensorView
&
last_view
)
{
return
PageBlockNavigator
<
DataType
,
VirtualDim
,
TensorView
>
(
physical_blocks
,
block_stride
,
fixed_offset
,
physical_block_indices
,
num_blocks
,
page_block_size
,
complete_view
,
last_view
);
}
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
0 → 100644
View file @
e536d321
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
0 → 100644
View file @
e536d321
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
index_t
kM0_
,
index_t
kN0_
,
index_t
kK0_
,
index_t
kN1_
>
struct
FmhaFwdAppendKVTilePartitioner
{
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN0
=
kN0_
;
static
constexpr
ck_tile
::
index_t
kK0
=
kK0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static_assert
(
kK0
==
kN1
);
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_knew
)
{
// TODO: this may need tuning
return
dim3
(
std
::
max
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
,
kM0
),
ck_tile
::
integer_divide_ceil
(
seqlen_knew
,
kN0
)),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()()
{
const
index_t
i_tile
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment