Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f6ceef78
Commit
f6ceef78
authored
Aug 26, 2024
by
ThomasNing
Browse files
merge with the develop branch
parents
536c5458
25935b57
Changes
240
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1571 additions
and
608 deletions
+1571
-608
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+10
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+15
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+18
-10
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+417
-91
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+47
-8
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+4
-2
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+11
-11
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+21
-1
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+9
-33
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+9
-0
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+3
-2
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+4
-1
include/ck_tile/core/utility/philox_rand.hpp
include/ck_tile/core/utility/philox_rand.hpp
+33
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+3
-8
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+377
-16
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+555
-332
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
+0
-54
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+2
-4
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+21
-18
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+12
-13
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -649,6 +649,15 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
const
BGridDesc_B_K0_N_K1
&
b_grid_desc_b_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_b_k0_m_k1
.
GetElementSpaceSize
()
*
sizeof
(
FloatAB
)
<=
TwoGB
&&
b_grid_desc_b_k0_n_k1
.
GetElementSpaceSize
()
*
sizeof
(
FloatAB
)
<=
TwoGB
&&
c_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
FloatC
)
<=
TwoGB
))
{
return
false
;
}
const
auto
M
=
a_grid_desc_b_k0_m_k1
.
GetLength
(
I2
);
const
auto
N
=
b_grid_desc_b_k0_n_k1
.
GetLength
(
I2
);
const
auto
K0
=
a_grid_desc_b_k0_m_k1
.
GetLength
(
I1
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
f6ceef78
...
...
@@ -417,6 +417,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}();
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
...
...
@@ -454,6 +461,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
struct
Problem
...
...
@@ -953,7 +961,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
!
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
...
...
@@ -970,7 +979,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
...
...
@@ -1105,7 +1115,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
if
constexpr
(
!
(
is_same
<
remove_cvref_t
<
CDataType
>
,
half_t
>::
value
||
is_same
<
remove_cvref_t
<
CDataType
>
,
float
>::
value
))
is_same
<
remove_cvref_t
<
CDataType
>
,
float
>::
value
||
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
||
is_same
<
remove_cvref_t
<
CDataType
>
,
int32_t
>::
value
))
{
if
(
!
karg
.
IsReduceAdd
())
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
f6ceef78
...
...
@@ -36,10 +36,9 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3
_multi_d
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
...
...
@@ -56,7 +55,7 @@ __global__ void
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
typename
GridwiseGemm
,
...
...
@@ -69,10 +68,9 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3_
multi_d_
2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -93,7 +91,7 @@ __global__ void
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
typename
ALayout
,
...
...
@@ -454,6 +452,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
}();
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
...
...
@@ -491,6 +496,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
...
...
@@ -1016,7 +1022,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
!
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
...
...
@@ -1033,7 +1040,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
f6ceef78
...
...
@@ -19,7 +19,8 @@ template <index_t NDimSpatial,
bool
SplitN
=
false
,
typename
ADataType
=
float
,
typename
CDataType
=
float
,
index_t
NumGroupsToMerge
=
1
>
index_t
NumGroupsToMerge
=
1
,
typename
IndexType
=
index_t
>
struct
TransformConvFwdToGemm
{
private:
...
...
@@ -46,10 +47,10 @@ struct TransformConvFwdToGemm
}
template
<
typename
ConvDimsType
>
static
i
ndex
_t
GetSplitedNSize
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
c_g_n_k_wos_lengths
,
const
ConvDimsType
&
c_g_n_k_wos_strides
)
static
I
ndex
Type
GetSplitedNSize
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
c_g_n_k_wos_lengths
,
const
ConvDimsType
&
c_g_n_k_wos_strides
)
{
const
long_index_t
a_element_space_size
=
calculate_element_space_size_impl
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
I1
);
...
...
@@ -59,7 +60,7 @@ struct TransformConvFwdToGemm
c_element_space_size
*
sizeof
(
CDataType
));
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
const
i
ndex
_t
N
=
a_g_n_c_wis_lengths
[
I1
];
const
I
ndex
Type
N
=
a_g_n_c_wis_lengths
[
I1
];
if
(
element_space_size
>
TwoGB
)
{
...
...
@@ -70,7 +71,7 @@ struct TransformConvFwdToGemm
{
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for
(
i
ndex
_t
least_divisor
=
divisor
;
least_divisor
*
least_divisor
<=
N
;
for
(
I
ndex
Type
least_divisor
=
divisor
;
least_divisor
*
least_divisor
<=
N
;
least_divisor
++
)
{
if
(
N
%
least_divisor
==
0
)
...
...
@@ -98,6 +99,53 @@ struct TransformConvFwdToGemm
public:
__host__
__device__
constexpr
TransformConvFwdToGemm
()
{}
template
<
typename
TransformConvFwdToGemmBase
>
__host__
__device__
TransformConvFwdToGemm
(
const
TransformConvFwdToGemmBase
&
transform_conv_fwd_to_gemm_base
)
:
N_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
N_
)},
Di_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
Di_
)},
Hi_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
Hi_
)},
Wi_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
Wi_
)},
Do_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
Do_
)},
Ho_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
Ho_
)},
Wo_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
Wo_
)},
Z_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
Z_
)},
Y_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
Y_
)},
X_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
X_
)},
K_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
K_
)},
C_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
C_
)},
DiStride_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
DiStride_
)},
HiStride_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
HiStride_
)},
WiStride_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
WiStride_
)},
DoStride_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
DoStride_
)},
HoStride_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
HoStride_
)},
WoStride_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
WoStride_
)},
XStride_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
XStride_
)},
CStrideTensorA_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
CStrideTensorA_
)},
CStrideTensorB_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
CStrideTensorB_
)},
KStrideTensorB_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
KStrideTensorB_
)},
KStrideTensorC_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
KStrideTensorC_
)},
NStrideTensorA_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
NStrideTensorA_
)},
NStrideTensorC_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
NStrideTensorC_
)},
GStrideTensorA_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
GStrideTensorA_
)},
GStrideTensorB_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
GStrideTensorB_
)},
GStrideTensorC_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
GStrideTensorC_
)},
ConvStrideD_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
ConvStrideD_
)},
ConvStrideH_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
ConvStrideH_
)},
ConvStrideW_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
ConvStrideW_
)},
ConvDilationD_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
ConvDilationD_
)},
ConvDilationH_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
ConvDilationH_
)},
ConvDilationW_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
ConvDilationW_
)},
InLeftPadD_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
InLeftPadD_
)},
InLeftPadH_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
InLeftPadH_
)},
InLeftPadW_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
InLeftPadW_
)},
InRightPadD_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
InRightPadD_
)},
InRightPadH_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
InRightPadH_
)},
InRightPadW_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
InRightPadW_
)},
ZYX_
{
static_cast
<
IndexType
>
(
transform_conv_fwd_to_gemm_base
.
ZYX_
)}
{
}
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
...
...
@@ -126,6 +174,8 @@ struct TransformConvFwdToGemm
DiStride_
{
I1
},
HiStride_
{
I1
},
WiStride_
{
a_g_n_c_wis_strides
[
I3
]},
DoStride_
{
I1
},
HoStride_
{
I1
},
WoStride_
{
c_g_n_k_wos_strides
[
I3
]},
XStride_
{
b_g_k_c_xs_strides
[
I3
]},
CStrideTensorA_
{
a_g_n_c_wis_strides
[
I2
]},
...
...
@@ -133,6 +183,7 @@ struct TransformConvFwdToGemm
KStrideTensorB_
{
b_g_k_c_xs_strides
[
I1
]},
KStrideTensorC_
{
c_g_n_k_wos_strides
[
I2
]},
NStrideTensorA_
{
a_g_n_c_wis_strides
[
I1
]},
NStrideTensorC_
{
c_g_n_k_wos_strides
[
I1
]},
GStrideTensorA_
{
a_g_n_c_wis_strides
[
I0
]},
GStrideTensorB_
{
b_g_k_c_xs_strides
[
I0
]},
GStrideTensorC_
{
c_g_n_k_wos_strides
[
I0
]},
...
...
@@ -150,10 +201,10 @@ struct TransformConvFwdToGemm
InRightPadW_
{
input_right_pads
[
I0
]},
ZYX_
{
X_
}
{
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
i
ndex
_t
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
i
ndex
_t
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
i
ndex
_t
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
i
ndex
_t
,
NDimSpatial
+
I3
>>
);
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
I
ndex
Type
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
I
ndex
Type
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
I
ndex
Type
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
I
ndex
Type
,
NDimSpatial
+
I3
>>
);
if
constexpr
(
SplitN
)
{
...
...
@@ -164,7 +215,6 @@ struct TransformConvFwdToGemm
{
N_
=
c_g_n_k_wos_lengths
[
I1
];
}
NDoHoWo_
=
N_
*
Wo_
;
}
template
<
typename
ConvDimsType
,
...
...
@@ -195,6 +245,8 @@ struct TransformConvFwdToGemm
DiStride_
{
I1
},
HiStride_
{
a_g_n_c_wis_strides
[
I3
]},
WiStride_
{
a_g_n_c_wis_strides
[
I4
]},
DoStride_
{
I1
},
HoStride_
{
c_g_n_k_wos_strides
[
I3
]},
WoStride_
{
c_g_n_k_wos_strides
[
I4
]},
XStride_
{
b_g_k_c_xs_strides
[
I4
]},
CStrideTensorA_
{
a_g_n_c_wis_strides
[
I2
]},
...
...
@@ -202,6 +254,7 @@ struct TransformConvFwdToGemm
KStrideTensorB_
{
b_g_k_c_xs_strides
[
I1
]},
KStrideTensorC_
{
c_g_n_k_wos_strides
[
I2
]},
NStrideTensorA_
{
a_g_n_c_wis_strides
[
I1
]},
NStrideTensorC_
{
c_g_n_k_wos_strides
[
I1
]},
GStrideTensorA_
{
a_g_n_c_wis_strides
[
I0
]},
GStrideTensorB_
{
b_g_k_c_xs_strides
[
I0
]},
GStrideTensorC_
{
c_g_n_k_wos_strides
[
I0
]},
...
...
@@ -219,10 +272,10 @@ struct TransformConvFwdToGemm
InRightPadW_
{
input_right_pads
[
I1
]},
ZYX_
{
Y_
*
X_
}
{
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
i
ndex
_t
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
i
ndex
_t
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
i
ndex
_t
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
i
ndex
_t
,
NDimSpatial
+
I3
>>
);
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
I
ndex
Type
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
I
ndex
Type
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
I
ndex
Type
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
I
ndex
Type
,
NDimSpatial
+
I3
>>
);
if
constexpr
(
SplitN
)
{
...
...
@@ -233,7 +286,6 @@ struct TransformConvFwdToGemm
{
N_
=
c_g_n_k_wos_lengths
[
I1
];
}
NDoHoWo_
=
N_
*
Ho_
*
Wo_
;
}
template
<
typename
ConvDimsType
,
...
...
@@ -264,6 +316,8 @@ struct TransformConvFwdToGemm
DiStride_
{
a_g_n_c_wis_strides
[
I3
]},
HiStride_
{
a_g_n_c_wis_strides
[
I4
]},
WiStride_
{
a_g_n_c_wis_strides
[
I5
]},
DoStride_
{
c_g_n_k_wos_strides
[
I3
]},
HoStride_
{
c_g_n_k_wos_strides
[
I4
]},
WoStride_
{
c_g_n_k_wos_strides
[
I5
]},
XStride_
{
b_g_k_c_xs_strides
[
I5
]},
CStrideTensorA_
{
a_g_n_c_wis_strides
[
I2
]},
...
...
@@ -271,6 +325,7 @@ struct TransformConvFwdToGemm
KStrideTensorB_
{
b_g_k_c_xs_strides
[
I1
]},
KStrideTensorC_
{
c_g_n_k_wos_strides
[
I2
]},
NStrideTensorA_
{
a_g_n_c_wis_strides
[
I1
]},
NStrideTensorC_
{
c_g_n_k_wos_strides
[
I1
]},
GStrideTensorA_
{
a_g_n_c_wis_strides
[
I0
]},
GStrideTensorB_
{
b_g_k_c_xs_strides
[
I0
]},
GStrideTensorC_
{
c_g_n_k_wos_strides
[
I0
]},
...
...
@@ -288,10 +343,10 @@ struct TransformConvFwdToGemm
InRightPadW_
{
input_right_pads
[
I2
]},
ZYX_
{
Z_
*
Y_
*
X_
}
{
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
i
ndex
_t
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
i
ndex
_t
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
i
ndex
_t
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
i
ndex
_t
,
NDimSpatial
+
I3
>>
);
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
I
ndex
Type
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
I
ndex
Type
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
I
ndex
Type
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
I
ndex
Type
,
NDimSpatial
+
I3
>>
);
if
constexpr
(
SplitN
)
{
...
...
@@ -302,7 +357,122 @@ struct TransformConvFwdToGemm
{
N_
=
c_g_n_k_wos_lengths
[
I1
];
}
NDoHoWo_
=
N_
*
Do_
*
Ho_
*
Wo_
;
}
__host__
bool
AreDescriptorsSmallerThan2GB
()
const
{
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
const
long_index_t
in_desc_space_size
=
I1
+
(
N_
-
I1
)
*
NStrideTensorA_
+
(
Di_
-
I1
)
*
DiStride_
+
(
Hi_
-
I1
)
*
HiStride_
+
(
Wi_
-
I1
)
*
WiStride_
+
(
C_
-
I1
)
*
CStrideTensorA_
;
const
long_index_t
out_desc_space_size
=
I1
+
(
N_
-
I1
)
*
NStrideTensorC_
+
(
Do_
-
I1
)
*
DoStride_
+
(
Ho_
-
I1
)
*
HoStride_
+
(
Wo_
-
I1
)
*
WoStride_
+
(
K_
-
I1
)
*
KStrideTensorC_
;
bool
is_a_descriptor_smaller_than_2GB
=
(
in_desc_space_size
*
sizeof
(
ADataType
))
<=
TwoGB
;
bool
is_c_descriptor_smaller_than_2GB
=
(
out_desc_space_size
*
sizeof
(
CDataType
))
<=
TwoGB
;
return
is_a_descriptor_smaller_than_2GB
&&
is_c_descriptor_smaller_than_2GB
;
}
__host__
auto
SplitConvProblem
(
const
ADataType
*
a_grid_ptr_base
,
CDataType
*
c_grid_ptr_base
)
const
{
// Create copies
auto
conv_to_gemm_transformer_left
=
*
this
;
auto
conv_to_gemm_transformer_right
=
*
this
;
IndexType
a_right_offset
=
0
;
IndexType
c_right_offset
=
0
;
// Calculate real filter size
const
IndexType
z_eff
=
(
Z_
-
1
)
*
ConvDilationD_
+
1
;
const
IndexType
y_eff
=
(
Y_
-
1
)
*
ConvDilationH_
+
1
;
const
IndexType
x_eff
=
(
X_
-
1
)
*
ConvDilationW_
+
1
;
// Calculate start position in input for right tensor
const
IndexType
di_right_transformer_start_idx
=
(
Do_
/
2
)
*
ConvStrideD_
;
const
IndexType
hi_right_transformer_start_idx
=
(
Ho_
/
2
)
*
ConvStrideH_
;
const
IndexType
wi_right_transformer_start_idx
=
(
Wo_
/
2
)
*
ConvStrideW_
;
// Calculate last position in input for left tensor
const
IndexType
di_left_transformer_end_idx
=
(
Do_
/
2
-
1
)
*
ConvStrideD_
+
z_eff
;
const
IndexType
hi_left_transformer_end_idx
=
(
Ho_
/
2
-
1
)
*
ConvStrideH_
+
y_eff
;
const
IndexType
wi_left_transformer_end_idx
=
(
Wo_
/
2
-
1
)
*
ConvStrideW_
+
x_eff
;
// Allow to split if whole left padding will be in left tensor and right padding in right
// tensor
const
bool
is_possible_to_split_d
=
Do_
!=
1
&&
di_right_transformer_start_idx
>
InLeftPadD_
&&
di_left_transformer_end_idx
<=
(
InLeftPadD_
+
Di_
);
const
bool
is_possible_to_split_h
=
Ho_
!=
1
&&
hi_right_transformer_start_idx
>
InLeftPadH_
&&
hi_left_transformer_end_idx
<=
(
InLeftPadH_
+
Hi_
);
const
bool
is_possible_to_split_w
=
Wo_
!=
1
&&
wi_right_transformer_start_idx
>
InLeftPadW_
&&
wi_left_transformer_end_idx
<=
(
InLeftPadW_
+
Wi_
);
if
(
is_possible_to_split_d
)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left
.
Do_
=
Do_
/
2
;
conv_to_gemm_transformer_right
.
Do_
=
Do_
-
Do_
/
2
;
// Assign left padding to left convolution
conv_to_gemm_transformer_left
.
InLeftPadD_
=
InLeftPadD_
;
conv_to_gemm_transformer_right
.
InLeftPadD_
=
0
;
// Assign right padding to right convolution
conv_to_gemm_transformer_left
.
InRightPadD_
=
0
;
conv_to_gemm_transformer_right
.
InRightPadD_
=
InRightPadD_
;
// Calculate new input size
conv_to_gemm_transformer_left
.
Di_
=
di_left_transformer_end_idx
-
InLeftPadD_
;
conv_to_gemm_transformer_right
.
Di_
=
math
::
min
(
Di_
-
(
di_right_transformer_start_idx
-
InLeftPadD_
),
(
conv_to_gemm_transformer_right
.
Do_
-
1
)
*
ConvStrideD_
+
z_eff
);
;
// Calcualte offsets
a_right_offset
=
((
Do_
/
2
)
*
ConvStrideD_
-
InLeftPadD_
)
*
DiStride_
;
c_right_offset
=
(
Do_
/
2
)
*
DoStride_
;
}
else
if
(
is_possible_to_split_h
)
{
conv_to_gemm_transformer_left
.
Ho_
=
Ho_
/
2
;
conv_to_gemm_transformer_right
.
Ho_
=
Ho_
-
Ho_
/
2
;
conv_to_gemm_transformer_left
.
InLeftPadH_
=
InLeftPadH_
;
conv_to_gemm_transformer_right
.
InLeftPadH_
=
0
;
conv_to_gemm_transformer_left
.
InRightPadH_
=
0
;
conv_to_gemm_transformer_right
.
InRightPadH_
=
InRightPadH_
;
conv_to_gemm_transformer_left
.
Hi_
=
hi_left_transformer_end_idx
-
InLeftPadH_
;
conv_to_gemm_transformer_right
.
Hi_
=
math
::
min
(
Hi_
-
(
hi_right_transformer_start_idx
-
InLeftPadH_
),
(
conv_to_gemm_transformer_right
.
Ho_
-
1
)
*
ConvStrideH_
+
y_eff
);
a_right_offset
=
((
Ho_
/
2
)
*
ConvStrideH_
-
InLeftPadH_
)
*
HiStride_
;
c_right_offset
=
(
Ho_
/
2
)
*
HoStride_
;
}
else
if
(
is_possible_to_split_w
)
{
conv_to_gemm_transformer_left
.
Wo_
=
Wo_
/
2
;
conv_to_gemm_transformer_right
.
Wo_
=
Wo_
-
Wo_
/
2
;
conv_to_gemm_transformer_left
.
InLeftPadW_
=
InLeftPadW_
;
conv_to_gemm_transformer_right
.
InLeftPadW_
=
0
;
conv_to_gemm_transformer_left
.
InRightPadW_
=
0
;
conv_to_gemm_transformer_right
.
InRightPadW_
=
InRightPadW_
;
conv_to_gemm_transformer_left
.
Wi_
=
wi_left_transformer_end_idx
-
InLeftPadW_
;
conv_to_gemm_transformer_right
.
Wi_
=
math
::
min
(
Wi_
-
(
wi_right_transformer_start_idx
-
InLeftPadW_
),
(
conv_to_gemm_transformer_right
.
Wo_
-
1
)
*
ConvStrideW_
+
x_eff
);
a_right_offset
=
((
Wo_
/
2
)
*
ConvStrideW_
-
InLeftPadW_
)
*
WiStride_
;
c_right_offset
=
(
Wo_
/
2
)
*
WoStride_
;
}
// Return left transform, right transformer, right offset to Input and right offset to
// Output
return
ck
::
make_tuple
(
conv_to_gemm_transformer_left
,
conv_to_gemm_transformer_right
,
a_grid_ptr_base
+
a_right_offset
,
c_grid_ptr_base
+
c_right_offset
);
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
...
...
@@ -320,20 +490,27 @@ struct TransformConvFwdToGemm
{
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
C_
),
make_tuple
(
WiStride_
,
CStrideTensorA_
));
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Wo_
,
C_
),
make_tuple
(
NStrideTensorA_
,
WiStride_
,
CStrideTensorA_
));
return
transform_tensor_descriptor
(
in_gemmm_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Wo_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_gemmm_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
DoHo
Wo_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
make_tuple
(
N
_
,
Wo_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
return
transform_tensor_descriptor
(
in_gemmm_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
DoHo
Wo_
,
NumGroupsToMerge
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Wo_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
...
...
@@ -527,20 +704,29 @@ struct TransformConvFwdToGemm
{
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
C_
),
make_tuple
(
WiStride_
,
CStrideTensorA_
));
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Ho_
,
Wo_
,
C_
),
make_tuple
(
NStrideTensorA_
,
HiStride_
,
WiStride_
,
CStrideTensorA_
));
return
transform_tensor_descriptor
(
in_gemmm_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Ho_
,
Wo_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_gemmm_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
make_tuple
(
N_
,
Ho_
,
Wo_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
HiStride_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
return
transform_tensor_descriptor
(
in_gemmm_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
DoHo
Wo_
,
NumGroupsToMerge
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Ho_
,
Wo_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
...
...
@@ -759,20 +945,34 @@ struct TransformConvFwdToGemm
{
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
C_
),
make_tuple
(
WiStride_
,
CStrideTensorA_
));
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
,
C_
),
make_tuple
(
NStrideTensorA_
,
DiStride_
,
HiStride_
,
WiStride_
,
CStrideTensorA_
));
return
transform_tensor_descriptor
(
in_gemmm_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_gemmm_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
DiStride_
,
HiStride_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
return
transform_tensor_descriptor
(
in_gemmm_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
...
...
@@ -1119,45 +1319,70 @@ struct TransformConvFwdToGemm
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNWK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
1
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
NDoHoWo_
,
K_
));
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
*
Wo_
,
K_
),
make_tuple
(
I0
,
KStrideTensorC_
));
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NDHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
,
bool
>::
type
=
false
>
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
*
Ho_
*
Wo_
,
K_
),
make_tuple
(
I0
,
KStrideTensorC_
));
}
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
3
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
*
Do_
*
Ho_
*
Wo_
,
K_
),
make_tuple
(
I0
,
KStrideTensorC_
));
}
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
1
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNWK
>
),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
const
IndexType
NDoHoWo
=
N_
*
Wo_
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
_
,
K_
),
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
,
K_
),
make_tuple
(
WoStride_
,
KStrideTensorC_
));
}
else
{
const
auto
nhwo_groups_k_1_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
,
K_
,
1
),
make_tuple
(
WoStride_
,
GStrideTensorC_
,
KStrideTensorC_
,
GStrideTensorC_
));
make_tuple
(
N_
,
Wo_
,
NumGroupsToMerge
,
K_
,
1
),
make_tuple
(
NStrideTensorC_
,
WoStride_
,
GStrideTensorC_
,
KStrideTensorC_
,
GStrideTensorC_
));
// Padd 1 to NumGroupsToMerge
const
auto
padded_desc
=
transform_tensor_descriptor
(
nhwo_groups_k_1_desc
,
make_tuple
(
make_
pass_through_transform
(
NDoHo
Wo_
),
make_tuple
(
make_
merge_transform
(
make_tuple
(
N_
,
Wo_
)
)
,
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
K_
),
make_pad_transform
(
1
,
0
,
NumGroupsToMerge
-
1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}
,
Sequence
<
4
>
{}
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
...
...
@@ -1167,7 +1392,7 @@ struct TransformConvFwdToGemm
NumGroupsToMerge
==
32
||
NumGroupsToMerge
==
64
);
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
padded_desc
,
make_tuple
(
make_pass_through_transform
(
NDoHoWo
_
),
make_tuple
(
make_pass_through_transform
(
NDoHoWo
),
make_xor_transform
(
make_tuple
(
NumGroupsToMerge
,
NumGroupsToMerge
)),
make_pass_through_transform
(
K_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}),
...
...
@@ -1175,45 +1400,146 @@ struct TransformConvFwdToGemm
// Merge To M, N
return
transform_tensor_descriptor
(
unmerged_padded_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NDoHoWo
_
,
NumGroupsToMerge
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
NDoHoWo
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
K_
,
NumGroupsToMerge
))),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
// for output bias
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>,
bool
>::
type
=
false
>
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWK
>
),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo_
,
K_
),
make_tuple
(
I0
,
KStrideTensorC_
));
return
out_gemmm_gemmn_desc
;
const
IndexType
NDoHoWo
=
N_
*
Ho_
*
Wo_
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
,
K_
),
make_tuple
(
WoStride_
,
KStrideTensorC_
));
}
else
{
const
auto
nhwo_groups_k_1_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Ho_
,
Wo_
,
NumGroupsToMerge
,
K_
,
1
),
make_tuple
(
NStrideTensorC_
,
HoStride_
,
WoStride_
,
GStrideTensorC_
,
KStrideTensorC_
,
GStrideTensorC_
));
// Padd 1 to NumGroupsToMerge
const
auto
padded_desc
=
transform_tensor_descriptor
(
nhwo_groups_k_1_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Ho_
,
Wo_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
K_
),
make_pad_transform
(
1
,
0
,
NumGroupsToMerge
-
1
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert
(
NumGroupsToMerge
==
1
||
NumGroupsToMerge
==
2
||
NumGroupsToMerge
==
4
||
NumGroupsToMerge
==
8
||
NumGroupsToMerge
==
16
||
NumGroupsToMerge
==
32
||
NumGroupsToMerge
==
64
);
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
padded_desc
,
make_tuple
(
make_pass_through_transform
(
NDoHoWo
),
make_xor_transform
(
make_tuple
(
NumGroupsToMerge
,
NumGroupsToMerge
)),
make_pass_through_transform
(
K_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}));
// Merge To M, N
return
transform_tensor_descriptor
(
unmerged_padded_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NDoHoWo
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
K_
,
NumGroupsToMerge
))),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
public:
index_t
N_
;
template
<
typename
CLayout
,
index_t
NDimSp
=
NDimSpatial
,
typename
std
::
enable_if
<
NDimSp
==
3
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NDHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
private:
const
index_t
Di_
,
Hi_
,
Wi_
;
const
index_t
Do_
,
Ho_
,
Wo_
;
const
index_t
Z_
,
Y_
,
X_
;
const
index_t
K_
,
C_
;
const
index_t
DiStride_
,
HiStride_
,
WiStride_
;
const
index_t
WoStride_
;
const
index_t
XStride_
;
const
index_t
CStrideTensorA_
,
CStrideTensorB_
,
KStrideTensorB_
,
KStrideTensorC_
;
const
index_t
NStrideTensorA_
;
const
index_t
GStrideTensorA_
,
GStrideTensorB_
,
GStrideTensorC_
;
const
index_t
ConvStrideD_
,
ConvStrideH_
,
ConvStrideW_
;
const
index_t
ConvDilationD_
,
ConvDilationH_
,
ConvDilationW_
;
const
index_t
InLeftPadD_
,
InLeftPadH_
,
InLeftPadW_
;
const
index_t
InRightPadD_
,
InRightPadH_
,
InRightPadW_
;
const
index_t
ZYX_
;
index_t
NDoHoWo_
;
const
IndexType
NDoHoWo
=
N_
*
Do_
*
Ho_
*
Wo_
;
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
,
K_
),
make_tuple
(
WoStride_
,
KStrideTensorC_
));
}
else
{
const
auto
nhwo_groups_k_1_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
,
NumGroupsToMerge
,
K_
,
1
),
make_tuple
(
NStrideTensorC_
,
DoStride_
,
HoStride_
,
WoStride_
,
GStrideTensorC_
,
KStrideTensorC_
,
GStrideTensorC_
));
// Padd 1 to NumGroupsToMerge
const
auto
padded_desc
=
transform_tensor_descriptor
(
nhwo_groups_k_1_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
K_
),
make_pad_transform
(
1
,
0
,
NumGroupsToMerge
-
1
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// We need only matrices from diagonal. X_or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert
(
NumGroupsToMerge
==
1
||
NumGroupsToMerge
==
2
||
NumGroupsToMerge
==
4
||
NumGroupsToMerge
==
8
||
NumGroupsToMerge
==
16
||
NumGroupsToMerge
==
32
||
NumGroupsToMerge
==
64
);
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
padded_desc
,
make_tuple
(
make_pass_through_transform
(
NDoHoWo
),
make_xor_transform
(
make_tuple
(
NumGroupsToMerge
,
NumGroupsToMerge
)),
make_pass_through_transform
(
K_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}));
// Merge To M, N
return
transform_tensor_descriptor
(
unmerged_padded_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NDoHoWo
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
K_
,
NumGroupsToMerge
))),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
IndexType
N_
;
IndexType
Di_
,
Hi_
,
Wi_
;
IndexType
Do_
,
Ho_
,
Wo_
;
IndexType
Z_
,
Y_
,
X_
;
IndexType
K_
,
C_
;
IndexType
DiStride_
,
HiStride_
,
WiStride_
;
IndexType
DoStride_
,
HoStride_
,
WoStride_
;
IndexType
XStride_
;
IndexType
CStrideTensorA_
,
CStrideTensorB_
,
KStrideTensorB_
,
KStrideTensorC_
;
IndexType
NStrideTensorA_
,
NStrideTensorC_
;
IndexType
GStrideTensorA_
,
GStrideTensorB_
,
GStrideTensorC_
;
IndexType
ConvStrideD_
,
ConvStrideH_
,
ConvStrideW_
;
IndexType
ConvDilationD_
,
ConvDilationH_
,
ConvDilationW_
;
IndexType
InLeftPadD_
,
InLeftPadH_
,
InLeftPadW_
;
IndexType
InRightPadD_
,
InRightPadH_
,
InRightPadW_
;
IndexType
ZYX_
;
};
// wrapper class to call member functions on TransformConvToGemm struct at runtime
...
...
@@ -1230,17 +1556,17 @@ struct TransformConv
if
(
NDimSpatial
==
2
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
ck
::
tensor_layout
::
convolution
::
NHWGK
>();
.
template
MakeCDescriptor_M_N
<
ck
::
tensor_layout
::
convolution
::
NHWGK
,
2
>();
}
else
if
(
NDimSpatial
==
3
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NDHWGK
>();
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NDHWGK
,
3
>();
}
else
if
(
NDimSpatial
==
1
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NWGK
>();
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NWGK
,
1
>();
}
}
};
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
f6ceef78
...
...
@@ -562,6 +562,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset
);
}
template
<
typename
T
,
index_t
N
>
__device__
void
amd_global_atomic_add_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
T
*
addr
)
{
static_assert
((
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
2
||
N
==
4
||
N
==
8
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
vector_type
<
half_t
,
N
>
tmp
{
src_thread_data
};
static_for
<
0
,
N
/
2
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_global_atomic_fadd_v2f16
(
bit_cast
<
half2_t
*>
(
addr
)
+
i
,
tmp
.
template
AsType
<
half2_t
>()[
i
]);
});
}
#if defined(__gfx942__)
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
vector_type
<
bhalf_t
,
N
>
tmp
{
src_thread_data
};
static_for
<
0
,
N
/
2
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_global_atomic_fadd_v2bf16
(
bit_cast
<
bhalf2_t
*>
(
addr
)
+
i
,
tmp
.
template
AsType
<
bhalf2_t
>()[
i
]);
});
}
#endif
}
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_atomic_add_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
...
...
@@ -907,18 +935,29 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
if
(
dst_thread_element_valid
)
{
amd_global_atomic_add_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
p_dst_wave
+
dst_thread_element_offset
);
}
}
else
{
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
amd_buffer_atomic_add_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_element_valid
)
{
amd_buffer_atomic_add_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_element_valid
)
{
amd_buffer_atomic_add_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
}
}
// buffer_atomic_max requires:
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
f6ceef78
...
...
@@ -358,13 +358,15 @@ struct DynamicBuffer
bool
constexpr
use_amd_buffer_addressing
=
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
||
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
);
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
)
||
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
bhalf_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool
constexpr
use_amd_buffer_addressing
=
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
;
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
);
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
)
||
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
bhalf_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
...
...
include/ck/utility/f8_utils.hpp
View file @
f6ceef78
...
...
@@ -44,7 +44,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng)
// convert to bitwise
using
T_bitwise
=
typename
NumericUtils
<
X
>::
bitwise_type
;
T_bitwise
x_bitwise
=
*
(
reinterpre
t_cast
<
T_bitwise
*
>
(
&
x
)
);
T_bitwise
x_bitwise
=
bi
t_cast
<
T_bitwise
>
(
x
);
// unpack the input, depends on datatype
head
=
x_bitwise
&
NumericUtils
<
X
>::
head_mask
;
...
...
@@ -165,7 +165,7 @@ In this case, the fp16 mantissa should be shift left by 1 */
if
(
out_exponent
>
max_exp
)
{
if
(
clip
)
if
constexpr
(
clip
)
{
mantissa
=
(
1
<<
out_mant
)
-
1
;
out_exponent
=
max_exp
;
...
...
@@ -196,18 +196,17 @@ __host__ __device__ Y run_cast_from_f8(X x)
// prepare the codes
constexpr
X
nan_code
=
0x80
;
Y
Inf
,
NegInf
,
NaN
,
Neg0
;
using
T_bitwise
=
typename
NumericUtils
<
Y
>::
bitwise_type
;
using
T_bitwise
=
typename
NumericUtils
<
Y
>::
bitwise_type
;
constexpr
T_bitwise
Inf_bitwise
=
NumericUtils
<
Y
>::
Inf
;
constexpr
T_bitwise
NegInf_bitwise
=
NumericUtils
<
Y
>::
NegInf
;
constexpr
T_bitwise
NaN_bitwise
=
NumericUtils
<
Y
>::
NaN
;
constexpr
T_bitwise
Neg0_bitwise
=
NumericUtils
<
Y
>::
Neg0
;
Inf
=
*
(
reinterpret_cast
<
const
Y
*
>
(
&
Inf_bitwise
)
)
;
NegInf
=
*
(
reinterpret_cast
<
const
Y
*
>
(
&
NegInf_bitwise
)
)
;
NaN
=
*
(
reinterpret_cast
<
const
Y
*
>
(
&
NaN_bitwise
)
)
;
Neg0
=
*
(
reinterpret_cast
<
const
Y
*
>
(
&
Neg0_bitwise
)
)
;
constexpr
Y
Inf
=
bit_cast
<
Y
>
(
Inf_bitwise
);
constexpr
Y
NegInf
=
bit_cast
<
Y
>
(
NegInf_bitwise
);
constexpr
Y
NaN
=
bit_cast
<
Y
>
(
NaN_bitwise
);
constexpr
Y
Neg0
=
bit_cast
<
Y
>
(
Neg0_bitwise
);
// check if x is 0.0
if
(
x
==
0
)
...
...
@@ -235,11 +234,12 @@ __host__ __device__ Y run_cast_from_f8(X x)
return
(
mantissa
==
0
)
?
(
sign
?
NegInf
:
Inf
)
:
NaN
;
}
if
((
NumericUtils
<
Y
>::
mant
==
10
)
&&
(
NumericUtils
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
if
constexpr
((
NumericUtils
<
Y
>::
mant
==
10
)
&&
(
NumericUtils
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
{
retval
=
x
;
retval
<<=
8
;
return
*
(
reinterpret_cast
<
const
Y
*
>
(
&
retval
)
)
;
return
bit_cast
<
Y
>
(
retval
);
}
// subnormal input
...
...
@@ -263,7 +263,7 @@ __host__ __device__ Y run_cast_from_f8(X x)
}
retval
=
(
sign
<<
(
out_exp
+
out_mant
))
|
(
exponent
<<
out_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
Y
*
>
(
&
retval
)
)
;
return
bit_cast
<
Y
>
(
retval
);
}
}
// namespace
...
...
include/ck/utility/type_convert.hpp
View file @
f6ceef78
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/f8_utils.hpp"
#include "ck/utility/random_gen.hpp"
#include "ck/utility/array.hpp"
namespace
ck
{
// Define the common macro for gfx94x models
...
...
@@ -500,6 +501,25 @@ inline __host__ __device__ half_t type_convert<half_t, bf8_t>(bf8_t x)
#endif
}
template
<
typename
Y
,
typename
X
,
std
::
size_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
std
::
array
<
Y
,
NumElems
>&
y
,
const
std
::
array
<
X
,
NumElems
>&
x
)
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
template
<
typename
Y
,
typename
X
,
index_t
NumElems
>
inline
__host__
__device__
void
array_convert
(
Array
<
Y
,
NumElems
>&
y
,
const
Array
<
X
,
NumElems
>&
x
)
{
for
(
std
::
size_t
i
=
0
;
i
<
NumElems
;
i
++
)
{
y
[
i
]
=
type_convert
<
Y
>
(
x
[
i
]);
}
}
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
...
...
include/ck_tile/core/algorithm/coordinate_transform.hpp
View file @
f6ceef78
...
...
@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
};
// 2D XOR, NOTE: "xor" is a keyword
template
<
typename
LowLengths
,
typename
RightShift
>
template
<
typename
LowLengths
>
struct
xor_t
:
public
base_transform
<
2
,
2
>
{
static
constexpr
auto
type_enum
=
coord_transform_enum
::
xor_t
;
...
...
@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
using
UpLengths
=
LowLengths
;
UpLengths
up_lengths_
;
RightShift
right_shift_
;
CK_TILE_HOST_DEVICE
constexpr
xor_t
()
:
up_lengths_
{}
,
right_shift_
{}
{}
CK_TILE_HOST_DEVICE
constexpr
xor_t
()
:
up_lengths_
{}
{}
CK_TILE_HOST_DEVICE
constexpr
xor_t
(
const
LowLengths
&
low_lengths
,
const
RightShift
&
right_shift
)
:
up_lengths_
{
low_lengths
},
right_shift_
{
right_shift
}
{
}
CK_TILE_HOST_DEVICE
constexpr
xor_t
(
const
LowLengths
&
low_lengths
)
:
up_lengths_
{
low_lengths
}
{}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
...
...
@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}];
const
auto
idx_low_1_tmp
=
(
idx_up
[
number
<
1
>
{}]
-
idx_up
[
number
<
0
>
{}]
*
right_shift_
)
%
up_lengths_
[
number
<
1
>
{}];
const
auto
idx_low_1
=
(
idx_low_1_tmp
>=
0
)
?
idx_low_1_tmp
:
up_lengths_
[
number
<
1
>
{}]
+
idx_low_1_tmp
;
idx_low
(
number
<
1
>
{})
=
idx_low_1
;
idx_low
(
number
<
1
>
{})
=
idx_up
[
number
<
1
>
{}]
^
(
idx_up
[
number
<
0
>
{}]
%
up_lengths_
[
number
<
1
>
{}]);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
RightShift
>::
value
;
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
;
}
// MUST be static function
...
...
@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
array
<
index_t
,
2
>
up_vector_lengths
=
low_vector_lengths
;
array
<
index_t
,
2
>
up_vector_strides
=
low_vector_strides
;
if
constexpr
(
ck_tile
::
is_known_at_compile_time
<
RightShift
>::
value
)
{
if
(
low_vector_lengths
[
1
]
!=
-
1
)
{
up_vector_lengths
(
1
)
=
gcd
(
low_vector_lengths
[
1
],
abs
(
right_shift_
));
}
}
return
make_tuple
(
up_vector_lengths
,
up_vector_strides
);
}
...
...
@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"right_shift_: "
);
print
(
right_shift_
);
printf
(
"}"
);
}
};
...
...
@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return
modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
}
template
<
typename
LowLengths
,
typename
RightShift
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
,
const
RightShift
&
right_shift
)
template
<
typename
LowLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
)
{
return
xor_t
<
LowLengths
,
RightShift
>
{
low_lengths
,
right_shift
};
return
xor_t
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
LowLength
,
typename
OffsetLength
>
...
...
include/ck_tile/core/numeric/vector_type.hpp
View file @
f6ceef78
...
...
@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using
int32x32_t
=
int32_t
__attribute__
((
ext_vector_type
(
32
)));
using
int32x64_t
=
int32_t
__attribute__
((
ext_vector_type
(
64
)));
// u32
// using uint32_t = ...
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint32x4_t
=
uint32_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint32x8_t
=
uint32_t
__attribute__
((
ext_vector_type
(
8
)));
using
uint32x16_t
=
uint32_t
__attribute__
((
ext_vector_type
(
16
)));
using
uint32x32_t
=
uint32_t
__attribute__
((
ext_vector_type
(
32
)));
using
uint32x64_t
=
uint32_t
__attribute__
((
ext_vector_type
(
64
)));
// i16
// using int16_t = ...
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
f6ceef78
...
...
@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return
make_tuple
(
make_static_tile_distribution
(
tile_distribution_encoding
<
typename
Encoding
::
RsLengths
,
decltype
(
sliced_h_lengths
),
// only need to change the
// h_lengths type
remove_cvref_t
<
decltype
(
sliced_h_lengths
)
>
,
// only need to
// change the
// h_lengths type
typename
Encoding
::
Ps2RHssMajor
,
typename
Encoding
::
Ps2RHssMinor
,
typename
Encoding
::
Ys2RHsMajor
,
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
f6ceef78
...
...
@@ -393,7 +393,10 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
...
...
include/ck_tile/core/utility/philox_rand.hpp
View file @
f6ceef78
...
...
@@ -53,6 +53,39 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
CK_TILE_HOST_DEVICE
void
get_random_8x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
,
const
index_t
start_idx
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32x4_t
tmp
;
tmp
[
0
]
=
tmp_ph
.
x
;
tmp
[
1
]
=
tmp_ph
.
y
;
tmp
[
2
]
=
tmp_ph
.
z
;
tmp
[
3
]
=
tmp_ph
.
w
;
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp
[
start_idx
];
out_tmp
[
1
]
=
tmp
[
start_idx
+
2
];
}
CK_TILE_HOST_DEVICE
void
get_random_4x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
,
const
index_t
start_idx
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32x4_t
tmp
;
tmp
[
0
]
=
tmp_ph
.
x
;
tmp
[
1
]
=
tmp_ph
.
y
;
tmp
[
2
]
=
tmp_ph
.
z
;
tmp
[
3
]
=
tmp_ph
.
w
;
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp
[
start_idx
];
}
private:
struct
ull2
{
...
...
include/ck_tile/ops/fmha.hpp
View file @
f6ceef78
...
...
@@ -8,21 +8,16 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
...
...
include/ck_tile/ops/fmha/block/block_dropout.hpp
View file @
f6ceef78
...
...
@@ -286,11 +286,226 @@ struct BlockDropout
});
}
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
template
<
bool
IsDropout_
,
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropoutBwd
;
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropoutBwd
<
false
,
IsWG32_
,
IsStoreRandval_
>
{
static
constexpr
bool
IsDropout
=
false
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
__host__
__device__
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
(
void
)
randval_dram_block_window_tmp
;
(
void
)
seqlen_qk_start
;
return
make_null_tile_window
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{}));
}
};
template
<
bool
IsWG32_
,
bool
IsStoreRandval_
>
struct
BlockDropoutBwd
<
true
,
IsWG32_
,
IsStoreRandval_
>
{
static
constexpr
bool
IsDropout
=
true
;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static
constexpr
bool
IsWG32
=
IsWG32_
;
static
constexpr
bool
IsStoreRandval
=
IsStoreRandval_
;
CK_TILE_HOST_DEVICE
BlockDropoutBwd
(
index_t
i_batch
,
index_t
i_head
,
index_t
nheads
,
unsigned
long
long
seed
,
unsigned
long
long
offset
,
float
rp_undrop_
,
uint8_t
p_undrop_in_uint8_t_
)
:
ph
(
seed
,
offset
+
(
i_batch
*
nheads
+
i_head
)
*
get_warp_size
()
+
(
IsWG32
?
get_lane_id
()
:
((
get_lane_id
()
&
47
)
+
((
get_warp_id
()
&
1
)
<<
4
)))),
rp_undrop
(
rp_undrop_
),
p_undrop_in_uint8_t
(
p_undrop_in_uint8_t_
)
{
}
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
,
typename
RandValDramBlockWindowTmp
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandvalDramWindow
(
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
index_t
seqlen_qk_start
)
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsFwd
)
&&
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
index_t
kMPerStep
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
MWarp
*
WG
::
kM
*
2
;
}
else
{
return
MWarp
*
WG
::
kM
;
}
}();
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
const
auto
block_origin
=
randval_dram_block_window_tmp
.
get_window_origin
();
auto
randval_dram_window
=
[
&
]()
{
if
constexpr
(
IsFwd
)
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
block_origin
.
at
(
number
<
0
>
{}),
seqlen_qk_start
});
// M/N
}
else
{
return
make_tile_window
(
randval_dram_block_window_tmp
.
get_bottom_tensor_view
(),
ck_tile
::
make_tuple
(
number
<
kMPerStep
>
{},
number
<
kNPerStep
>
{}),
{
seqlen_qk_start
,
block_origin
.
at
(
number
<
1
>
{})});
// M/N
}
}();
return
randval_dram_window
;
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsBlockDescriptor
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
WG
::
kN
;
constexpr
index_t
kN1
=
8
;
constexpr
index_t
kN0
=
kNPerStep
/
kN1
;
constexpr
auto
randval_lds_block_desc_0
=
make_naive_tensor_descriptor
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kMPerStep
>
{},
number
<
kN1
>
{}),
ck_tile
::
make_tuple
(
number
<
(
kMPerStep
+
1
)
*
kN1
>
{},
number
<
kN1
>
{},
number
<
1
>
{}),
number
<
kN1
>
{},
number
<
1
>
{});
constexpr
auto
randval_lds_block_desc
=
transform_tensor_descriptor
(
randval_lds_block_desc_0
,
ck_tile
::
make_tuple
(
make_pass_through_transform
(
number
<
kMPerStep
>
{}),
make_merge_transform
(
ck_tile
::
make_tuple
(
number
<
kN0
>
{},
number
<
kN1
>
{}))),
ck_tile
::
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
ck_tile
::
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
randval_lds_block_desc
;
}
template
<
typename
BlockGemm
,
bool
IsFwd
=
true
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsFwd
)
&&
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
index_t
MIterPerWarp
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
2
;
}
else
{
return
1
;
}
}();
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
// except headdim256.
constexpr
auto
randval_block_inner_part_dstr_encoding
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
BlockGemm
::
ADataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
BDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
BlockGemm
::
CDataType
,
float
>
)
{
if
constexpr
(
IsWG32
)
return
typename
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
else
return
typename
WarpGemmMfmaF16F16F32M16N16K16
::
CWarpDstrEncoding
{};
}
else
{
if
constexpr
(
IsWG32
)
return
typename
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
::
CWarpDstrEncoding
{};
else
return
typename
WarpGemmMfmaBf16Bf16F32M16N16K16
::
CWarpDstrEncoding
{};
}
}();
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
randval_block_inner_part_dstr_encoding
);
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeRandValLdsShuffleTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
1
;
constexpr
index_t
NIterPerWarp
=
1
;
constexpr
auto
randval_block_outer_part_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
randval_block_part_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
randval_block_outer_part_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
return
make_static_tile_distribution
(
randval_block_part_dstr_encode
);
}
template
<
typename
BlockGemm
,
typename
PComputeDataType
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
const
index_t
start_m0_idx
,
CK_TILE_HOST_DEVICE
void
Run
(
void
*
randval_ptr
,
const
index_t
start_m0_idx
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
...
...
@@ -305,30 +520,177 @@ struct BlockDropout
constexpr
index_t
kMPerStep
=
MWarp
*
WG
::
kM
;
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// randval tile in LDS
auto
randval_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
uint8_t
*>
(
randval_ptr
),
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
());
auto
randval_lds_window
=
make_tile_window
(
randval_lds
,
MakeRandValLdsBlockDescriptor
<
BlockGemm
>
().
get_lengths
(),
{
0
,
0
});
// register distribute
auto
randval
=
auto
randval
_dist_generated
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
>
());
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
static_assert
(
randval
_dist_generated
.
kThreadElementSpaceSize
==
16
);
const
int
start_n0_idx
=
randval_dram_window
.
get_window_origin
().
at
(
number
<
1
>
{});
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
auto
randval_lds_read_window
=
make_tile_window
(
randval_lds_window
.
get_bottom_tensor_view
(),
randval_lds_window
.
get_window_lengths
(),
randval_lds_window
.
get_window_origin
(),
MakeRandValLdsShuffleTileDistribution
<
BlockGemm
>
());
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
int
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
(
i_m0
*
MWarp
)
+
get_warp_id
();
int
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
i_n0
;
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
constexpr
auto
randval_dist_generated_spans
=
decltype
(
randval_dist_generated
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_dist_generated_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval_dist_generated
(
i_j_idx
)
=
random_uint8_t
[
i_random_idx
++
];
});
});
// save to LDS
store_tile
(
randval_lds_window
,
randval_dist_generated
);
block_sync_lds
();
// read from LDS to register
auto
randval
=
load_tile
(
randval_lds_read_window
);
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
,
idx1
.
impl_
.
at
(
1
),
idx1
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
,
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
?
p_compute
[
p_idx
]
*
rp_undrop
:
PComputeDataType
(
0
);
});
});
// save to Global
if
constexpr
(
IsStoreRandval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
0
,
kNPerStep
});
}
});
if
constexpr
(
IsStoreRandval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
-
kNPerBlock
});
}
});
if
constexpr
(
IsStoreRandval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerBlock
});
}
}
template
<
typename
BlockGemm
,
typename
RandValOutputDataType
,
typename
PComputeWindow
,
typename
RandValDramWindow
>
CK_TILE_HOST_DEVICE
void
Run
(
const
index_t
start_m0_idx
,
const
index_t
start_n0_idx
,
PComputeWindow
&
p_compute
,
RandValDramWindow
&
randval_dram_window
)
const
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
typename
BlockGemm
::
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
using
BlockGemmShape
=
remove_cvref_t
<
typename
BlockGemm
::
BlockGemmShape
>
;
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
constexpr
bool
MBwdWG16MultiIterCheck
=
(
!
IsWG32
)
&&
(
kMPerBlock
>
16
);
constexpr
bool
MBwdWG16SingleIterCheck
=
(
!
IsWG32
)
&&
(
kMPerBlock
==
16
);
constexpr
index_t
kMPerStep
=
[
&
]()
{
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
return
MWarp
*
WG
::
kM
*
2
;
}
else
{
return
MWarp
*
WG
::
kM
;
}
}();
constexpr
index_t
kNPerStep
=
NWarp
*
WG
::
kN
;
// register distribute
auto
randval
=
make_static_distributed_tensor
<
uint8_t
>
(
MakeRandValTileDistribution
<
BlockGemm
,
false
>
());
if
constexpr
(
IsWG32
)
static_assert
(
randval
.
kThreadElementSpaceSize
==
16
);
else
static_assert
(
randval
.
kThreadElementSpaceSize
==
4
||
randval
.
kThreadElementSpaceSize
==
8
);
static_for
<
0
,
kNPerBlock
/
kNPerStep
,
1
>
{}([
&
](
auto
i_n0
)
{
static_for
<
0
,
kMPerBlock
/
kMPerStep
,
1
>
{}([
&
](
auto
i_m0
)
{
int
block_row_start
,
block_col_start
;
if
constexpr
(
IsWG32
)
{
block_row_start
=
(
start_m0_idx
/
WG
::
kM
)
+
i_m0
;
block_col_start
=
(
start_n0_idx
/
WG
::
kN
)
+
(
i_n0
*
NWarp
)
+
get_warp_id
();
}
else
{
block_row_start
=
start_m0_idx
/
32
+
i_m0
;
block_col_start
=
(
start_n0_idx
/
32
)
+
get_warp_id
()
/
2
+
i_n0
*
2
;
}
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
uint8_t
*
random_uint8_t_
;
if
constexpr
(
MBwdWG16SingleIterCheck
)
{
uint8_t
random_uint8_t
[
4
];
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
const
index_t
start_idx
=
((
get_lane_id
()
>>
4
)
&
1
)
+
(((
start_m0_idx
>>
4
)
&
1
)
<<
1
);
ph
.
get_random_4x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
start_idx
);
random_uint8_t_
=
random_uint8_t
;
}
else
if
constexpr
(
MBwdWG16MultiIterCheck
)
{
uint8_t
random_uint8_t
[
8
];
// t0 ~t15/t32~t47: 0
// t16~t31/t48~t63: 1
const
index_t
start_idx
=
(
get_lane_id
()
>>
4
)
&
1
;
ph
.
get_random_8x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
start_idx
);
random_uint8_t_
=
random_uint8_t
;
}
else
{
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
random_uint8_t_
=
random_uint8_t
;
}
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
int
i_random_idx
=
0
;
sweep_tile_span
(
randval_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
randval_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
r_idx
=
ck_tile
::
make_tuple
(
idx0
,
idx1
);
randval
(
r_idx
)
=
random_uint8_t_
[
i_random_idx
++
];
constexpr
auto
p_idx0
=
tile_distributed_index
<
i_m0
+
idx0
.
impl_
.
at
(
0
),
idx0
.
impl_
.
at
(
1
),
idx0
.
impl_
.
at
(
2
)
>
{};
constexpr
auto
p_idx1
=
tile_distributed_index
<
i_n0
>
{};
constexpr
auto
p_idx
=
ck_tile
::
make_tuple
(
p_idx0
,
p_idx1
);
p_compute
(
p_idx
)
=
randval
[
r_idx
]
<=
p_undrop_in_uint8_t
...
...
@@ -337,19 +699,19 @@ struct BlockDropout
});
});
// save to Global
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
const
auto
randval_store
=
cast_tile
<
RandValOutputDataType
>
(
randval
);
store_tile
(
randval_dram_window
,
randval_store
);
move_tile_window
(
randval_dram_window
,
{
kMPerStep
,
0
});
}
});
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
move_tile_window
(
randval_dram_window
,
{
-
kMPerBlock
,
kNPerStep
});
}
});
if
(
is_s
tore
_r
andval
)
if
constexpr
(
IsS
tore
R
andval
)
{
move_tile_window
(
randval_dram_window
,
{
kMPerBlock
,
-
kNPerBlock
});
}
...
...
@@ -358,7 +720,6 @@ struct BlockDropout
ck_tile
::
philox
ph
;
const
float
rp_undrop
;
const
uint8_t
p_undrop_in_uint8_t
;
const
bool
is_store_randval
;
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
f6ceef78
...
...
@@ -23,13 +23,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
KGradEpiloguePipeline_
,
typename
VGradEpiloguePipeline_
>
struct
FmhaBwdDQDKDVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
KGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
KGradEpiloguePipeline_
>
;
using
VGradEpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
VGradEpiloguePipeline_
>
;
...
...
@@ -59,9 +55,12 @@ struct FmhaBwdDQDKDVKernel
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kHasBiasGrad
=
FmhaPipeline
::
kHasBiasGrad
;
static
constexpr
bool
kHasDropout
=
FmhaPipeline
::
kHasDropout
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
using
FmhaDropout
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaDropout
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasDropout
=
FmhaDropout
::
IsDropout
;
static
constexpr
bool
kIsStoreRandval
=
FmhaDropout
::
IsStoreRandval
;
static
constexpr
bool
kIsDeterministic
=
FmhaPipeline
::
kIsDeterministic
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
...
...
@@ -73,9 +72,12 @@ struct FmhaBwdDQDKDVKernel
{
// sync with generate.py
// clang-format off
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr
=
typename
bfs
::
Gemm0BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
gbr0
=
typename
bfs
::
Gemm0BlockWarps
;
using
gbr1
=
typename
bfs
::
Gemm1BlockWarps
;
using
gbr4
=
typename
bfs
::
Gemm4BlockWarps
;
using
gwt0
=
typename
bfs
::
Gemm0WarpTile
;
using
gwt1
=
typename
bfs
::
Gemm1WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
...
...
@@ -88,13 +90,17 @@ struct FmhaBwdDQDKDVKernel
return
_SS_
(
"fmha_bwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
"r"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kK2
)
+
"x"
+
_TS_
(
bfs
::
kK3
)
+
"x"
+
_TS_
(
bfs
::
kK4
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"x"
+
_TS_
(
bfs
::
kVHeaddim
)
+
"_"
+
"r"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr0
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr1
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gbr4
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt0
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt1
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
)
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
);
(
kHasBiasGrad
?
"_dbias"
:
""
)
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kIsStoreRandval
?
"_storerandval"
:
""
)
+
(
kIsDeterministic
?
"_deterministic"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
@@ -117,7 +123,7 @@ struct FmhaBwdDQDKDVKernel
const
void
*
lse_ptr
;
const
void
*
do_ptr
;
const
void
*
d_ptr
;
void
*
dq_ptr
;
void
*
dq_
acc_
ptr
;
void
*
dk_ptr
;
void
*
dv_ptr
;
...
...
@@ -131,14 +137,13 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
num_head_q
;
ck_tile
::
index_t
nhead_ratio_qk
;
float
raw_scale
;
#if CK_TILE_FMHA_FWD_FAST_EXP2
float
scale
;
#endif
ck_tile
::
index_t
stride_q
;
ck_tile
::
index_t
stride_k
;
ck_tile
::
index_t
stride_v
;
ck_tile
::
index_t
stride_do
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
stride_dk
;
ck_tile
::
index_t
stride_dv
;
...
...
@@ -147,8 +152,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_v
;
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_lsed
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dk
;
ck_tile
::
index_t
nhead_stride_dv
;
};
struct
FmhaBwdCommonBiasKargs
...
...
@@ -206,7 +212,6 @@ struct FmhaBwdDQDKDVKernel
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
...
...
@@ -218,6 +223,10 @@ struct FmhaBwdDQDKDVKernel
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
struct
FmhaBwdDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
};
struct
FmhaBwdBatchModeKargs
:
FmhaBwdCommonKargs
,
...
...
@@ -228,12 +237,15 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdBatchModeBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kHasDropout
,
FmhaBwdBatchModeDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdDeterministicKargs
,
FmhaBwdEmptyKargs
<
4
>>
{
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_lsed
;
ck_tile
::
index_t
batch_stride_dq_acc
;
ck_tile
::
index_t
batch_stride_dk
;
ck_tile
::
index_t
batch_stride_dv
;
};
...
...
@@ -247,7 +259,8 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs
<
0
>>>
,
std
::
conditional_t
<
kHasBiasGrad
,
FmhaBwdCommonBiasGradKargs
,
FmhaBwdEmptyKargs
<
1
>>
,
std
::
conditional_t
<
kHasMask
,
FmhaBwdMaskKargs
,
FmhaBwdEmptyKargs
<
2
>>
,
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
std
::
conditional_t
<
kHasDropout
,
FmhaBwdCommonDropoutKargs
,
FmhaBwdEmptyKargs
<
3
>>
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdDeterministicKargs
,
FmhaBwdEmptyKargs
<
4
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
...
...
@@ -266,10 +279,10 @@ struct FmhaBwdDQDKDVKernel
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
...
...
@@ -283,6 +296,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
...
...
@@ -293,6 +307,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch_stride_q
,
ck_tile
::
index_t
batch_stride_k
,
...
...
@@ -301,14 +318,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
batch_stride_randval
,
ck_tile
::
index_t
batch_stride_do
,
ck_tile
::
index_t
batch_stride_lsed
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dk
,
ck_tile
::
index_t
batch_stride_dv
,
ck_tile
::
index_t
batch_stride_dbias
,
ck_tile
::
index_t
split_stride_dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -317,7 +335,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dq_
acc_
ptr
,
dk_ptr
,
dv_ptr
,
seqlen_q
,
...
...
@@ -327,13 +345,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dv
,
nhead_stride_q
,
...
...
@@ -341,15 +358,20 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for deterministic
batch_stride_q
,
batch_stride_k
,
batch_stride_v
,
batch_stride_do
,
batch_stride_lsed
,
batch_stride_dq_acc
,
batch_stride_dk
,
batch_stride_dv
};
...
...
@@ -384,11 +406,18 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
batch_stride_randval
=
batch_stride_randval
;
}
}
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
...
...
@@ -404,10 +433,10 @@ struct FmhaBwdDQDKDVKernel
const
void
*
do_ptr
,
const
void
*
d_ptr
,
void
*
rand_val_ptr
,
void
*
dq_ptr
,
void
*
dk_ptr
,
void
*
dv_ptr
,
void
*
dbias_ptr
,
void
*
dq_acc_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
const
void
*
seqlen_k_ptr
,
...
...
@@ -422,6 +451,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
stride_bias
,
ck_tile
::
index_t
stride_randval
,
ck_tile
::
index_t
stride_do
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
stride_dk
,
ck_tile
::
index_t
stride_dv
,
ck_tile
::
index_t
stride_dbias
,
...
...
@@ -432,13 +462,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_lsed
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dk
,
ck_tile
::
index_t
nhead_stride_dv
,
ck_tile
::
index_t
nhead_stride_dbias
,
ck_tile
::
index_t
batch
_stride_
lsed
,
ck_tile
::
index_t
split
_stride_
dq_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
...
...
@@ -447,7 +479,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr
,
do_ptr
,
d_ptr
,
dq_ptr
,
dq_
acc_
ptr
,
dk_ptr
,
dv_ptr
,
-
1
,
// seqlen will be updated by another pointer
...
...
@@ -457,13 +489,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q
,
nhead_ratio_qk
,
scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast
<
float
>
(
scale
*
ck_tile
::
log2e_v
<>
),
#endif
stride_q
,
stride_k
,
stride_v
,
stride_do
,
stride_dq_acc
,
stride_dk
,
stride_dv
,
nhead_stride_q
,
...
...
@@ -471,11 +502,14 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v
,
nhead_stride_do
,
nhead_stride_lsed
,
batch_stride_lsed
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
nhead_stride_dq_acc
,
nhead_stride_dk
,
nhead_stride_dv
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for dbias
{},
// placeholder for mask
{},
// placeholder for dropout
{},
// placeholder for deterministic
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
)};
...
...
@@ -506,10 +540,16 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
is_store_randval
=
s_randval
;
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
}
}
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
...
...
@@ -518,7 +558,17 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_k_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
FmhaPipeline
::
kN0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -536,7 +586,7 @@ struct FmhaBwdDQDKDVKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_k
);
const
auto
[
i_tile_n
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN0
);
...
...
@@ -547,6 +597,7 @@ struct FmhaBwdDQDKDVKernel
long_index_t
batch_offset_randval
=
0
;
long_index_t
batch_offset_do
=
0
;
long_index_t
batch_offset_lsed
=
0
;
long_index_t
batch_offset_dq_acc
=
0
;
long_index_t
batch_offset_dk
=
0
;
long_index_t
batch_offset_dv
=
0
;
long_index_t
batch_offset_dbias
=
0
;
...
...
@@ -557,13 +608,14 @@ struct FmhaBwdDQDKDVKernel
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_lsed
=
query_start
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
batch_offset_dk
=
key_start
*
kargs
.
stride_dk
;
batch_offset_dv
=
key_start
*
kargs
.
stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
;
...
...
@@ -576,7 +628,7 @@ struct FmhaBwdDQDKDVKernel
{
batch_offset_dbias
=
key_start
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
query_start
*
kargs
.
stride_randval
;
}
...
...
@@ -603,13 +655,14 @@ struct FmhaBwdDQDKDVKernel
}
else
{
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dk
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dk
;
batch_offset_dv
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dv
;
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_do
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_do
;
batch_offset_lsed
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lsed
;
batch_offset_dq_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq_acc
;
batch_offset_dk
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dk
;
batch_offset_dv
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dv
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
...
...
@@ -618,7 +671,7 @@ struct FmhaBwdDQDKDVKernel
{
batch_offset_dbias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dbias
;
}
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
batch_offset_randval
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_randval
;
...
...
@@ -646,14 +699,11 @@ struct FmhaBwdDQDKDVKernel
const
OGradDataType
*
do_ptr
=
reinterpret_cast
<
const
OGradDataType
*>
(
kargs
.
do_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_do
+
batch_offset_do
;
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
KGradDataType
*
dk_ptr
=
reinterpret_cast
<
KGradDataType
*>
(
kargs
.
dk_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_k
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_
d
k
+
batch_offset_dk
;
VGradDataType
*
dv_ptr
=
reinterpret_cast
<
VGradDataType
*>
(
kargs
.
dv_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_v
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_
d
v
+
batch_offset_dv
;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
...
...
@@ -663,45 +713,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
const
auto
q_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
const
auto
qt_dram_naive
=
transform_tensor_view
(
q_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
qt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
qt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenQ
>
{});
}
}();
const
auto
q_dram
=
pad_tensor_view
(
q_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
const
auto
k_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
k_ptr
,
...
...
@@ -709,45 +724,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_k
,
1
),
number
<
FmhaPipeline
::
kAlignmentK
>
{},
number
<
1
>
{});
const
auto
k_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
else
{
return
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
}
}();
const
auto
kt_dram_naive
=
transform_tensor_view
(
k_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_q
),
make_pass_through_transform
(
kargs
.
seqlen_k
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
kt_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
else
{
return
pad_tensor_view
(
kt_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{}),
sequence
<
kPadHeadDimQ
,
kPadSeqLenK
>
{});
}
}();
const
auto
k_dram
=
pad_tensor_view
(
k_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimQ
>
{});
const
auto
v_dram
=
[
&
]()
{
const
auto
v_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
...
...
@@ -756,20 +736,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_v
,
1
),
number
<
FmhaPipeline
::
kAlignmentV
>
{},
number
<
1
>
{});
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}
return
pad_tensor_view
(
v_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenK
,
kPadHeadDimV
>
{});
}();
const
auto
lse_dram
=
[
&
]()
{
...
...
@@ -792,145 +762,89 @@ struct FmhaBwdDQDKDVKernel
make_tuple
(
kargs
.
stride_do
,
1
),
number
<
FmhaPipeline
::
kAlignmentOGrad
>
{},
number
<
1
>
{});
const
auto
do_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
else
{
return
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
}
}();
const
auto
dot_dram_naive
=
transform_tensor_view
(
do_dram_naive
,
make_tuple
(
make_pass_through_transform
(
kargs
.
hdim_v
),
make_pass_through_transform
(
kargs
.
seqlen_q
)),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
const
auto
dot_dram
=
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
else
{
return
pad_tensor_view
(
dot_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{}),
sequence
<
kPadHeadDimV
,
kPadSeqLenQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
const
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
const
auto
do_dram
=
pad_tensor_view
(
do_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimV
>
{});
auto
q_dram_window
=
make_tile_window
(
q_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
auto
qt_dram_window
=
make_tile_window
(
qt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK3
>
{});
}(),
{
0
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK0
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
i_n0
,
0
});
auto
kt_dram_window
=
make_tile_window
(
kt_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kKTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kQKHeaddim
>
{},
number
<
FmhaPipeline
::
kK4
>
{});
}(),
{
0
,
i_n0
});
auto
v_dram_window
=
make_tile_window
(
v_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kVLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kN0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
{
i_n0
,
0
});
auto
do_dram_window
=
make_tile_window
(
do_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kK2
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kVHeaddim
>
{}),
{
0
,
0
});
auto
dot_dram_window
=
make_tile_window
(
dot_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kOGradTLoadOnce
)
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
else
return
make_tuple
(
number
<
FmhaPipeline
::
kVHeaddim
>
{},
number
<
FmhaPipeline
::
kK1
>
{});
}(),
{
0
,
0
});
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
auto
dq_dram_window
=
[
&
,
i_tile_n_
=
i_tile_n
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsDeterministic
)
{
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dq_acc
+
static_cast
<
long_index_t
>
(
i_tile_n_
)
*
kargs
.
split_stride_dq_acc
+
batch_offset_dq_acc
;
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
}
else
{
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_dq_acc
+
batch_offset_dq_acc
;
auto
dq_acc_dram
=
[
&
]()
{
const
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
,
memory_operation_enum
::
atomic_add
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kQKHeaddim
>
{}),
{
0
,
0
});
}
}();
auto
lse_dram_window
=
make_tile_window
(
lse_dram
,
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
});
...
...
@@ -1008,9 +922,7 @@ struct FmhaBwdDQDKDVKernel
// TODO: how to use s_read?
AccDataType
slope
=
*
(
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
alibi_slope_ptr
)
+
i_batch_
*
kargs
.
alibi_slope_stride
+
i_nhead_
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope
*=
ck_tile
::
log2e_v
<>
;
#endif
if
constexpr
(
kHasMask
)
{
return
make_alibi_from_lr_mask
<
AccDataType
,
false
>
(
slope
,
...
...
@@ -1033,35 +945,34 @@ struct FmhaBwdDQDKDVKernel
}();
// dropout
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
0
;
uint64_t
drop_offset
=
0
;
bool
is_store_randval
=
false
;
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
if
constexpr
(
kHasDropout
)
{
rp_undrop
=
kargs
.
rp_undrop
;
scale_rp_undrop
=
kargs
.
scale_rp_undrop
;
p_undrop_in_uint8_t
=
kargs
.
p_undrop_in_uint8_t
;
drop_seed
=
kargs
.
drop_seed
;
drop_offset
=
kargs
.
drop_offset
;
is_store_randval
=
kargs
.
is_store_randval
;
rp_undrop
=
kargs
.
rp_undrop
;
scale_rp_undrop
=
kargs
.
scale_rp_undrop
;
}
BlockDropout
dropout
(
i_batch
,
i_nhead
,
kargs
.
num_head_q
,
drop_seed
,
drop_offset
,
rp_undrop
,
p_undrop_in_uint8_t
,
is_store_randval
);
auto
dropout
=
[
&
,
i_nhead_
=
i_nhead
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kHasDropout
)
{
return
FmhaDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
}
else
{
return
FmhaDropout
{};
};
}();
auto
randval_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
constexpr
auto
randval_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN0
>
{});
if
constexpr
(
k
HasDropout
)
if
constexpr
(
k
IsStoreRandval
)
{
RandValOutputDataType
*
rand_val_ptr
=
reinterpret_cast
<
RandValOutputDataType
*>
(
kargs
.
rand_val_ptr
)
+
...
...
@@ -1103,14 +1014,11 @@ struct FmhaBwdDQDKDVKernel
}();
auto
[
dk_acc_tile
,
dv_acc_tile
]
=
FmhaPipeline
{}(
q_dram_window
,
qt_dram_window
,
k_dram_window
,
kt_dram_window
,
v_dram_window
,
bias_dram_window
,
randval_dram_window
,
do_dram_window
,
dot_dram_window
,
lse_dram_window
,
d_dram_window
,
dq_dram_window
,
...
...
@@ -1118,9 +1026,7 @@ struct FmhaBwdDQDKDVKernel
mask
,
position_encoding
,
kargs
.
raw_scale
,
#if CK_TILE_FMHA_FWD_FAST_EXP2
kargs
.
scale
,
#endif
rp_undrop
,
scale_rp_undrop
,
smem_ptr
,
...
...
@@ -1169,10 +1075,9 @@ struct FmhaBwdDQDKDVKernel
}
};
template
<
typename
TilePartitioner_
,
typename
FmhaBwdOGradDotO_
>
template
<
typename
FmhaBwdOGradDotO_
>
struct
FmhaBwdOGradDotOKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaBwdOGradDotO
=
ck_tile
::
remove_cvref_t
<
FmhaBwdOGradDotO_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdOGradDotO
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdOGradDotO
::
kBlockPerCu
;
...
...
@@ -1234,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel
ck_tile
::
index_t
nhead_stride_do
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_d
;
ck_tile
::
index_t
batch_stride_d
;
};
struct
FmhaBwdOGradDotOBatchModeKargs
:
FmhaBwdOGradDotOCommonKargs
{
ck_tile
::
index_t
batch_stride_do
;
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_d
;
};
struct
FmhaBwdOGradDotOGroupModeKargs
:
FmhaBwdOGradDotOCommonKargs
...
...
@@ -1278,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel
stride_o
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_d
,
batch_stride_d
},
nhead_stride_d
},
batch_stride_do
,
batch_stride_o
};
batch_stride_o
,
batch_stride_d
};
return
kargs
;
}
...
...
@@ -1298,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel
ck_tile
::
index_t
stride_o
,
ck_tile
::
index_t
nhead_stride_do
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_d
,
ck_tile
::
index_t
batch_stride_d
)
ck_tile
::
index_t
nhead_stride_d
)
{
Kargs
kargs
{{
o_ptr
,
do_ptr
,
...
...
@@ -1311,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel
stride_o
,
nhead_stride_do
,
nhead_stride_o
,
nhead_stride_d
,
batch_stride_d
},
nhead_stride_d
},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
)};
return
kargs
;
...
...
@@ -1321,7 +1224,16 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -1331,7 +1243,7 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
kargs
.
seqlen_q
);
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
...
...
@@ -1346,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel
batch_offset_o
=
query_start
*
kargs
.
stride_o
;
batch_offset_do
=
query_start
*
kargs
.
stride_do
;
batch_offset_d
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_d
;
batch_offset_d
=
query_start
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
...
...
@@ -1418,4 +1330,315 @@ struct FmhaBwdOGradDotOKernel
}
};
template
<
typename
FmhaBwdConvertQGrad_
>
struct
FmhaBwdConvertQGradKernel
{
using
FmhaBwdConvertQGrad
=
ck_tile
::
remove_cvref_t
<
FmhaBwdConvertQGrad_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaBwdConvertQGrad
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaBwdConvertQGrad
::
kBlockPerCu
;
static
constexpr
ck_tile
::
index_t
kM0
=
FmhaBwdConvertQGrad
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
FmhaBwdConvertQGrad
::
kN0
;
static
constexpr
ck_tile
::
index_t
kQKHeaddim
=
FmhaBwdConvertQGrad
::
kQKHeaddim
;
using
AccDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdConvertQGrad
::
AccDataType
>
;
using
QGradDataType
=
ck_tile
::
remove_cvref_t
<
typename
FmhaBwdConvertQGrad
::
QGradDataType
>
;
static
constexpr
bool
kIsGroupMode
=
FmhaBwdConvertQGrad
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
FmhaBwdConvertQGrad
::
kPadSeqLenQ
;
static
constexpr
bool
kPadHeadDimQ
=
FmhaBwdConvertQGrad
::
kPadHeadDimQ
;
static
constexpr
bool
kIsDeterministic
=
FmhaBwdConvertQGrad
::
kIsDeterministic
;
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
// clang-format on
CK_TILE_HOST
static
std
::
string
GetName
()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadSeqLenQ
)
n
+=
"s"
;
if
(
kPadHeadDimQ
)
n
+=
"d"
;
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_bwd_convert_dq_d"
)
+
_TS_
(
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QGradDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
(
kIsDeterministic
?
"_deterministic"
:
""
)
+
"_"
+
(
"o"
+
_TS_
(
kBlockPerCu
))
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
);
#undef _SS_
#undef _TS_
// clang-format on
}
// to avoid duplicated base class prblem, introduce an template arg
template
<
ck_tile
::
index_t
I
>
struct
FmhaBwdConvertQGradEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct
FmhaBwdConvertQGradCommonKargs
{
const
void
*
dq_acc_ptr
;
void
*
dq_ptr
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_k
;
ck_tile
::
index_t
hdim_q
;
ck_tile
::
index_t
stride_dq
;
ck_tile
::
index_t
stride_dq_acc
;
ck_tile
::
index_t
nhead_stride_dq
;
ck_tile
::
index_t
nhead_stride_dq_acc
;
};
struct
FmhaBwdConvertQGradDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
};
struct
FmhaBwdConvertQGradBatchModeKargs
:
FmhaBwdConvertQGradCommonKargs
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdConvertQGradDeterministicKargs
,
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
ck_tile
::
index_t
batch_stride_dq
;
ck_tile
::
index_t
batch_stride_dq_acc
;
};
struct
FmhaBwdConvertQGradGroupModeKargs
:
FmhaBwdConvertQGradCommonKargs
,
std
::
conditional_t
<
kIsDeterministic
,
FmhaBwdConvertQGradDeterministicKargs
,
FmhaBwdConvertQGradEmptyKargs
<
0
>>
{
const
int32_t
*
seqstart_q_ptr
;
const
int32_t
*
seqstart_k_ptr
;
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
FmhaBwdConvertQGradGroupModeKargs
,
FmhaBwdConvertQGradBatchModeKargs
>
;
template
<
bool
Cond
=
!
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_k
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
batch_stride_dq
,
ck_tile
::
index_t
batch_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
)
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
seqlen_q
,
seqlen_k
,
hdim_q
,
stride_dq
,
stride_dq_acc
,
nhead_stride_dq
,
nhead_stride_dq_acc
},
{},
batch_stride_dq
,
batch_stride_dq_acc
};
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
}
template
<
bool
Cond
=
kIsGroupMode
>
CK_TILE_HOST
static
constexpr
std
::
enable_if_t
<
Cond
,
Kargs
>
MakeKargs
(
const
void
*
dq_acc_ptr
,
void
*
dq_ptr
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_k_ptr
,
ck_tile
::
index_t
hdim_q
,
ck_tile
::
index_t
stride_dq
,
ck_tile
::
index_t
stride_dq_acc
,
ck_tile
::
index_t
nhead_stride_dq
,
ck_tile
::
index_t
nhead_stride_dq_acc
,
ck_tile
::
index_t
split_stride_dq_acc
)
{
Kargs
kargs
{{
dq_acc_ptr
,
dq_ptr
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
//
hdim_q
,
stride_dq
,
stride_dq_acc
,
nhead_stride_dq
,
nhead_stride_dq_acc
},
{},
reinterpret_cast
<
const
int32_t
*>
(
seqstart_q_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqstart_k_ptr
)};
if
constexpr
(
kIsDeterministic
)
{
kargs
.
split_stride_dq_acc
=
split_stride_dq_acc
;
}
return
kargs
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
()
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile_m
,
i_nhead
,
i_batch
]
=
GetTileIndex
();
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
kM0
);
long_index_t
batch_offset_dq
=
0
;
long_index_t
batch_offset_dq_acc
=
0
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_dq
=
query_start
*
kargs
.
stride_dq
;
batch_offset_dq_acc
=
query_start
*
kargs
.
stride_dq_acc
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
if
constexpr
(
kIsDeterministic
)
{
const
auto
adjusted_seqstart_k_ptr
=
kargs
.
seqstart_k_ptr
+
i_batch
;
kargs
.
seqlen_k
=
adjusted_seqstart_k_ptr
[
1
]
-
adjusted_seqstart_k_ptr
[
0
];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
{
return
;
}
}
else
{
batch_offset_dq
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq
;
batch_offset_dq_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_dq_acc
;
}
// for simplicity, batch stride we just modify the pointer
QGradDataType
*
dq_ptr
=
reinterpret_cast
<
QGradDataType
*>
(
kargs
.
dq_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_dq
+
batch_offset_dq
;
// dQAcc/dQ DRAM and DRAM window
const
auto
dq_acc_dram
=
[
&
,
i_nhead_
=
i_nhead
]()
{
if
constexpr
(
kIsDeterministic
)
{
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq_acc
)
+
batch_offset_dq_acc
;
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
nsplits
,
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
split_stride_dq_acc
,
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
1
>
{},
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
else
{
const
AccDataType
*
dq_acc_ptr
=
reinterpret_cast
<
const
AccDataType
*>
(
kargs
.
dq_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
(
kargs
.
nhead_stride_dq_acc
)
+
batch_offset_dq_acc
;
auto
dq_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq_acc
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGradAcc
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_acc_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}
}();
auto
dq_dram
=
[
&
]()
{
auto
dq_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dq_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_dq
,
1
),
number
<
FmhaBwdConvertQGrad
::
kAlignmentQGrad
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
dq_dram_naive
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
sequence
<
kPadSeqLenQ
,
kPadHeadDimQ
>
{});
}();
auto
dq_acc_dram_window
=
[
&
]()
{
if
constexpr
(
kIsDeterministic
)
{
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
1
>
{},
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
0
,
i_m0
,
0
});
}
else
{
return
make_tile_window
(
dq_acc_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
i_m0
,
0
});
}
}();
auto
dq_dram_window
=
make_tile_window
(
dq_dram
,
make_tuple
(
number
<
kM0
>
{},
number
<
kQKHeaddim
>
{}),
{
i_m0
,
0
});
if
constexpr
(
kIsDeterministic
)
{
const
index_t
nsplits
=
ck_tile
::
integer_divide_ceil
(
kargs
.
seqlen_k
,
kN0
);
FmhaBwdConvertQGrad
{}(
dq_acc_dram_window
,
dq_dram_window
,
nsplits
);
}
else
{
FmhaBwdConvertQGrad
{}(
dq_acc_dram_window
,
dq_dram_window
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp
deleted
100644 → 0
View file @
536c5458
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaBwdTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_k_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_k_
,
kN0
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_k*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
template
<
ck_tile
::
index_t
kBlockSize
>
struct
FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kBlockSize
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
)
{
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_block
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
f6ceef78
...
...
@@ -86,7 +86,7 @@ struct FmhaFwdKernel
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
...
...
@@ -387,7 +387,6 @@ struct FmhaFwdKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
...
...
@@ -448,7 +447,6 @@ struct FmhaFwdKernel
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
if
constexpr
(
kDoFp8StaticQuant
)
{
...
...
@@ -524,7 +522,7 @@ struct FmhaFwdKernel
}
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
batch_offset_lse
=
query_start
;
}
if
constexpr
(
kHasDropout
)
{
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
f6ceef78
...
...
@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
...
...
@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
...
...
@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
{
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
};
struct
GroupModeKargs
...
...
@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
batch_stride_o
};
batch_stride_o
,
batch_stride_lse_acc
};
if
constexpr
(
kStoreLSE
)
{
...
...
@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
batch_stride_lse
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
)
{
...
...
@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
...
...
@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
{
kargs
.
lse_ptr
=
lse_ptr
;
kargs
.
nhead_stride_lse
=
nhead_stride_lse
;
kargs
.
batch_stride_lse
=
batch_stride_lse
;
}
if
constexpr
(
kDoFp8StaticQuant
)
{
...
...
@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
long_index_t
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
if
constexpr
(
kIsGroupMode
)
{
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_lse_acc
=
query_start
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
query_start
;
}
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
...
...
@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
}
// for simplicity, batch stride we just modify the pointer
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
f6ceef78
...
...
@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kHasDropout
?
"_dropout"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
);
#undef _SS_
#undef _TS_
...
...
@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
...
...
@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_lse_acc
;
};
struct
GroupModeKargs
...
...
@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
...
...
@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel
{},
// placeholder for dropout
batch_stride_q
,
batch_stride_k
,
batch_stride_v
};
batch_stride_v
,
batch_stride_lse_acc
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
...
...
@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_randval
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
...
...
@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
...
...
@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_randval
=
0
;
const
long_index_t
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
long_index_t
batch_offset_lse_acc
=
0
;
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
...
...
@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_lse_acc
=
query_start
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
...
...
@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel
}
else
{
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_q
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_q
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
batch_offset_bias
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_bias
;
...
...
Prev
1
2
3
4
5
6
7
8
9
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment