Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4769425e
Commit
4769425e
authored
May 23, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into gelu
parents
b548c0be
ba58a93f
Changes
75
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1845 additions
and
585 deletions
+1845
-585
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
...sor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
+150
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+143
-174
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+19
-46
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+974
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+16
-51
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+16
-52
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+16
-52
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+19
-51
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+20
-51
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+19
-51
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+108
-0
include/ck/utility/common_header.hpp
include/ck/utility/common_header.hpp
+1
-1
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+41
-1
include/ck/utility/generic_memory_space_atomic.hpp
include/ck/utility/generic_memory_space_atomic.hpp
+97
-0
include/ck/utility/get_id.hpp
include/ck/utility/get_id.hpp
+4
-0
include/ck/utility/statically_indexed_array_multi_index.hpp
include/ck/utility/statically_indexed_array_multi_index.hpp
+7
-0
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+3
-0
library/include/ck/library/host_tensor/device.hpp
library/include/ck/library/host_tensor/device.hpp
+23
-7
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp
...e_tensor_operation/cpu/reference_conv_backward_weight.hpp
+168
-47
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt
...operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt
+1
-1
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
0 → 100644
View file @
4769425e
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
GridwiseBinEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
>
__global__
void
kernel_binary_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
GridDesc_M0
c_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
GridwiseBinEltwise
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_grid_desc_m0
,
b_grid_desc_m0
,
c_grid_desc_m0
,
functor
);
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ComputeDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
,
index_t
ScalarPerVector
>
struct
GridwiseBinaryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m0
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScalarPerVector
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
ScalarPerVector
);
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
GridDesc_M0
c_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m0
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m0
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_grid_desc_m0
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
ScalarPerVector
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
ScalarPerVector
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
ScalarPerVector
,
true
>
c_thread_buf
;
const
auto
thread_store_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ComputeDataType
,
GridDesc_M0
,
decltype
(
thread_desc_m0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
ScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m0
,
thread_store_global_offset
};
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
ComputeDataType
,
GridDesc_M0
,
decltype
(
thread_desc_m0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
ScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
b_grid_desc_m0
,
thread_store_global_offset
};
auto
c_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
CDataType
,
decltype
(
thread_desc_m0
),
GridDesc_M0
,
PassThrough
,
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
c_grid_desc_m0
,
thread_store_global_offset
,
PassThrough
{}};
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
m0
=
c_grid_desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
ScalarPerVector
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
m0
/
(
loop_step
);
do
{
// read and process ScalarPerVector elements
a_global_load
.
Run
(
a_grid_desc_m0
,
a_global_buf
,
thread_desc_m0
,
make_tuple
(
I0
),
a_thread_buf
);
b_global_load
.
Run
(
b_grid_desc_m0
,
b_global_buf
,
thread_desc_m0
,
make_tuple
(
I0
),
b_thread_buf
);
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m0
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
c_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}),
b_thread_buf
(
Number
<
offset
>
{}));
});
c_global_write
.
Run
(
thread_desc_m0
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
c_thread_buf
,
c_grid_desc_m0
,
c_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m0
,
loop_step_index
);
b_global_load
.
MoveSrcSliceWindow
(
b_grid_desc_m0
,
loop_step_index
);
c_global_write
.
MoveDstSliceWindow
(
c_grid_desc_m0
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
4769425e
...
...
@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -15,11 +16,12 @@ namespace ck {
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatD
,
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D1ElementwiseOperation
,
typename
DxsInElementwiseOperation
,
typename
DxsOutElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -34,12 +36,12 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatD
*
__restrict__
p_d0_grid
,
FloatD
*
__restrict__
p_d1_grid
,
DPtrsGlobal
p_ds_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
D1ElementwiseOperation
d1_element_op
,
const
DxsInElementwiseOperation
dxs_in_element_op
,
const
DxsOutElementwiseOperation
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -53,13 +55,13 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_d0_grid
,
p_d1_grid
,
p_ds_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
d1_element_op
,
dxs_in_element_op
,
dxs_out_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -69,12 +71,12 @@ __global__ void
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_d0_grid
;
ignore
=
p_d1_grid
;
ignore
=
p_ds_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
d1_element_op
;
ignore
=
dxs_in_element_op
;
ignore
=
dxs_out_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -88,15 +90,15 @@ template <typename FloatAB,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatReduceAcc
,
typename
FloatD
,
typename
DPtrsGlobal
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D
0
ReduceOperation
,
typename
D
1Reduc
eOperation
,
typename
D
1
ElementwiseOperation
,
typename
D
xs
ReduceOperation
,
typename
D
xsInElementwis
eOperation
,
typename
D
xsOut
ElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
DGlobalMemoryDataOperation
,
typename
DGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
...
...
@@ -217,10 +219,12 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
...
...
@@ -248,21 +252,15 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
false
;
}
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
...
...
@@ -308,40 +306,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
// FIXME: remove
constexpr
auto
M01
=
I1
;
constexpr
auto
N01
=
I1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
...
...
@@ -357,13 +323,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatD
*
__restrict__
p_d0_grid
,
FloatD
*
__restrict__
p_d1_grid
,
DPtrsGlobal
p_ds_grid
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
D1ElementwiseOperation
&
d1_element_op
,
const
DxsInElementwiseOperation
&
dxs_in_element_op
,
const
DxsOutElementwiseOperation
&
dxs_out_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
...
...
@@ -377,15 +343,19 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
auto
d0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d0_grid
,
d_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
auto
d1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d1_grid
,
d_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
@@ -527,7 +497,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C
and
write out
// shuffle C
+ reduction +
write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
...
...
@@ -666,6 +636,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
// TODO: this should be implemented as a blockwise reduction
// LDS c_reduce_block_desc_mperblock_nperblock
constexpr
auto
c_reduce_block_desc_mperblock_nperblock
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -716,16 +709,9 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
auto
d_reduce_thread_desc_mblock_mperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
// TODO: this should be implemented as a blockwise reduction
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
d1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
// reduce: threadwise copy from LDS to VGPR
constexpr
auto
c_reduce_thread_cluster_desc
=
make_cluster_descriptor
(
CReduceThreadClusterLengths_MPerBlock_NPerBlock
{},
Sequence
<
1
,
0
>
{});
...
...
@@ -749,47 +735,29 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
1
,
true
>
{
c_reduce_block_desc_mperblock_nperblock
,
c_reduce_thread_data_idx_begin
};
// reduce: copy from VGPR to global
auto
d0_reduce_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatReduceAcc
,
FloatD
,
decltype
(
d_reduce_thread_desc_mblock_mperblock
),
decltype
(
d_grid_desc_mblock_mperblock
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
mreduce_per_thread
>
,
Sequence
<
0
,
1
>
,
1
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
DGlobalMemoryDataOperation
,
1
,
false
>
{
d_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
c_reduce_thread_data_idx_begin
[
I0
]),
// mperblock
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
d1_reduce_thread_copy_vgpr_to_global
=
d0_reduce_thread_copy_vgpr_to_global
;
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
auto
dxs_reduce_thread_copy_vgpr_to_global
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
p_d_grid
=
p_ds_grid
[
I
];
auto
d_out_element_op
=
dxs_out_element_op
[
I
];
return
ThreadwiseTensorSliceTransfer_v1r3
<
FloatReduceAcc
,
remove_pointer_t
<
decltype
(
p_d_grid
)
>
,
decltype
(
d_reduce_thread_desc_mblock_mperblock
),
decltype
(
d_grid_desc_mblock_mperblock
),
decltype
(
d_out_element_op
),
Sequence
<
1
,
mreduce_per_thread
>
,
Sequence
<
0
,
1
>
,
1
,
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
DGlobalMemoryDataOperation
::
At
(
I
),
1
,
false
>
{
d_grid_desc_mblock_mperblock
,
make_multi_index
(
block_work_idx
[
I0
],
// mblock
c_reduce_thread_data_idx_begin
[
I0
]),
// mperblock
d_out_element_op
};
},
Number
<
p_ds_grid
.
Size
()
>
{});
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
...
...
@@ -816,64 +784,73 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
using
ThreadwiseReduce_D0
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_reduce_thread_desc_mperblock
),
D0ReduceOperation
,
false
>
;
using
ThreadwiseReduce_D1
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_reduce_thread_desc_mperblock
),
D1ReduceOperation
,
false
>
;
const
auto
d0_zeroVal
=
D0ReduceOperation
::
GetReductionZeroVal
();
const
auto
d1_zeroVal
=
D0ReduceOperation
::
GetReductionZeroVal
();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d0_thread_buf
(
I
)
=
d0_zeroVal
;
});
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d1_thread_buf
(
I
)
=
d1_zeroVal
;
});
// reduce
// TODO - extract following into reduction_blockwise
{
// copy from LDS to VGPR
c_reduce_thread_copy_lds_to_vgpr
.
Run
(
c_reduce_block_desc_mperblock_nperblock
,
c_shuffle_block_buf
,
c_reduce_thread_desc_mperblock_nperblock
,
make_tuple
(
I0
,
I0
),
c_reduce_thread_buf
);
// reduce in VGPR
ThreadwiseReduce_D0
::
Reduce
(
c_reduce_thread_buf
,
d0_thread_buf
)
;
static_for
<
0
,
p_ds_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
auto
&
p_d_grid
=
p_ds_grid
[
In
]
;
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
static_for
<
0
,
nreduce_per_thread
,
1
>
{}([
&
](
auto
in
)
{
constexpr
auto
offset
=
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
make_tuple
(
im
,
in
))
>
{};
auto
d_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d_grid
,
d_grid_desc_mblock_mperblock
.
GetElementSpaceSize
());
d1_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
});
});
auto
d_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatReduceAcc
>
(
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
&
d_in_element_op
=
dxs_in_element_op
[
In
];
auto
&
d_reduce_thread_copy_vgpr_to_global
=
dxs_reduce_thread_copy_vgpr_to_global
(
In
);
ThreadwiseReduce_D1
::
Reduce
(
c_reduce_thread_buf
,
d1_thread_buf
);
using
DReduceOperation
=
remove_cvref_t
<
decltype
(
DxsReduceOperation
{}[
In
])
>
;
using
ThreadwiseReduce
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_reduce_thread_desc_mperblock
),
DReduceOperation
,
false
>
;
// copy from VGPR to Global
d0_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
d0_thread_buf
,
d_grid_desc_mblock_mperblock
,
d0_grid_buf
);
// Global write Gemm shuffle + reduction
const
auto
d_zeroVal
=
DReduceOperation
::
GetReductionZeroVal
();
d1_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
d1_thread_buf
,
d_grid_desc_mblock_mperblock
,
d1_grid_buf
);
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d_thread_buf
(
I
)
=
d_zeroVal
;
});
// reduce in VGPR
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
static_for
<
0
,
nreduce_per_thread
,
1
>
{}([
&
](
auto
in
)
{
constexpr
auto
offset
=
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
make_tuple
(
im
,
in
))
>
{};
d_in_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
});
});
ThreadwiseReduce
::
Reduce
(
c_reduce_thread_buf
,
d_thread_buf
);
// copy from VGPR to Global
d_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
d_thread_buf
,
d_grid_desc_mblock_mperblock
,
d_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
d_reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
}
});
}
if
constexpr
(
access_id
<
num_access
-
1
)
...
...
@@ -883,18 +860,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
// move on D0
d0_reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
// move on D1
d1_reduce_thread_copy_vgpr_to_global
.
MoveDstSliceWindow
(
d_grid_desc_mblock_mperblock
,
make_tuple
(
c_global_step
[
I0
],
c_global_step
[
I1
]));
}
});
// Reduction
}
}
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
4769425e
...
...
@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -190,10 +191,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
...
@@ -217,21 +220,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
false
;
}
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
...
...
@@ -262,40 +259,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
// FIXME: remove
constexpr
auto
M01
=
I1
;
constexpr
auto
N01
=
I1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
...
...
@@ -329,6 +294,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
0 → 100644
View file @
4769425e
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace
ck
{
// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template
<
typename
LowLengths
>
struct
Merge_v4_no_carry
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
LowLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v4_no_carry
()
=
default
;
__host__
__device__
constexpr
Merge_v4_no_carry
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
// division and mod
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_up_diff
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
INm1
=
Number
<
NDimLow
-
1
>
{};
index_t
tmp
=
idx_up_new
[
I0
];
idx_low
(
INm1
)
=
tmp
;
idx_diff_low
(
INm1
)
=
idx_up_diff
[
I0
];
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsScan
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v3_direct_division_mod_wrw, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_scan_ "
);
print_multi_index
(
low_lengths_scan_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v4_no_carry
(
const
LowLengths
&
low_lengths
)
{
return
Merge_v4_no_carry
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_B_K0_M_K1
,
typename
BGridDesc_B_K0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CBlockClusterAdaptor
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_bwd_weight
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_B_K0_M_K1
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
b_element_op
,
c_element_op
,
c_block_cluster_adaptor
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_b_k0_m_k1_grid_desc
;
ignore
=
b_b_k0_n_k1_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_block_cluster_adaptor
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_B_K0_M_K1
,
typename
BGridDesc_B_K0_N_K1
,
typename
CMNGridDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
index_t
ABlockLdsM1PerBlock
,
index_t
ABlockLdsM0PerBlock
,
index_t
ABlockLdsM1Padding
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
index_t
BBlockLdsN1PerBlock
,
index_t
BBlockLdsN0PerBlock
,
index_t
BBlockLdsN1Padding
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
ABlockLdsExtraM1Wrw
=
false
,
bool
BBlockLdsExtraN1Wrw
=
false
,
index_t
NumGemmKPrefetchStage
=
1
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
// M0/M1/M1Padding
static
constexpr
auto
M1PerBlock
=
Number
<
ABlockLdsM1PerBlock
>
{};
static
constexpr
auto
M0PerBlock
=
Number
<
ABlockLdsM0PerBlock
>
{};
static
constexpr
auto
M1Padding
=
Number
<
ABlockLdsM1Padding
>
{};
// N0/N1/N1Padding
static
constexpr
auto
N1PerBlock
=
Number
<
BBlockLdsN1PerBlock
>
{};
static
constexpr
auto
N0PerBlock
=
Number
<
BBlockLdsN0PerBlock
>
{};
static
constexpr
auto
N1Padding
=
Number
<
BBlockLdsN1Padding
>
{};
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
if
constexpr
(
ABlockLdsExtraM1Wrw
)
{
constexpr
auto
a_block_desc_k0_m0_m1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
M0PerBlock
>
{}
*
(
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
),
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
,
K1
,
I1
));
constexpr
auto
a_block_desc_k0_m_k1_tmp
=
transform_tensor_descriptor
(
a_block_desc_k0_m0_m1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_block_desc_k0_m_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
a_block_desc_k0_m_k1
;
}
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_b_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
if
constexpr
(
ABlockLdsExtraM1Wrw
)
{
constexpr
auto
a_block_desc_b_k0_m0_m1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
M0PerBlock
>
{}
*
(
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
),
Number
<
M0PerBlock
>
{}
*
(
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
),
Number
<
M1PerBlock
>
{}
*
K1
+
M1Padding
,
K1
,
I1
));
constexpr
auto
a_block_desc_b_k0_m_k1_tmp
=
transform_tensor_descriptor
(
a_block_desc_b_k0_m0_m1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
1
>
{}),
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v4_no_carry
(
make_tuple
(
Number
<
M0PerBlock
>
{},
Number
<
M1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
a_block_desc_b_k0_m_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
MPerBlock
+
1
>
{}
*
K1
,
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
a_block_desc_b_k0_m_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
if
constexpr
(
BBlockLdsExtraN1Wrw
)
{
constexpr
auto
b_block_desc_k0_n0_n1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
),
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
,
K1
,
I1
));
constexpr
auto
b_block_desc_k0_n_k1_tmp
=
transform_tensor_descriptor
(
b_block_desc_k0_n0_n1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_block_desc_k0_n_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
b_block_desc_k0_n_k1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_b_k0_n_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
if
constexpr
(
BBlockLdsExtraN1Wrw
)
{
constexpr
auto
b_block_desc_b_k0_n0_n1_k1
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
),
Number
<
N0PerBlock
>
{}
*
(
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
),
Number
<
N1PerBlock
>
{}
*
K1
+
N1Padding
,
K1
,
I1
));
constexpr
auto
b_block_desc_b_k0_n_k1_tmp
=
transform_tensor_descriptor
(
b_block_desc_b_k0_n0_n1_k1
,
make_tuple
(
make_pass_through_transform
(
Number
<
1
>
{}),
make_pass_through_transform
(
Number
<
K0PerBlock
>
{}),
make_merge_transform_v4_no_carry
(
make_tuple
(
Number
<
N0PerBlock
>
{},
Number
<
N1PerBlock
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
b_block_desc_b_k0_n_k1_tmp
;
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
NPerBlock
+
1
>
{}
*
K1
,
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
b_block_desc_b_k0_n_k1
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_b_k0_m_k1_block_desc
=
GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_b_k0_n_k1_block_desc
=
GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_b_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_b_k0_n_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
c_block_size
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
().
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatC
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
)
&&
(
NPerBlock
%
(
NRepeat
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I2
);
const
auto
N
=
b_b_k0_n_k1_grid_desc
.
GetLength
(
I2
);
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
const
auto
KBatch
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I0
);
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
if
(
!
(
M
==
c_m_n_grid_desc
.
GetLength
(
I0
)
&&
N
==
c_m_n_grid_desc
.
GetLength
(
I1
)
&&
K0
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I1
)
&&
K1
==
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
)
&&
K1
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I3
)
&&
KBatch
==
b_b_k0_n_k1_grid_desc
.
GetLength
(
I0
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_m_n_grid_desc
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
// const bool has_main_k0_block_loop = K0 > K0PerBlock;
const
index_t
num_loop
=
K0
/
K0PerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
// return has_main_k0_block_loop;
}
__host__
__device__
static
constexpr
auto
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CMNGridDesc
&
c_m_n_grid_desc
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
return
transform_tensor_descriptor
(
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
,
index_t
KBatch
)
{
return
BlockToCTileMap_KSplit_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CMNGridDesc
>
(
c_m_n_grid_desc
,
M01
,
N01
,
KBatch
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
}
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CMNGridDesc
{}));
using
CBlockClusterAdaptor
=
decltype
(
MakeCBlockClusterAdaptor
(
CMNGridDesc
{},
1
,
1
,
1
));
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CBlockClusterAdaptor
&
c_block_cluster_adaptor
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_b_k0_n_k1_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
K0
=
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
);
// divide block work by [M, N]
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
if
(
!
c_block_cluster_adaptor
.
ValidCTileIndex
(
make_tuple
(
block_work_idx
[
I1
],
block_work_idx
[
I2
]),
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
a_b_k0_m_k1_block_desc
=
GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_b_k0_n_k1_block_desc
=
GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
ABlockTransferSrcVectorDim
,
3
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_b_k0_m_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_b_k0_m_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
BBlockTransferSrcVectorDim
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_b_k0_n_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_b_k0_n_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
K1
,
MfmaSelector
<
FloatAB
,
MPerXDL
,
NPerXDL
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block
,
a_k0_m_k1_block_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block
,
b_k0_n_k1_block_desc
.
GetElementSpaceSize
());
// gridwise GEMM pipeline
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_block_desc
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_b_k0_n_k1_grid_desc
,
b_b_k0_n_k1_block_desc
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
K0BlockMainLoop
);
// output: register to global memory
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I7
);
constexpr
auto
c_block_desc_mblock_mperblock_nblock_nperblock
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatC
*>
(
p_shared_block
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
static_assert
(
M1
==
MWave
,
""
);
static_assert
(
N1
==
NWave
,
""
);
static_assert
(
M2
*
M3
*
M4
==
MPerXDL
,
""
);
static_assert
(
N2
==
NPerXDL
,
""
);
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_unmerge_transform
(
make_tuple
(
CShuffleMRepeatPerShuffle
,
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXDL
make_freeze_transform
(
I0
),
// freeze nblock
make_unmerge_transform
(
make_tuple
(
CShuffleNRepeatPerShuffle
,
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXDL
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// LDS to global
auto
c_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
1
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatC
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun
{
c_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I1
],
0
,
block_work_idx
[
I2
],
0
),
c_element_op
};
constexpr
auto
mxdlperwave_forward_step
=
make_multi_index
(
0
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
0
,
0
);
constexpr
auto
nxdlperwave_forward_step
=
make_multi_index
(
0
,
0
,
0
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
);
constexpr
auto
nxdlperwave_backward_step
=
make_multi_index
(
0
,
0
,
0
,
-
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
);
static_for
<
0
,
MRepeat
,
CShuffleMRepeatPerShuffle
>
{}([
&
](
auto
mxdlperwave_iter
)
{
constexpr
auto
mxdlperwave
=
mxdlperwave_iter
;
static_for
<
0
,
NRepeat
,
CShuffleNRepeatPerShuffle
>
{}([
&
](
auto
nxdlperwave_iter
)
{
constexpr
bool
nxdlperwave_forward_sweep
=
(
mxdlperwave
%
(
2
*
CShuffleMRepeatPerShuffle
)
==
0
);
constexpr
index_t
nxdlperwave_value
=
nxdlperwave_forward_sweep
?
nxdlperwave_iter
:
(
NRepeat
-
nxdlperwave_iter
-
CShuffleNRepeatPerShuffle
);
constexpr
auto
nxdlperwave
=
Number
<
nxdlperwave_value
>
{};
// make sure it's safe to do ds_write
block_sync_lds
();
// VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
mxdlperwave
,
nxdlperwave
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_buf
);
// make sure it's safe to do ds_read
block_sync_lds
();
// LDS to global
c_block_copy_lds_to_global
.
Run
(
c_block_desc_mblock_mperblock_nblock_nperblock
,
c_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
// move on nxdlperwave dimension
if
constexpr
(
nxdlperwave_forward_sweep
&&
(
nxdlperwave
<
NRepeat
-
CShuffleNRepeatPerShuffle
))
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
nxdlperwave_forward_step
);
}
else
if
constexpr
((
!
nxdlperwave_forward_sweep
)
&&
(
nxdlperwave
>
0
))
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
nxdlperwave_backward_step
);
}
});
// move on mxdlperwave dimension
if
constexpr
(
mxdlperwave
<
MRepeat
-
CShuffleMRepeatPerShuffle
)
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
mxdlperwave_forward_step
);
}
});
}
}
};
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
4769425e
...
...
@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
...
...
@@ -185,12 +186,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -219,31 +220,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
false
;
}
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
...
@@ -305,36 +290,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
,
M01
,
N01
);
}
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
...
...
@@ -368,6 +325,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
),
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
4769425e
...
...
@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
...
...
@@ -167,12 +168,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -196,31 +197,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_m_n_grid_desc
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
)
*
KBatch
;
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
...
...
@@ -282,37 +267,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
KBatch
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
KBatch
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_kbatch_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_kbatch_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_KSplit_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CMNGridDesc
>
(
c_m_n_grid_desc
,
M01
,
N01
,
KBatch
);
}
using
CM0N0M1N1M2M3M4N2GridDesc
=
decltype
(
MakeCM0N0M1N1M2M3M4N2GridDescriptor
(
CMNGridDesc
{}));
...
...
@@ -344,6 +300,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
c_block_cluster_adaptor
.
ValidCTileIndex
(
make_tuple
(
block_work_idx
[
I1
],
block_work_idx
[
I2
]),
make_tuple
(
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetLength
(
I0
),
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
.
GetLength
(
I1
))))
{
return
;
}
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
4769425e
...
...
@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -174,12 +175,12 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -203,31 +204,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_m_n_grid_desc
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
)
*
KBatch
;
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
K0
>
K0PerBlock
;
...
...
@@ -256,37 +241,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__host__
__device__
static
constexpr
auto
MakeCBlockClusterAdaptor
(
const
CMNGridDesc
&
c_m_n_grid_desc
,
index_t
M01
,
index_t
N01
,
index_t
KBatch
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
KBatch
),
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
,
4
>
{}));
const
auto
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
KBatch
,
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_blockid_to_kbatch_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor
);
return
c_blockid_to_kbatch_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_KSplit_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CMNGridDesc
>
(
c_m_n_grid_desc
,
M01
,
N01
,
KBatch
);
}
__host__
__device__
static
constexpr
auto
...
...
@@ -333,6 +289,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
const
auto
block_work_idx
=
c_block_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
c_block_cluster_adaptor
.
ValidCTileIndex
(
make_tuple
(
block_work_idx
[
I1
],
block_work_idx
[
I2
]),
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
k_batch_id
=
block_work_idx
[
I0
];
// HACK: this force m/n_block_data_idx_on_grid into SGPR
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
4769425e
...
...
@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r1.hpp"
...
...
@@ -223,12 +224,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
...
...
@@ -256,31 +257,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return
false
;
}
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
...
...
@@ -318,36 +303,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
,
M01
,
N01
);
}
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
...
...
@@ -385,6 +342,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I0
),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I3
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
4769425e
...
...
@@ -5,6 +5,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r2.hpp"
...
...
@@ -230,12 +231,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -264,31 +265,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
return
false
;
}
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
...
@@ -327,37 +312,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
,
M01
,
N01
);
}
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
...
...
@@ -408,6 +366,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I0
),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I3
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
4769425e
...
...
@@ -3,6 +3,7 @@
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "thread_group_tensor_slice_transfer_v6r3.hpp"
...
...
@@ -237,12 +238,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
...
...
@@ -271,31 +272,15 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
return
false
;
}
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
(
K0PerBlock
*
K1
);
...
...
@@ -334,36 +319,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
return
BlockToCTileMap_M00_N00_M01_N01
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
,
M01
,
N01
);
}
using
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
=
remove_cvref_t
<
decltype
(
...
...
@@ -427,6 +384,17 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I0
),
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
.
GetLength
(
I3
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
4769425e
...
...
@@ -258,6 +258,14 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
index_t
soffset
,
index_t
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fadd.f32"
);
// buffer atomic-add fp32
__device__
double
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
double
vdata
,
int32x4_t
rsrc
,
// dst_wave_buffer_resource
int
voffset
,
// dst_thread_addr_offset
int
soffset
,
// dst_wave_addr_offset
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
template
<
typename
T
,
index_t
N
>
__device__
typename
vector_type
<
T
,
N
>::
type
amd_buffer_load_impl
(
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
...
...
@@ -915,6 +923,71 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type<T, N>::typ
}
}
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_atomic_max_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
)
{
static_assert
((
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
double
>::
value
)
{
if
constexpr
(
N
==
1
)
{
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
}
else
if
constexpr
(
N
==
2
)
{
vector_type
<
double
,
2
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
double
),
0
);
}
else
if
constexpr
(
N
==
4
)
{
vector_type
<
double
,
4
>
tmp
{
src_thread_data
};
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
0
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
sizeof
(
double
),
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
2
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
2
*
sizeof
(
double
),
0
);
llvm_amdgcn_raw_buffer_atomic_max_fp64
(
tmp
.
AsType
<
double
>
()[
Number
<
3
>
{}],
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
3
*
sizeof
(
double
),
0
);
}
}
}
// buffer_load requires:
// 1) p_src_wave must point to global memory space
// 2) p_src_wave must be a wavewise pointer.
...
...
@@ -1046,4 +1119,39 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
#endif
}
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
// It is user's responsibility to make sure that is true.
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_atomic_max
(
const
typename
vector_type_maker
<
T
,
N
>::
type
::
type
src_thread_data
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
)
{
const
int32x4_t
dst_wave_buffer_resource
=
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
);
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
using
vector_t
=
typename
vector_type_maker
<
T
,
N
>::
type
::
type
;
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x7fffffff
;
amd_buffer_atomic_max_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#else
if
(
dst_thread_element_valid
)
{
amd_buffer_atomic_max_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
}
}
// namespace ck
include/ck/utility/common_header.hpp
View file @
4769425e
...
...
@@ -32,7 +32,7 @@
#include "debug.hpp"
#include "amd_buffer_addressing.hpp"
#include "generic_memory_space_atomic
_add
.hpp"
#include "generic_memory_space_atomic.hpp"
#include "get_id.hpp"
#include "synchronization.hpp"
#include "amd_address_space.hpp"
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
4769425e
...
...
@@ -3,7 +3,7 @@
#include "enable_if.hpp"
#include "c_style_pointer_cast.hpp"
#include "amd_buffer_addressing.hpp"
#include "generic_memory_space_atomic
_add
.hpp"
#include "generic_memory_space_atomic.hpp"
namespace
ck
{
...
...
@@ -125,6 +125,10 @@ struct DynamicBuffer
{
this
->
template
AtomicAdd
<
X
>(
i
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
InMemoryDataOperationEnum
::
AtomicMax
)
{
this
->
template
AtomicMax
<
X
>(
i
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
InMemoryDataOperationEnum
::
Add
)
{
auto
tmp
=
this
->
template
Get
<
X
>(
i
,
is_valid_element
);
...
...
@@ -326,6 +330,42 @@ struct DynamicBuffer
}
}
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
AtomicMax
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
scalar_type
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
static_assert
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
,
"only support global mem"
);
#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
using
scalar_t
=
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
;
bool
constexpr
use_amd_buffer_addressing
=
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
double
>
;
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
if
constexpr
(
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
element_space_size_
);
}
else
if
(
is_valid_element
)
{
atomic_max
<
X
>
(
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
]),
x
);
}
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
...
...
include/ck/utility/generic_memory_space_atomic
_add
.hpp
→
include/ck/utility/generic_memory_space_atomic.hpp
View file @
4769425e
...
...
@@ -3,6 +3,10 @@
namespace
ck
{
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_add explicit for
// each datatype.
template
<
typename
X
>
__device__
X
atomic_add
(
X
*
p_dst
,
const
X
&
x
);
...
...
@@ -41,4 +45,53 @@ __device__ float2_t atomic_add<float2_t>(float2_t* p_dst, const float2_t& x)
return
vy
.
template
AsType
<
float2_t
>()[
I0
];
}
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation failure when trying to
// instantiate this template. The purpose is to make the implementation of atomic_max explicit for
// each datatype.
template
<
typename
X
>
__device__
X
atomic_max
(
X
*
p_dst
,
const
X
&
x
);
template
<
>
__device__
int32_t
atomic_max
<
int32_t
>
(
int32_t
*
p_dst
,
const
int32_t
&
x
)
{
return
atomicMax
(
p_dst
,
x
);
}
template
<
>
__device__
uint32_t
atomic_max
<
uint32_t
>
(
uint32_t
*
p_dst
,
const
uint32_t
&
x
)
{
return
atomicMax
(
p_dst
,
x
);
}
template
<
>
__device__
float
atomic_max
<
float
>
(
float
*
p_dst
,
const
float
&
x
)
{
return
atomicMax
(
p_dst
,
x
);
}
template
<
>
__device__
double
atomic_max
<
double
>
(
double
*
p_dst
,
const
double
&
x
)
{
return
atomicMax
(
p_dst
,
x
);
}
template
<
>
__device__
float2_t
atomic_max
<
float2_t
>
(
float2_t
*
p_dst
,
const
float2_t
&
x
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
const
vector_type
<
float
,
2
>
vx
{
x
};
vector_type
<
float
,
2
>
vy
{
0
};
vy
.
template
AsType
<
float
>()(
I0
)
=
atomicMax
(
c_style_pointer_cast
<
float
*>
(
p_dst
),
vx
.
template
AsType
<
float
>()[
I0
]);
vy
.
template
AsType
<
float
>()(
I1
)
=
atomicMax
(
c_style_pointer_cast
<
float
*>
(
p_dst
)
+
1
,
vx
.
template
AsType
<
float
>()[
I1
]);
return
vy
.
template
AsType
<
float2_t
>()[
I0
];
}
}
// namespace ck
include/ck/utility/get_id.hpp
View file @
4769425e
...
...
@@ -11,10 +11,14 @@ __host__ __device__ constexpr index_t get_warp_size()
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_thread_global_1d_id
()
{
return
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
}
__device__
index_t
get_warp_local_1d_id
()
{
return
threadIdx
.
x
/
get_warp_size
();
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_grid_size
()
{
return
gridDim
.
x
;
}
__device__
index_t
get_block_size
()
{
return
blockDim
.
x
;
}
}
// namespace ck
include/ck/utility/statically_indexed_array_multi_index.hpp
View file @
4769425e
...
...
@@ -93,6 +93,13 @@ __host__ __device__ constexpr auto operator*(index_t a, const Tuple<Xs...>& x)
return
r
;
}
// MultiIndex = MultiIndex * index_t
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
operator
*
(
const
Tuple
<
Xs
...
>&
x
,
index_t
a
)
{
return
a
*
x
;
}
template
<
typename
...
Xs
>
__host__
__device__
void
print_multi_index
(
const
Tuple
<
Xs
...
>&
x
)
{
...
...
include/ck/utility/type.hpp
View file @
4769425e
...
...
@@ -29,6 +29,9 @@ using remove_cv_t = typename std::remove_cv<T>::type;
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
...
...
library/include/ck/library/host_tensor/device.hpp
View file @
4769425e
...
...
@@ -10,6 +10,15 @@
#include "stream_config.hpp"
#include "ck/options.hpp"
template
<
typename
T
>
__global__
void
set_buffer_value
(
T
*
p
,
T
x
,
uint64_t
buffer_element_size
)
{
for
(
uint64_t
i
=
threadIdx
.
x
;
i
<
buffer_element_size
;
i
+=
blockDim
.
x
)
{
p
[
i
]
=
x
;
}
}
inline
void
hip_check_error
(
hipError_t
x
)
{
if
(
x
!=
hipSuccess
)
...
...
@@ -30,6 +39,16 @@ struct DeviceMem
void
ToDevice
(
const
void
*
p
);
void
FromDevice
(
void
*
p
);
void
SetZero
();
template
<
typename
T
>
void
SetValue
(
T
x
)
{
if
(
mMemSize
%
sizeof
(
T
)
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! not entire DeviceMem will be set"
);
}
set_buffer_value
<
T
><<<
1
,
1024
>>>
(
static_cast
<
T
*>
(
mpDeviceBuf
),
x
,
mMemSize
/
sizeof
(
T
));
}
~
DeviceMem
();
void
*
mpDeviceBuf
;
...
...
@@ -74,8 +93,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
printf
(
"Warm up 1 time
\n
"
);
// warm up
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
...
...
@@ -84,8 +102,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
}
timer
.
End
();
...
...
@@ -94,13 +111,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
}
else
{
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
return
0
;
}
#else
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
return
0
;
#endif
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_backward_weight.hpp
View file @
4769425e
#ifndef REFERENCE_CONV_WRW_HPP
#define REFERENCE_CONV_WRW_HPP
#pragma once
#include <iostream>
#include <sstream>
...
...
@@ -16,7 +15,9 @@ template <typename InDataType,
typename
OutDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
typename
OutElementwiseOperation
,
ck
::
index_t
NumDimSpatial
=
2
,
typename
ck
::
enable_if
<
NumDimSpatial
>
=
1
&&
NumDimSpatial
<=
3
,
bool
>::
type
=
false
>
struct
ReferenceConvBwdWeight
:
public
device
::
BaseOperator
{
// Argument
...
...
@@ -32,9 +33,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
:
in
_n_c_hi_wi
_
{
in_n_c_hi_wi
},
wei
_k_c_y_x
_
{
wei_k_c_y_x
},
out
_n_k_ho_wo
_
{
out_n_k_ho_wo
},
:
in
put
_
{
in_n_c_hi_wi
},
wei
ght
_
{
wei_k_c_y_x
},
out
put
_
{
out_n_k_ho_wo
},
conv_strides_
{
conv_filter_strides
},
conv_dilations_
{
conv_filter_dilations
},
in_left_pads_
{
input_left_pads
},
...
...
@@ -45,9 +46,9 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
{
}
const
Tensor
<
InDataType
>&
in
_n_c_hi_wi
_
;
Tensor
<
WeiDataType
>&
wei
_k_c_y_x
_
;
const
Tensor
<
OutDataType
>&
out
_n_k_ho_wo
_
;
const
Tensor
<
InDataType
>&
in
put
_
;
Tensor
<
WeiDataType
>&
wei
ght
_
;
const
Tensor
<
OutDataType
>&
out
put
_
;
std
::
vector
<
index_t
>
conv_strides_
;
std
::
vector
<
index_t
>
conv_dilations_
;
...
...
@@ -66,59 +67,180 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
float
Run
(
const
Argument
&
arg
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
2
];
++
ho
)
if
constexpr
(
NumDimSpatial
==
1
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
auto
f_kcx
=
[
&
](
auto
k
,
auto
c
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
auto
hi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
I0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
I0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
out_n_k_ho_wo_
.
mDesc
.
GetLengths
()[
3
];
++
wo
)
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
wo
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I1
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I1
]);
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
in_n_c_hi_wi_
.
mDesc
.
GetLengths
()[
3
])
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
out_n_k_ho_wo_
(
n
,
k
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
in_n_c_hi_wi_
(
n
,
c
,
hi
,
wi
)));
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
}
}
}
float
v_wei
;
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
wei
_k_c_y_x
_
(
k
,
c
,
y
,
x
)
=
ck
::
type_convert
<
Out
DataType
>
(
v_wei
);
};
arg
.
wei
ght
_
(
k
,
c
,
x
)
=
ck
::
type_convert
<
Wei
DataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kcyx
,
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
0
],
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
1
],
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
2
],
arg
.
wei_k_c_y_x_
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f_kcx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
}
else
if
constexpr
(
NumDimSpatial
==
2
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
auto
f_kcyx
=
[
&
](
auto
k
,
auto
c
,
auto
y
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
ho
)
{
auto
hi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
I0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
I0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
++
wo
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I1
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I1
]);
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
}
}
}
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
weight_
(
k
,
c
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kcyx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
else
if
constexpr
(
NumDimSpatial
==
3
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
auto
f_kczyx
=
[
&
](
auto
k
,
auto
c
,
auto
z
,
auto
y
,
auto
x
)
{
float
v_acc
=
0
;
for
(
std
::
size_t
n
=
0
;
n
<
arg
.
output_
.
mDesc
.
GetLengths
()[
0
];
++
n
)
{
for
(
std
::
size_t
do_
=
0
;
do_
<
arg
.
output_
.
mDesc
.
GetLengths
()[
2
];
++
do_
)
{
auto
di
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
do_
*
arg
.
conv_strides_
[
I0
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
I0
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I0
]);
for
(
std
::
size_t
ho
=
0
;
ho
<
arg
.
output_
.
mDesc
.
GetLengths
()[
3
];
++
ho
)
{
auto
hi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
I1
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
I1
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I1
]);
for
(
std
::
size_t
wo
=
0
;
wo
<
arg
.
output_
.
mDesc
.
GetLengths
()[
4
];
++
wo
)
{
auto
wi
=
ck
::
type_convert
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
I2
])
+
ck
::
type_convert
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
I2
])
-
ck
::
type_convert
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
I2
]);
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
2
]
&&
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
3
]
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
mDesc
.
GetLengths
()[
4
])
{
float
v_out
;
float
v_in
;
arg
.
out_element_op_
(
v_out
,
ck
::
type_convert
<
float
>
(
arg
.
output_
(
n
,
k
,
do_
,
ho
,
wo
)));
arg
.
in_element_op_
(
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
n
,
c
,
di
,
hi
,
wi
)));
v_acc
+=
v_out
*
v_in
;
}
}
}
}
}
float
v_wei
;
arg
.
wei_element_op_
(
v_wei
,
v_acc
);
arg
.
weight_
(
k
,
c
,
z
,
y
,
x
)
=
ck
::
type_convert
<
WeiDataType
>
(
v_wei
);
};
make_ParallelTensorFunctor
(
f_kczyx
,
arg
.
weight_
.
mDesc
.
GetLengths
()[
0
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
1
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
2
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
3
],
arg
.
weight_
.
mDesc
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
return
0
;
}
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
...
...
@@ -182,4 +304,3 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
#endif
library/src/tensor_operation_instance/gpu/batched_gemm_reduce/CMakeLists.txt
View file @
4769425e
set
(
DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE
set
(
DEVICE_BATCHED_GEMM_REDUCE_INSTANCE_SOURCE
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instance.cpp
device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instance.cpp
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment