Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b134b7d6
Commit
b134b7d6
authored
May 16, 2022
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into cpu_avx2
parents
090ba885
9f71ff48
Changes
211
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1010 additions
and
1080 deletions
+1010
-1080
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+168
-129
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+144
-152
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+102
-127
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+88
-199
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+45
-43
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+48
-46
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+76
-105
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+78
-104
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+89
-117
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+9
-10
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
+7
-2
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+4
-4
include/ck/utility/common_header.hpp
include/ck/utility/common_header.hpp
+1
-0
include/ck/utility/get_id.hpp
include/ck/utility/get_id.hpp
+6
-2
include/ck/utility/number.hpp
include/ck/utility/number.hpp
+3
-0
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+6
-0
include/ck/utility/thread_group.hpp
include/ck/utility/thread_group.hpp
+18
-0
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+5
-6
library/include/ck/library/host/host_interface.hpp
library/include/ck/library/host/host_interface.hpp
+54
-0
library/include/ck/library/host_tensor/device.hpp
library/include/ck/library/host_tensor/device.hpp
+59
-34
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
b134b7d6
#ifndef CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
#pragma once
#define CK_GRIDWISE_GEMM_PIPELINE_V1_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
AGridDesc
,
template
<
index_t
NumPrefetch
>
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
index_t
NumPrefetch
,
bool
HasMainLoop
>
struct
GridwiseGemmPipeline_v1
;
struct
GridwiseGemmPipeline_v1
;
// 1-stage prefetch
// 1-stage prefetch
template
<
typename
AGridDesc
,
template
<
>
typename
ABlockDesc
,
struct
GridwiseGemmPipeline_v1
<
1
>
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
bool
HasMainLoop
>
struct
GridwiseGemmPipeline_v1
<
AGridDesc
,
ABlockDesc
,
ABlockTransfer
,
AGridBuffer
,
ABlockBuffer
,
ABlockTransferStep
,
BGridDesc
,
BBlockDesc
,
BBlockTransfer
,
BGridBuffer
,
BBlockBuffer
,
BBlockTransferStep
,
BlockwiseGemm
,
CThreadBuffer
,
1
,
HasMainLoop
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
const
AGridBuffer
&
a_grid_buf
,
...
@@ -75,51 +52,6 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
...
@@ -75,51 +52,6 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
)
{
{
#if 0
// preload data into LDS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
// Initialize C
c_thread_buf.Clear();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
block_sync_lds();
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
++i;
} while(i < (num_loop - 1));
}
// tail
{
block_sync_lds();
blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf);
}
#else
// preload data into LDS
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
...
@@ -166,46 +98,42 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
...
@@ -166,46 +98,42 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
#endif
}
}
};
};
// 2-stage prefetch
// 2-stage prefetch
template
<
typename
AGridDesc
,
template
<
>
typename
ABlockDesc
,
struct
GridwiseGemmPipeline_v1
<
2
>
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
bool
HasMainLoop
>
struct
GridwiseGemmPipeline_v1
<
AGridDesc
,
ABlockDesc
,
ABlockTransfer
,
AGridBuffer
,
ABlockBuffer
,
ABlockTransferStep
,
BGridDesc
,
BBlockDesc
,
BBlockTransfer
,
BGridBuffer
,
BBlockBuffer
,
BBlockTransferStep
,
BlockwiseGemm
,
CThreadBuffer
,
2
,
HasMainLoop
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
// TODO: improve applicability
return
num_loop
%
2
==
0
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
(
num_loop
/
2
)
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
ABlockTransfer
&
a_blockwise_copy
,
...
@@ -321,5 +249,116 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
...
@@ -321,5 +249,116 @@ struct GridwiseGemmPipeline_v1<AGridDesc,
}
}
};
};
template
<
index_t
NumPrefetch
>
struct
GridwiseGemmPipelineInterwave_v1
;
template
<
>
struct
GridwiseGemmPipelineInterwave_v1
<
1
>
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
static
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
// block_sync_lds(); // moved into blockwise_gemm
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
template
<
>
struct
GridwiseGemmPipelineInterwave_v1
<
2
>
:
public
GridwiseGemmPipeline_v1
<
2
>
{
};
template
<
index_t
NumPrefetch
,
LoopScheduler
LoopSched
>
constexpr
auto
GridwiseGemmPipeline_v1_Selector
()
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
return
GridwiseGemmPipeline_v1
<
NumPrefetch
>
{};
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
return
GridwiseGemmPipelineInterwave_v1
<
NumPrefetch
>
{};
}
}
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
b134b7d6
...
@@ -4,10 +4,11 @@
...
@@ -4,10 +4,11 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "reduction_functions_threadwise.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -18,14 +19,13 @@ template <typename GridwiseGemm,
...
@@ -18,14 +19,13 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D1ElementwiseOperation
,
typename
D1ReduceOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
DGridDescriptor_MBlock_MPerBlock
,
typename
DGridDescriptor_MBlock_MPerBlock
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -39,8 +39,7 @@ __global__ void
...
@@ -39,8 +39,7 @@ __global__ void
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
D0ReduceOperation
d0_reduce_op
,
const
D1ElementwiseOperation
d1_element_op
,
const
D1ReduceOperation
d1_reduce_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -51,22 +50,21 @@ __global__ void
...
@@ -51,22 +50,21 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_d0_grid
,
p_d0_grid
,
p_d1_grid
,
p_d1_grid
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
d0_reduce_op
,
d1_element_op
,
d1_reduce_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_mblock_mperblock
,
d_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -76,8 +74,7 @@ __global__ void
...
@@ -76,8 +74,7 @@ __global__ void
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
d0_reduce_op
;
ignore
=
d1_element_op
;
ignore
=
d1_reduce_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
@@ -97,6 +94,7 @@ template <typename FloatAB,
...
@@ -97,6 +94,7 @@ template <typename FloatAB,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
D0ReduceOperation
,
typename
D0ReduceOperation
,
typename
D1ReduceOperation
,
typename
D1ReduceOperation
,
typename
D1ElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
DGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
DGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
...
@@ -136,7 +134,8 @@ template <typename FloatAB,
...
@@ -136,7 +134,8 @@ template <typename FloatAB,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
typename
CReduceThreadClusterLengths_MPerBlock_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
>
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
>
struct
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
struct
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -154,6 +153,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -154,6 +153,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
...
@@ -237,21 +240,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -237,21 +240,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
return
false
;
// check NumGemmKPrefetchStage
// check gridwise gemm pipeline
if
constexpr
(
NumGemmKPrefetchStage
==
1
)
const
auto
num_k_loop
=
K
/
KPerBlock
;
{
// 1-stage prefetch always supported
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
}
else
if
constexpr
(
NumGemmKPrefetchStage
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K
/
KPerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
...
@@ -271,12 +263,11 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -271,12 +263,11 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
grid_size
;
return
grid_size
;
}
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k0_block_loop
=
((
K0
*
AK1
)
/
(
NumGemmKPrefetchStage
*
KPerBlock
))
>
1
;
const
index_t
num_loop
=
K
/
KPerBlock
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -362,7 +353,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -362,7 +353,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -372,8 +363,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -372,8 +363,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
D0ReduceOperation
&
d0_reduce_op
,
const
D1ElementwiseOperation
&
d1_element_op
,
const
D1ReduceOperation
&
d1_reduce_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
...
@@ -414,28 +404,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -414,28 +404,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -445,28 +435,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -445,28 +435,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
...
@@ -484,17 +474,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -484,17 +474,18 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
index_t
KPack
=
math
::
max
(
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
MPerXdl
,
MPerXdl
,
NPerXdl
,
NPerXdl
,
MXdlPerWave
,
MXdlPerWave
,
NXdlPerWave
,
NXdlPerWave
,
KPack
>
{};
KPack
,
LoopSched
>
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
@@ -514,42 +505,27 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -514,42 +505,27 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumGemmKPrefetchStage
,
HasMainK0BlockLoop
>
{};
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
KPerBlock
);
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_ak0_m_ak1
,
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_blockwise_copy
,
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_blockwise_copy
,
b_grid_buf
,
b_grid_buf
,
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
blockwise_gemm
,
blockwise_gemm
,
c_thread_buf
,
c_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// shuffle C and write out
// shuffle C and write out
{
{
...
@@ -665,8 +641,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -665,8 +641,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r1
<
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r1
<
BlockSize
,
// index_t BlockSize,
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
...
@@ -741,13 +717,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -741,13 +717,13 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
mreduce_per_thread
>
{}));
// TODO: this should be implemented as a blockwise reduction
// TODO: this should be implemented as a blockwise reduction
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
CShuffle
>
(
auto
c_reduce_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
ReduceAcc
>
(
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSpaceSize
());
c_reduce_thread_desc_mperblock_nperblock
.
GetElementSpaceSize
());
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
CShuffle
>
(
auto
d0_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
ReduceAcc
>
(
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
auto
d1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
CShuffle
>
(
auto
d1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
Float
ReduceAcc
>
(
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
d_reduce_thread_desc_mperblock
.
GetElementSpaceSize
());
// reduce: threadwise copy from LDS to VGPR
// reduce: threadwise copy from LDS to VGPR
...
@@ -763,7 +739,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -763,7 +739,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
auto
c_reduce_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
auto
c_reduce_thread_copy_lds_to_vgpr
=
ThreadwiseTensorSliceTransfer_v2
<
FloatCShuffle
,
FloatCShuffle
,
Float
CShuffle
,
Float
ReduceAcc
,
decltype
(
c_reduce_block_desc_mperblock_nperblock
),
decltype
(
c_reduce_block_desc_mperblock_nperblock
),
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
c_reduce_thread_lengths_mperblock_nperblock
),
decltype
(
c_reduce_thread_lengths_mperblock_nperblock
),
...
@@ -775,7 +751,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -775,7 +751,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// reduce: copy from VGPR to global
// reduce: copy from VGPR to global
auto
d0_reduce_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
d0_reduce_thread_copy_vgpr_to_global
=
ThreadwiseTensorSliceTransfer_v1r3
<
Float
CShuffle
,
Float
ReduceAcc
,
FloatD
,
FloatD
,
decltype
(
d_reduce_thread_desc_mblock_mperblock
),
decltype
(
d_reduce_thread_desc_mblock_mperblock
),
decltype
(
d_grid_desc_mblock_mperblock
),
decltype
(
d_grid_desc_mblock_mperblock
),
...
@@ -840,6 +816,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -840,6 +816,28 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
c_grid_buf
);
using
ThreadwiseReduce_D0
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_reduce_thread_desc_mperblock
),
D0ReduceOperation
,
false
>
;
using
ThreadwiseReduce_D1
=
ThreadwiseReduction
<
FloatReduceAcc
,
decltype
(
c_reduce_thread_desc_mperblock_nperblock
),
decltype
(
d_reduce_thread_desc_mperblock
),
D1ReduceOperation
,
false
>
;
const
auto
d0_zeroVal
=
D0ReduceOperation
::
GetReductionZeroVal
();
const
auto
d1_zeroVal
=
D0ReduceOperation
::
GetReductionZeroVal
();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d0_thread_buf
(
I
)
=
d0_zeroVal
;
});
static_for
<
0
,
mreduce_per_thread
,
1
>
{}(
[
&
](
auto
I
)
{
d1_thread_buf
(
I
)
=
d1_zeroVal
;
});
// reduce
// reduce
{
{
// copy from LDS to VGPR
// copy from LDS to VGPR
...
@@ -850,26 +848,20 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -850,26 +848,20 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_reduce_thread_buf
);
c_reduce_thread_buf
);
// reduce in VGPR
// reduce in VGPR
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
ThreadwiseReduce_D0
::
Reduce
(
c_reduce_thread_buf
,
d0_thread_buf
);
FloatReduceAcc
d0_acc
=
d0_reduce_op
.
GetReduceZeroValue
();
FloatReduceAcc
d1_acc
=
d1_reduce_op
.
GetReduceZeroValue
();
static_for
<
0
,
mreduce_per_thread
,
1
>
{}([
&
](
auto
im
)
{
static_for
<
0
,
nreduce_per_thread
,
1
>
{}([
&
](
auto
in
)
{
static_for
<
0
,
nreduce_per_thread
,
1
>
{}([
&
](
auto
in
)
{
constexpr
auto
offset
=
constexpr
auto
offset
=
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
Number
<
c_reduce_thread_desc_mperblock_nperblock
.
CalculateOffset
(
make_tuple
(
im
,
in
))
>
{};
make_tuple
(
im
,
in
))
>
{};
d0_reduce_op
.
Reduce
(
d0_acc
,
c_reduce_thread_buf
[
offset
]);
d1_element_op
(
c_reduce_thread_buf
(
offset
),
c_reduce_thread_buf
(
offset
));
d1_reduce_op
.
Reduce
(
d1_acc
,
c_reduce_thread_buf
[
offset
]);
});
});
constexpr
index_t
out_offset
=
d_reduce_thread_desc_mperblock
.
CalculateOffset
(
make_tuple
(
im
));
d0_thread_buf
(
Number
<
out_offset
>
{})
=
d0_acc
;
d1_thread_buf
(
Number
<
out_offset
>
{})
=
d1_acc
;
});
});
ThreadwiseReduce_D1
::
Reduce
(
c_reduce_thread_buf
,
d1_thread_buf
);
// copy from VGPR to Global
// copy from VGPR to Global
d0_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
d0_reduce_thread_copy_vgpr_to_global
.
Run
(
d_reduce_thread_desc_mblock_mperblock
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
b134b7d6
...
@@ -4,8 +4,8 @@
...
@@ -4,8 +4,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
...
@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
...
@@ -21,7 +21,7 @@ template <typename GridwiseGemm,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -41,17 +41,17 @@ __global__ void
...
@@ -41,17 +41,17 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -107,7 +107,8 @@ template <typename FloatAB,
...
@@ -107,7 +107,8 @@ template <typename FloatAB,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
>
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -125,6 +126,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -125,6 +126,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
...
@@ -190,10 +195,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -190,10 +195,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
{
// static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
// is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
// "wrong! K1 need to be known at compile-time");
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
...
@@ -208,21 +209,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -208,21 +209,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
return
false
;
// check NumGemmKPrefetchStage
// check gridwise gemm pipeline
if
constexpr
(
NumGemmKPrefetchStage
==
1
)
const
auto
num_k_loop
=
K
/
KPerBlock
;
{
// 1-stage prefetch always supported
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
}
else
if
constexpr
(
NumGemmKPrefetchStage
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K
/
KPerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
...
@@ -242,12 +232,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -242,12 +232,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
grid_size
;
return
grid_size
;
}
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k0_block_loop
=
((
K0
*
AK1
)
/
(
NumGemmKPrefetchStage
*
KPerBlock
))
>
1
;
const
index_t
num_loop
=
K
/
KPerBlock
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -315,7 +304,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -315,7 +304,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
@@ -358,28 +347,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -358,28 +347,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -389,28 +378,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -389,28 +378,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
...
@@ -428,17 +417,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -428,17 +417,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
index_t
KPack
=
math
::
max
(
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
MPerXdl
,
MPerXdl
,
NPerXdl
,
NPerXdl
,
MXdlPerWave
,
MXdlPerWave
,
NXdlPerWave
,
NXdlPerWave
,
KPack
>
{};
KPack
,
LoopSched
>
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
@@ -458,42 +448,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -458,42 +448,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumGemmKPrefetchStage
,
HasMainK0BlockLoop
>
{};
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
KPerBlock
);
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_ak0_m_ak1
,
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_blockwise_copy
,
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_blockwise_copy
,
b_grid_buf
,
b_grid_buf
,
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
blockwise_gemm
,
blockwise_gemm
,
c_thread_buf
,
c_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// shuffle C and write out
// shuffle C and write out
{
{
...
@@ -609,8 +584,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -609,8 +584,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r1
<
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r1
<
BlockSize
,
// index_t BlockSize,
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
b134b7d6
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#pragma once
#define CK_GRIDWISE_GEMM_XDLOPS_V2R3_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
...
@@ -22,7 +20,7 @@ template <typename GridwiseGemm,
...
@@ -22,7 +20,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -42,17 +40,17 @@ __global__ void
...
@@ -42,17 +40,17 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
,
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -67,88 +65,6 @@ __global__ void
...
@@ -67,88 +65,6 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
GemmDesc
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
bool
HasMainK0BlockLoop
,
index_t
MaxGroupCount
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdlops_v2r3
(
const
StaticallyIndexedArray
<
GemmDesc
,
MaxGroupCount
>
gemm_desc_
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
#if 1
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
block_id
>=
gemm_desc_
[
i
].
BlockStart_
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd_
&&
i
<
group_count
)
{
auto
group_id
=
i
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
gemm_desc_
[
group_id
].
a_ptr
,
gemm_desc_
[
group_id
].
b_ptr
,
gemm_desc_
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_
[
group_id
].
grouped_gemm_block_2_ctile_map_
);
}
});
#else
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
&
gemm_desc_
);
index_t
group_id
=
0
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
group_id
=
(
block_id
>=
gemm_desc_
[
i
].
BlockStart
&&
block_id
<
gemm_desc_
[
i
].
BlockEnd
&&
i
<
group_count
)
?
i
:
group_id
;
});
const
index_t
block_id_grp
=
block_id
-
gemm_desc_ptr
[
group_id
].
BlockStart
;
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
gemm_desc_ptr
[
group_id
].
a_ptr
,
gemm_desc_ptr
[
group_id
].
b_ptr
,
gemm_desc_ptr
[
group_id
].
c_ptr
,
p_shared
,
gemm_desc_ptr
[
group_id
].
a_grid_desc_k0_m_k1_
,
gemm_desc_ptr
[
group_id
].
b_grid_desc_k0_n_k1_
,
gemm_desc_ptr
[
group_id
].
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
a_element_op
,
b_element_op
,
c_element_op
,
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
block_id_grp
);
#endif
#else
ignore
=
gemm_desc_
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatAcc
,
...
@@ -187,7 +103,7 @@ template <index_t BlockSize,
...
@@ -187,7 +103,7 @@ template <index_t BlockSize,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
NumPrefetch
=
1
>
index_t
Num
GemmK
Prefetch
Stage
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -202,6 +118,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -202,6 +118,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -291,21 +211,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -291,21 +211,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
return
false
;
// check NumPrefetch
// check gridwise gemm pipeline
if
constexpr
(
NumPrefetch
==
1
)
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
{
// 1-stage prefetch always supported
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
}
else
if
constexpr
(
NumPrefetch
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K0
/
K0PerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
...
@@ -335,12 +244,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -335,12 +244,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return
grid_size
;
return
grid_size
;
}
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k0_block
_loop
=
(
K0
/
(
NumPrefetch
*
K0PerBlock
))
>
1
;
const
index_t
num
_loop
=
K
/
(
K0PerBlock
*
K1
)
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -433,7 +341,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -433,7 +341,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
@@ -478,28 +386,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -478,28 +386,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -509,28 +417,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -509,28 +417,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
...
@@ -575,41 +483,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -575,41 +483,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_k0_m_k1
)
>
,
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
a_block_desc_k0_m_k1
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
a_blockwise_copy
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
a_grid_buf
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
a_block_buf
,
remove_cvref_t
<
decltype
(
b_grid_desc_k0_n_k1
)
>
,
a_block_slice_copy_step
,
remove_cvref_t
<
decltype
(
b_block_desc_k0_n_k1
)
>
,
b_grid_desc_k0_n_k1
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
b_block_desc_k0_n_k1
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
b_blockwise_copy
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
b_grid_buf
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
b_block_buf
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
b_block_slice_copy_step
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
blockwise_gemm
,
NumPrefetch
,
c_thread_buf
,
HasMainK0BlockLoop
>
{};
num_k_block_main_loop
);
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
gridwise_gemm_pipeline
.
Run
(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_k0_n_k1
,
b_block_desc_k0_n_k1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
K0BlockMainLoop
);
// output: register to global memory
// output: register to global memory
{
{
...
@@ -692,4 +582,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -692,4 +582,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
};
};
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
b134b7d6
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -120,6 +120,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
...
@@ -120,6 +120,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -420,27 +422,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
...
@@ -420,27 +422,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}();
}();
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
K0PerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
1
,
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
3
,
3
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -450,27 +452,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
...
@@ -450,27 +452,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
K0PerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
1
,
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
3
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
b_b_k0_n_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
b134b7d6
...
@@ -6,8 +6,8 @@
...
@@ -6,8 +6,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -123,6 +123,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -123,6 +123,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -409,27 +411,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -409,27 +411,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}();
}();
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
K0PerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
1
,
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
3
,
3
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -439,27 +441,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -439,27 +441,27 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
1
,
K0PerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
1
,
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
3
,
3
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
b_b_k0_n_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
...
@@ -660,8 +662,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -660,8 +662,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// LDS to global
// LDS to global
auto
c_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r1
<
auto
c_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r1
<
BlockSize
,
// index_t BlockSize,
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
b134b7d6
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R1_HPP
#pragma once
#define CK_GRIDWISE_GEMM_XDLOPS_V3R1_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "tensor_space_filling_curve.hpp"
#include "tensor_space_filling_curve.hpp"
...
@@ -113,7 +111,7 @@ template <
...
@@ -113,7 +111,7 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumPrefetch
=
1
>
index_t
Num
GemmK
Prefetch
Stage
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -131,6 +129,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -131,6 +129,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
constexpr
auto
max_lds_align
=
AK1
;
constexpr
auto
max_lds_align
=
AK1
;
...
@@ -246,21 +248,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -246,21 +248,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
return
false
;
// check NumPrefetch
// check gridwise gemm pipeline
if
constexpr
(
NumPrefetch
==
1
)
const
auto
num_k_loop
=
K
/
KPerBlock
;
{
// 1-stage prefetch always supported
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
}
else
if
constexpr
(
NumPrefetch
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K
/
KPerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
...
@@ -290,12 +281,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -290,12 +281,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
return
grid_size
;
return
grid_size
;
}
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k0_block_loop
=
((
K0
*
AK1
)
/
(
NumPrefetch
*
KPerBlock
))
>
1
;
const
index_t
num_loop
=
K
/
KPerBlock
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -413,28 +403,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -413,28 +403,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -444,28 +434,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -444,28 +434,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
...
@@ -512,43 +502,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -512,43 +502,25 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumPrefetch
,
HasMainK0BlockLoop
>
{};
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
KPerBlock
);
g
ridwise
_g
emm
_p
ipe
line
.
Run
(
a_grid_desc_ak0_m_ak1
,
G
ridwise
G
emm
P
ipe
::
template
Run
<
HasMainK0BlockLoop
>
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_blockwise_copy
,
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_blockwise_copy
,
b_grid_buf
,
b_grid_buf
,
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
blockwise_gemm
,
blockwise_gemm
,
c_thread_buf
,
c_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// shuffle C and write out
// shuffle C and write out
{
{
...
@@ -672,8 +644,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -672,8 +644,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// LDS to global
// LDS to global
auto
c_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r1
<
auto
c_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r1
<
BlockSize
,
// index_t BlockSize,
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
...
@@ -774,4 +746,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
...
@@ -774,4 +746,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
};
};
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
b134b7d6
...
@@ -6,8 +6,8 @@
...
@@ -6,8 +6,8 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r2.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r2.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
...
@@ -24,7 +24,7 @@ template <typename GridwiseGemm,
...
@@ -24,7 +24,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -48,7 +48,7 @@ __global__ void
...
@@ -48,7 +48,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
...
@@ -119,7 +119,7 @@ template <
...
@@ -119,7 +119,7 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumPrefetch
=
1
>
index_t
Num
GemmK
Prefetch
Stage
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -134,6 +134,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -134,6 +134,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -252,21 +256,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -252,21 +256,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
return
false
;
// check NumPrefetch
// check gridwise gemm pipeline
if
constexpr
(
NumPrefetch
==
1
)
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
{
// 1-stage prefetch always supported
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
}
else
if
constexpr
(
NumPrefetch
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K0
/
K0PerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
...
@@ -296,12 +289,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -296,12 +289,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
return
grid_size
;
return
grid_size
;
}
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k0_block
_loop
=
(
K0
/
(
NumPrefetch
*
K0PerBlock
))
>
1
;
const
index_t
num
_loop
=
K
/
(
K0PerBlock
*
K1
)
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
}
template
<
typename
CGridDesc_M_N_
>
template
<
typename
CGridDesc_M_N_
>
...
@@ -379,7 +371,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -379,7 +371,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
@@ -434,28 +426,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -434,28 +426,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -465,28 +457,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -465,28 +457,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
NumPrefetch
>
(
Num
GemmK
Prefetch
Stage
>
(
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
...
@@ -531,41 +523,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -531,41 +523,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumPrefetch
,
HasMainK0BlockLoop
>
{};
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
g
ridwise
_g
emm
_p
ipe
line
.
Run
(
a_grid_desc_k0_m_k1
,
G
ridwise
G
emm
P
ipe
::
template
Run
<
HasMainKBlockLoop
>
(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_blockwise_copy
,
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
b_block_desc_k0_n_k1
,
b_block_desc_k0_n_k1
,
b_blockwise_copy
,
b_blockwise_copy
,
b_grid_buf
,
b_grid_buf
,
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
blockwise_gemm
,
blockwise_gemm
,
c_thread_buf
,
c_thread_buf
,
K0BlockMainLoop
);
K0BlockMainLoop
);
// shuffle C and write out
// shuffle C and write out
{
{
...
@@ -690,8 +664,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
...
@@ -690,8 +664,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
n_thread_data_on_block_idx
[
I2
]),
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
c_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r2
<
auto
c_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r2
<
BlockSize
,
// index_t BlockSize,
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
b134b7d6
#ifndef CK_GRIDWISE_GEMM_XDLOPS_V3R3_HPP
#pragma once
#define CK_GRIDWISE_GEMM_XDLOPS_V3R3_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "
blockwise
_tensor_slice_transfer_v4r1.hpp"
#include "
thread_group
_tensor_slice_transfer_v4r1.hpp"
#include "
blockwise
_tensor_slice_transfer_v6r3.hpp"
#include "
thread_group
_tensor_slice_transfer_v6r3.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
...
@@ -25,7 +23,7 @@ template <typename GridwiseGemm,
...
@@ -25,7 +23,7 @@ template <typename GridwiseGemm,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -52,7 +50,7 @@ __global__ void
...
@@ -52,7 +50,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
...
@@ -128,7 +126,7 @@ template <
...
@@ -128,7 +126,7 @@ template <
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumPrefetch
=
1
>
index_t
Num
GemmK
Prefetch
Stage
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -143,6 +141,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -143,6 +141,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -261,21 +263,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -261,21 +263,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
return
false
;
// check NumPrefetch
// check gridwise gemm pipeline
if
constexpr
(
NumPrefetch
==
1
)
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
{
// 1-stage prefetch always supported
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
}
else
if
constexpr
(
NumPrefetch
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K0
/
K0PerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
...
@@ -305,12 +296,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -305,12 +296,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
return
grid_size
;
return
grid_size
;
}
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
{
const
bool
has_main_k0_block
_loop
=
(
K0
/
(
NumPrefetch
*
K0PerBlock
))
>
1
;
const
index_t
num
_loop
=
K
/
(
K0PerBlock
*
K1
)
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
}
template
<
typename
CGridDesc_M_N_
>
template
<
typename
CGridDesc_M_N_
>
...
@@ -393,7 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -393,7 +383,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using
DefaultBlock2CTileMap
=
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
...
@@ -455,27 +445,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -455,27 +445,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
...
@@ -485,27 +475,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -485,27 +475,27 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
// B matrix blockwise copy
// B matrix blockwise copy
auto
b_blockwise_copy
=
auto
b_blockwise_copy
=
Blockwise
TensorSliceTransfer_v4r1
<
Block
Size
,
ThreadGroup
TensorSliceTransfer_v4r1
<
ThisThread
Block
,
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
Sequence
<
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
...
@@ -550,41 +540,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -550,41 +540,23 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_k0_m_k1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_k0_n_k1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumPrefetch
,
HasMainK0BlockLoop
>
{};
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
g
ridwise
_g
emm
_p
ipe
line
.
Run
(
a_grid_desc_k0_m_k1
,
G
ridwise
G
emm
P
ipe
::
template
Run
<
HasMainKBlockLoop
>
(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_blockwise_copy
,
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
b_block_desc_k0_n_k1
,
b_block_desc_k0_n_k1
,
b_blockwise_copy
,
b_blockwise_copy
,
b_grid_buf
,
b_grid_buf
,
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
blockwise_gemm
,
blockwise_gemm
,
c_thread_buf
,
c_thread_buf
,
K0BlockMainLoop
);
K0BlockMainLoop
);
// shuffle C and write out
// shuffle C and write out
{
{
...
@@ -623,17 +595,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -623,17 +595,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
,
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
,
make_tuple
(
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_freeze_transform
(
I0
),
// freeze mblock
make_pass_through_transform
(
make_pass_through_transform
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{}),
// M0 (MXdlPerWave) per
Number
<
CShuffleMXdlPerWavePerShuffle
>
{}),
// M0 (MXdlPerWave) per shuffle
// shuffle
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple
(
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_freeze_transform
(
I0
),
// freeze nblock
make_freeze_transform
(
I0
),
// freeze nblock
make_pass_through_transform
(
make_pass_through_transform
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{}),
// N0 (NXdlPerWave) per shuffle
Number
<
CShuffleNXdlPerWavePerShuffle
>
{}),
// N0 (NXdlPerWave) per
make_unmerge_transform
(
// shuffle
make_tuple
(
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_unmerge_transform
(
make_tuple
(
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXdl
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -709,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -709,8 +682,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
n_thread_data_on_block_idx
[
I2
]),
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
auto
c_block_copy_lds_to_global
=
Blockwise
TensorSliceTransfer_v6r3
<
auto
c_block_copy_lds_to_global
=
ThreadGroup
TensorSliceTransfer_v6r3
<
BlockSize
,
// index_t BlockSize,
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
...
@@ -851,4 +824,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
...
@@ -851,4 +824,3 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
};
};
}
// namespace ck
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
b134b7d6
...
@@ -51,7 +51,7 @@ template <typename SrcData,
...
@@ -51,7 +51,7 @@ template <typename SrcData,
typename
DstData
,
typename
DstData
,
typename
SrcDesc
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DstDesc
,
typename
Dst
ElementwiseOperation
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
SliceLengths
,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
...
@@ -70,12 +70,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -70,12 +70,11 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
,
const
Index
&
dst_slice_origin_idx
,
const
ElementwiseOperation
&
element_op
)
const
DstElementwiseOperation
&
dst_element_op
)
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
)),
:
dst_coord_
(
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
)),
dst_
element_op_
{
dst_
element_op
}
element_op_
{
element_op
}
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
"wrong! SrcDesc need to known at compile-time"
);
...
@@ -136,13 +135,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -136,13 +135,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
SrcData
dst_
v
;
SrcData
v
;
// apply element-wise operation
// apply element-wise operation
dst_
element_op_
(
dst_
v
,
src_buf
[
Number
<
src_offset
>
{}]);
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
// apply type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
dst_
v
);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
});
});
const
bool
is_dst_valid
=
const
bool
is_dst_valid
=
...
@@ -213,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -213,7 +212,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private:
private:
DstCoord
dst_coord_
;
DstCoord
dst_coord_
;
const
Dst
ElementwiseOperation
dst_
element_op_
;
const
ElementwiseOperation
element_op_
;
};
// namespace ThreadwiseTensorSliceTransfer_v1r3
};
// namespace ThreadwiseTensorSliceTransfer_v1r3
// Assume:
// Assume:
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp
View file @
b134b7d6
...
@@ -102,8 +102,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
...
@@ -102,8 +102,13 @@ struct ThreadwiseTensorSliceTransfer_v6r1
// apply pointwise operation
// apply pointwise operation
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
element_op_
(
dst_vector_container
.
template
AsType
<
DstData
>()(
i
),
SrcData
v
;
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
// apply element-wise operation
element_op_
(
v
,
src_vector_container
.
template
AsType
<
SrcData
>()[
i
]);
// apply type convert
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
});
});
const
bool
is_dst_valid
=
const
bool
is_dst_valid
=
...
...
include/ck/utility/amd_xdlops.hpp
View file @
b134b7d6
...
@@ -266,8 +266,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
...
@@ -266,8 +266,8 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
__device__
static
void
Run
(
const
int8x4_t
&
reg_a
,
const
int8x4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
int8x4_t
&
reg_a
,
const
int8x4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
int32x16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_32x32x8i8
(
bit_cast
<
int
>
(
reg_a
),
__builtin_amdgcn_mfma_i32_32x32x8i8
(
bit_cast
<
int
32_t
>
(
reg_a
),
bit_cast
<
int
>
(
reg_b
),
bit_cast
<
int
32_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
reg_c
.
template
AsType
<
int32x16_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
,
0
,
...
@@ -285,8 +285,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
...
@@ -285,8 +285,8 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
__device__
static
void
Run
(
const
int8x4_t
&
reg_a
,
const
int8x4_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
int8x4_t
&
reg_a
,
const
int8x4_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
reg_c
.
template
AsType
<
int32x4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_mfma_i32_16x16x16i8
(
bit_cast
<
int
>
(
reg_a
),
__builtin_amdgcn_mfma_i32_16x16x16i8
(
bit_cast
<
int
32_t
>
(
reg_a
),
bit_cast
<
int
>
(
reg_b
),
bit_cast
<
int
32_t
>
(
reg_b
),
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
reg_c
.
template
AsType
<
int32x4_t
>()[
Number
<
0
>
{}],
0
,
0
,
0
,
0
,
...
...
include/ck/utility/common_header.hpp
View file @
b134b7d6
...
@@ -28,6 +28,7 @@
...
@@ -28,6 +28,7 @@
#include "transpose_vectors.hpp"
#include "transpose_vectors.hpp"
#include "inner_product.hpp"
#include "inner_product.hpp"
// #include "element_wise_operation.hpp"
// #include "element_wise_operation.hpp"
#include "thread_group.hpp"
#include "debug.hpp"
#include "debug.hpp"
#include "amd_buffer_addressing.hpp"
#include "amd_buffer_addressing.hpp"
...
...
include/ck/utility/get_id.hpp
View file @
b134b7d6
...
@@ -3,11 +3,15 @@
...
@@ -3,11 +3,15 @@
namespace
ck
{
namespace
ck
{
__device__
constexpr
index_t
get_wave_size
()
{
return
CK_GPU_WAVE_SIZE
;
}
__host__
__device__
constexpr
index_t
get_warp_size
()
{
// warpSize is defined by HIP
return
warpSize
;
}
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_wa
ve
_local_1d_id
()
{
return
threadIdx
.
x
/
get_wa
ve
_size
();
}
__device__
index_t
get_wa
rp
_local_1d_id
()
{
return
threadIdx
.
x
/
get_wa
rp
_size
();
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
...
...
include/ck/utility/number.hpp
View file @
b134b7d6
...
@@ -8,5 +8,8 @@ namespace ck {
...
@@ -8,5 +8,8 @@ namespace ck {
template
<
index_t
N
>
template
<
index_t
N
>
using
Number
=
integral_constant
<
index_t
,
N
>
;
using
Number
=
integral_constant
<
index_t
,
N
>
;
template
<
index_t
N
>
using
LongNumber
=
integral_constant
<
long_index_t
,
N
>
;
}
// namespace ck
}
// namespace ck
#endif
#endif
include/ck/utility/static_buffer.hpp
View file @
b134b7d6
...
@@ -158,5 +158,11 @@ __host__ __device__ constexpr auto make_static_buffer(Number<N>)
...
@@ -158,5 +158,11 @@ __host__ __device__ constexpr auto make_static_buffer(Number<N>)
return
StaticBuffer
<
AddressSpace
,
T
,
N
,
true
>
{};
return
StaticBuffer
<
AddressSpace
,
T
,
N
,
true
>
{};
}
}
template
<
AddressSpaceEnum
AddressSpace
,
typename
T
,
long_index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
LongNumber
<
N
>
)
{
return
StaticBuffer
<
AddressSpace
,
T
,
N
,
true
>
{};
}
}
// namespace ck
}
// namespace ck
#endif
#endif
include/ck/utility/thread_group.hpp
0 → 100644
View file @
b134b7d6
#pragma once
#include "get_id.hpp"
namespace
ck
{
template
<
index_t
ThreadPerBlock
>
struct
ThisThreadBlock
{
static
constexpr
index_t
kNumThread_
=
ThreadPerBlock
;
__device__
static
constexpr
index_t
GetNumOfThread
()
{
return
kNumThread_
;
}
__device__
static
constexpr
bool
IsBelong
()
{
return
true
;
}
__device__
static
index_t
GetThreadId
()
{
return
get_thread_local_1d_id
();
}
};
}
// namespace ck
include/ck/utility/tuple.hpp
View file @
b134b7d6
...
@@ -21,9 +21,9 @@ struct TupleElement
...
@@ -21,9 +21,9 @@ struct TupleElement
{
{
__host__
__device__
constexpr
TupleElement
()
=
default
;
__host__
__device__
constexpr
TupleElement
()
=
default
;
template
<
typename
T
,
template
<
typename
enable_if
<!
is_same
<
remove_reference_t
<
remove_cv_t
<
T
>
>
,
TupleElement
>::
value
,
typename
T
,
bool
>::
type
=
false
>
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElement
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
__host__
__device__
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
{
{
}
}
...
@@ -60,7 +60,7 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
...
@@ -60,7 +60,7 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
template
<
typename
Y
,
template
<
typename
Y
,
typename
enable_if
<
sizeof
...(
Is
)
==
1
&&
sizeof
...(
Xs
)
==
1
&&
typename
enable_if
<
sizeof
...(
Is
)
==
1
&&
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_ref
erence_t
<
remove_cv
_t
<
Y
>
>
,
TupleImpl
>::
value
,
!
is_same
<
remove_
cv
ref_t
<
Y
>,
TupleImpl
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
...
@@ -101,8 +101,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
...
@@ -101,8 +101,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
constexpr
Tuple
()
=
default
;
__host__
__device__
constexpr
Tuple
()
=
default
;
template
<
typename
Y
,
template
<
typename
Y
,
typename
enable_if
<
sizeof
...(
Xs
)
==
1
&&
typename
enable_if
<
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_cvref_t
<
Y
>,
Tuple
>::
value
,
!
is_same
<
remove_reference_t
<
remove_cv_t
<
Y
>
>
,
Tuple
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
std
::
forward
<
Y
>
(
y
))
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
std
::
forward
<
Y
>
(
y
))
{
{
...
...
library/include/ck/library/host/host_interface.hpp
0 → 100644
View file @
b134b7d6
#pragma once
#include <memory>
#include <string>
#include "stream_config.hpp"
#include "config.hpp"
#include "device_base.hpp"
struct
DeviceConvFwdPtr_t
{
using
BaseArgument
=
ck
::
tensor_operation
::
device
::
BaseArgument
;
using
BaseInvoker
=
ck
::
tensor_operation
::
device
::
BaseInvoker
;
struct
DeviceConvFwdPtrImpl
;
std
::
unique_ptr
<
DeviceConvFwdPtrImpl
>
pImpl
;
DeviceConvFwdPtr_t
();
~
DeviceConvFwdPtr_t
();
DeviceConvFwdPtr_t
(
DeviceConvFwdPtr_t
&&
);
DeviceConvFwdPtr_t
(
DeviceConvFwdPtrImpl
&
);
DeviceConvFwdPtr_t
&
operator
=
(
DeviceConvFwdPtr_t
&
)
=
delete
;
DeviceConvFwdPtr_t
&
operator
=
(
const
DeviceConvFwdPtr_t
&
)
=
delete
;
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
in_ptr
,
void
*
wei_ptr
,
void
*
out_ptr
,
size_t
N
,
size_t
K
,
size_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
const
;
// in,wei and out element ops are ignored for now since even if we change them, they
// cant be linked
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
const
;
// requires including BaseInvoker headers
std
::
string
GetTypeString
();
bool
IsSupportedArgument
(
const
BaseArgument
*
arg_ptr
);
};
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
void
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances_t
(
std
::
vector
<
DeviceConvFwdPtr_t
>&
instances
);
library/include/ck/library/host_tensor/device.hpp
View file @
b134b7d6
#ifndef DEVICE_HPP
#pragma once
#define DEVICE_HPP
#include <memory>
#include <memory>
#include <functional>
#include <functional>
#include <thread>
#include <thread>
#include <chrono>
#include <chrono>
#include "hip/hip_runtime.h"
#include <hip/hip_runtime.h>
#include "hip/hip_fp16.h"
#include <hip/hip_fp16.h>
#include "stream_config.hpp"
#include "ck/options.hpp"
inline
void
hip_check_error
(
hipError_t
x
)
{
if
(
x
!=
hipSuccess
)
{
std
::
ostringstream
ss
;
ss
<<
"HIP runtime error: "
<<
hipGetErrorString
(
x
)
<<
". "
<<
__FILE__
<<
": "
<<
__LINE__
<<
"in function: "
<<
__func__
;
throw
std
::
runtime_error
(
ss
.
str
());
}
}
struct
DeviceMem
struct
DeviceMem
{
{
...
@@ -68,47 +81,60 @@ struct WallTimer
...
@@ -68,47 +81,60 @@ struct WallTimer
using
device_stream_t
=
hipStream_t
;
using
device_stream_t
=
hipStream_t
;
template
<
typename
...
Args
,
typename
F
>
template
<
typename
...
Args
,
typename
F
>
void
launch_kernel
(
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
float
launch_and_time_kernel
(
const
StreamConfig
&
stream_config
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
{
hipStream_t
stream_id
=
nullptr
;
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
{
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
const
int
nrepeat
=
10
;
}
template
<
typename
...
Args
,
typename
F
>
printf
(
"Warm up 1 time
\n
"
);
float
launch_and_time_kernel
(
F
kernel
,
int
nrepeat
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
KernelTimer
timer
;
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
// warm up
__func__
,
hipLaunchKernelGGL
(
grid_dim
.
x
,
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"
Warm up
\n
"
);
printf
(
"
Start running %d times...
\n
"
,
nrepeat
);
hipStream_t
stream_id
=
nullptr
;
KernelTimer
timer
;
timer
.
Start
();
// warm up
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
{
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
}
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
timer
.
End
(
);
timer
.
Start
();
return
timer
.
GetElapsedTime
()
/
nrepeat
;
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_id
,
args
...);
}
}
else
{
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
timer
.
End
();
return
0
;
}
#else
hipLaunchKernelGGL
(
kernel
,
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
,
args
...);
return
timer
.
GetElapsedTime
()
/
nrepeat
;
return
0
;
#endif
}
}
template
<
typename
...
Args
,
typename
F
>
template
<
typename
...
Args
,
typename
F
>
...
@@ -137,4 +163,3 @@ float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args)
...
@@ -137,4 +163,3 @@ float launch_and_time_cpu_kernel(F kernel, int nrepeat, Args... args)
return
timer
.
GetElapsedTime
()
/
nrepeat
;
return
timer
.
GetElapsedTime
()
/
nrepeat
;
}
}
#endif
Prev
1
2
3
4
5
6
7
8
9
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment