Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cbcc844e
Commit
cbcc844e
authored
Feb 08, 2024
by
illsilin
Browse files
merge from public repo
parents
29deceb6
1f306024
Changes
393
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2333 additions
and
47 deletions
+2333
-47
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
.../grid/gridwise_elementwise_layernorm_welford_variance.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp
+1
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+4
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
...gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
+23
-14
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp
...ration/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp
+142
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+2
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
+1153
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
...tion/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
...operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp
.../gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp
+962
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
...ensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+31
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+1
-1
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp
View file @
cbcc844e
...
...
@@ -119,7 +119,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
InDataTypePointerTuple
p_in_global_tuple
,
XDataType
*
const
__restrict__
p_x_lds
,
XDataType
*
const
__restrict__
p_x_lds
_
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
...
...
@@ -149,7 +149,7 @@ struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk
p_y_global
,
y_grid_desc_m_k
.
GetElementSpaceSize
());
auto
x_lds_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_x_lds
,
x_grid_desc_m_k
.
GetElementSpaceSize
()
/
grid_size
);
p_x_lds
_
,
x_grid_desc_m_k
.
GetElementSpaceSize
()
/
grid_size
);
auto
in_thread_buf_tuple
=
generate_tuple
(
[
&
](
auto
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
cbcc844e
...
...
@@ -67,7 +67,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp
View file @
cbcc844e
...
...
@@ -28,8 +28,7 @@ __global__ void
#endif
kernel_gemm_dpp
(
const
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
auto
a_grid_desc_ak0_m_ak1
=
amd_wave_read_first_lane
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
cbcc844e
...
...
@@ -54,8 +54,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -148,8 +147,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
// printf("entry kernel launch");
__shared__
char
p_shared
[
GridwiseOp
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -244,8 +242,7 @@ __global__ void
const
CDEElementwiseOperation
cde_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
__shared__
char
p_shared
[
GridwiseOp
::
GetSharedMemoryNumberOfByte
()];
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -274,7 +271,7 @@ __global__ void
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx11
00
__ ))
#endif // end of if (defined(__gfx11__ ))
}
template
<
// DataType Family
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
View file @
cbcc844e
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_lds.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
...
@@ -55,8 +56,7 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -236,9 +236,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
(
a_block_space_size_aligned
*
sizeof
(
AComputeDataType
)
+
b_block_space_size_aligned
*
sizeof
(
BComputeDataType
),
c_block_size
*
sizeof
(
CShuffleDataType
));
return
math
::
max
(
NumGemmKPrefetchStage
*
a_block_space_size_aligned
*
sizeof
(
AComputeDataType
)
+
NumGemmKPrefetchStage
*
b_block_space_size_aligned
*
sizeof
(
BComputeDataType
),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
__host__
__device__
static
auto
...
...
@@ -624,12 +625,20 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
AComputeDataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
BComputeDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
a_buffers_offset
=
0
;
auto
a_block_buffers
=
ck
::
lds_utils
::
AllocateLdsBuffers
<
AComputeDataType
,
NumGemmKPrefetchStage
>
(
p_shared
,
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
a_buffers_offset
,
max_lds_align
);
const
auto
b_buffers_offset
=
a_block_space_size_aligned
*
NumGemmKPrefetchStage
;
auto
b_block_buffers
=
ck
::
lds_utils
::
AllocateLdsBuffers
<
BComputeDataType
,
NumGemmKPrefetchStage
>
(
p_shared
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
b_buffers_offset
,
max_lds_align
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
...
...
@@ -645,13 +654,13 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_buf
fers
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_buf
fers
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp
View file @
cbcc844e
...
...
@@ -7,6 +7,20 @@
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace
lds_direct_load
{
__device__
void
sched_barrier
()
{
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
// When direct loads and `waitcnt` instructions are submitted using inline asm, the usage of
// `sched_barrier` is necessary to make sure no instructions that use the loaded memory
// are scheduled by the compiler before the `waitcnt` instruction.
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
}
// namespace lds_direct_load
namespace
ck
{
template
<
index_t
NumPrefetch
>
...
...
@@ -17,7 +31,6 @@ template <>
struct
GridwiseGemmPipeline_v4
<
1
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
...
...
@@ -31,13 +44,13 @@ struct GridwiseGemmPipeline_v4<1>
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockBuffer
s
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockBuffer
s
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
...
...
@@ -45,18 +58,22 @@ struct GridwiseGemmPipeline_v4<1>
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
ABlockBuffer
s
&
a_block_buf
s
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
BBlockBuffer
s
&
b_block_buf
s
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
static_assert
(
ABlockBuffers
::
Size
()
==
1
&&
BBlockBuffers
::
Size
()
==
1
);
auto
&
a_block_buf
=
a_block_bufs
.
At
(
I0
);
auto
&
b_block_buf
=
b_block_bufs
.
At
(
I0
);
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
...
...
@@ -74,10 +91,12 @@ struct GridwiseGemmPipeline_v4<1>
do
{
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf
);
...
...
@@ -92,10 +111,128 @@ struct GridwiseGemmPipeline_v4<1>
// tail
{
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
}
}
};
// 2-stages prefetch
template
<
>
struct
GridwiseGemmPipeline_v4
<
2
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
return
num_loop
%
2
==
0
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
(
num_loop
/
2
)
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffers
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffers
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffers
&
a_block_bufs
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffers
&
b_block_bufs
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
static_assert
(
ABlockBuffers
::
Size
()
==
2
&&
BBlockBuffers
::
Size
()
==
2
);
auto
&
a_block_buf1
=
a_block_bufs
.
At
(
I0
);
auto
&
a_block_buf2
=
a_block_bufs
.
At
(
I1
);
auto
&
b_block_buf1
=
b_block_bufs
.
At
(
I0
);
auto
&
b_block_buf2
=
b_block_bufs
.
At
(
I1
);
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf1
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf1
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf2
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf2
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf1
,
b_block_buf1
,
c_thread_buf
);
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf1
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf1
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf2
,
b_block_buf2
,
c_thread_buf
);
i
+=
2
;
}
while
(
i
<
(
num_loop
-
2
));
}
// tail
{
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_buf2
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_grid_buf
,
b_block_desc
,
b_block_buf2
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf1
,
b_block_buf1
,
c_thread_buf
);
block_sync_lds_direct_load
();
lds_direct_load
::
sched_barrier
();
blockwise_gemm
.
Run
(
a_block_buf2
,
b_block_buf2
,
c_thread_buf
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
cbcc844e
...
...
@@ -55,7 +55,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
cbcc844e
...
...
@@ -49,8 +49,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
defined(__gfx1102__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -75,7 +74,7 @@ __global__ void
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
template
<
index_t
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
cbcc844e
...
...
@@ -25,7 +25,7 @@ __global__ void
kernel_gemm_xdl_cshuffle_v1
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
...
...
@@ -50,7 +50,7 @@ __global__ void
typename
GridwiseGemm
::
Problem
problem
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
problem
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp
0 → 100644
View file @
cbcc844e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
index_t
TailNum
=
3
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
1
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v2
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared_0
,
p_shared_1
,
karg
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
1
)
#endif
kernel_gemm_xdl_cshuffle_v2
(
const
FloatA
*
p_a_grid
,
const
FloatB
*
p_b_grid
,
FloatC
*
p_c_grid
,
typename
GridwiseGemm
::
Problem
problem
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_0
,
p_shared_1
,
problem
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
problem
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
FloatA
,
typename
FloatB
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
ComputeTypeA
=
FloatC
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
GridwiseGemm_xdl_cshuffle_v2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK0Number
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0Number
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
1
,
1
);
}
__host__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_divide_ceil
(
M
,
MPerBlock
)
*
MPerBlock
;
}
__host__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
math
::
integer_divide_ceil
(
N
,
NPerBlock
)
*
NPerBlock
;
}
__host__
static
auto
CalculateKPadded
(
index_t
K
)
{
return
math
::
integer_divide_ceil
(
K
,
KPerBlock
)
*
KPerBlock
;
}
__host__
static
auto
CalculateAK0
(
index_t
K
)
{
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
return
CalculateKPadded
(
K
)
/
AK1Value
;
}
else
{
return
K
/
AK1Value
;
}
}
__host__
static
auto
CalculateBK0
(
index_t
K
)
{
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
return
CalculateKPadded
(
K
)
/
BK1Value
;
}
else
{
return
K
/
BK1Value
;
}
}
__host__
static
auto
CalculateMBlock
(
index_t
M
)
{
return
math
::
integer_divide_floor
(
M
,
MPerBlock
);
}
__host__
static
auto
CalculateNBlock
(
index_t
N
)
{
return
math
::
integer_divide_floor
(
N
,
NPerBlock
);
}
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
TileDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptor
(
const
TileDesc_K0_MN_K1
&
)
{
constexpr
index_t
K0
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
0
>
{});
constexpr
index_t
K1
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
2
>
{});
return
transform_tensor_descriptor
(
TileDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
StrideB
,
I1
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
N
,
NPad
-
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeAMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
struct
Problem
{
__host__
Problem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)},
KPadded
{
CalculateKPadded
(
K_
)},
AK0
{
CalculateAK0
(
K_
)},
BK0
{
CalculateBK0
(
K_
)},
MBlock
{
CalculateMBlock
(
M_
)},
NBlock
{
CalculateNBlock
(
N_
)}
{
}
__host__
void
Print
()
const
{
std
::
cout
<<
"problem {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"AK0:"
<<
AK0
<<
", "
<<
"BK0:"
<<
BK0
<<
", "
<<
"MBlock: "
<<
MBlock
<<
", "
<<
"NBlock: "
<<
NBlock
<<
"}"
<<
std
::
endl
;
}
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KPadded
;
index_t
AK0
;
index_t
BK0
;
index_t
MBlock
;
index_t
NBlock
;
};
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
,
public
Problem
{
__host__
Argument
(
const
FloatA
*
p_a_grid_
,
const
FloatB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
},
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
}
{
}
const
FloatA
*
p_a_grid
;
const
FloatB
*
p_b_grid
;
FloatC
*
p_c_grid
;
};
// FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1Number
,
AK1Number
,
I1
));
}
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1Number
,
BK1Number
,
I1
));
}
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
ComputeTypeA
)
+
b_block_space_size_aligned
*
sizeof
(
ComputeTypeB
)),
c_block_size
*
sizeof
(
FloatCShuffle
));
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
static
constexpr
bool
CheckValidity
(
const
Problem
&
problem
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
problem
.
M
%
MPerBlock
==
0
))
{
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
problem
.
N
%
NPerBlock
==
0
))
{
return
false
;
}
}
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
)
{
if
(
!
(
CalculateKPadded
(
problem
.
K
)
%
AK1Value
==
0
)
||
!
(
CalculateKPadded
(
problem
.
K
)
%
BK1Value
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
problem
.
K
%
AK1Value
==
0
)
||
!
(
problem
.
K
%
BK1Value
==
0
))
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
problem
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
problem
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
(
problem
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
problem
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
if
(
problem
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
if
(
problem
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
(
CalculateAK0
(
problem
.
K
)
*
AK1Value
)
/
KPerBlock
;
if
(
num_k_loop
<
4
)
{
return
false
;
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
num_loop
>
3
;
}
__host__
static
constexpr
index_t
CalculateKBlockLoopTailNum
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
if
(
num_loop
%
2
==
1
)
return
3
;
else
return
2
;
}
template
<
typename
CGridDesc
>
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
{
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using
Block2CTileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
template
<
bool
HasMainKBlockLoop
,
index_t
TailNum
=
3
>
__device__
static
void
Run
(
const
FloatA
*
p_a_grid
,
const
FloatB
*
p_b_grid
,
FloatC
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
const
Problem
&
problem
)
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Block2CTileMap
{
problem
.
M
,
problem
.
N
,
4
};
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
#if 0
if(threadIdx.x == 0){
printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n",
get_block_1d_id(),
block_work_idx[I0],
block_work_idx[I1],
__smid()>>6 & 0xf,
__smid()>>4 & 0x3,
__smid() & 0xf);
}
#endif
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatA
,
ComputeTypeA
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatB
,
ComputeTypeB
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
// BlockSize,
// ComputeType,
// FloatGemmAcc,
// decltype(a_block_desc_ak0_m_ak1),
// decltype(b_block_desc_bk0_n_bk1),
// MPerXdl,
// NPerXdl,
// MXdlPerWave,
// NXdlPerWave,
// KPack,
// LoopSched>();
auto
blockwise_gemm_pipeline
=
BlockwiseGemmXdlops_pipeline_v4
<
BlockSize
,
ComputeTypeA
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
MakeAMmaTileDescriptor_M0_M1_M2_K
(
a_block_desc_ak0_m_ak1
)),
decltype
(
MakeBMmaTileDescriptor_N0_N1_N2_K
(
b_block_desc_bk0_n_bk1
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
{};
// TransposeC
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeTypeA
*>
(
p_shared_0
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeTypeB
*>
(
p_shared_0
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeTypeA
*>
(
p_shared_1
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ComputeTypeB
*>
(
p_shared_1
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_bufs
=
make_tuple
(
a_block_buf_ping
,
a_block_buf_pong
);
auto
b_block_bufs
=
make_tuple
(
b_block_buf_ping
,
b_block_buf_pong
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
// gridwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
GridwiseGemmPipe
>
);
// const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_bufs
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_bufs
,
b_block_slice_copy_step
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatCShuffle
*>
(
p_shared_0
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatGemmAcc
,
FloatCShuffle
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp
View file @
cbcc844e
...
...
@@ -58,7 +58,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
// TODO ANT: separate into MMA + Epilogue
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
View file @
cbcc844e
...
...
@@ -167,7 +167,7 @@ __global__ void
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp
View file @
cbcc844e
...
...
@@ -45,7 +45,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp
0 → 100644
View file @
cbcc844e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_lds.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_splitk_lds_direct_load
(
typename
GridwiseGemm
::
Argument
karg
,
const
Block2CTileMap
&
b2c_map
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
karg
,
static_cast
<
void
*>
(
p_shared
),
b2c_map
,
a_element_op
,
b_element_op
,
c_element_op
);
#else
ignore
=
karg
;
ignore
=
b2c_map
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
FloatC
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
K1Value
,
index_t
MRepeat
,
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMRepeatPerShuffle
,
index_t
CShuffleNRepeatPerShuffle
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v4
,
typename
ComputeType
=
FloatC
>
struct
GridwiseGemm_xdlops_splitk_lds_direct_load
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
static
constexpr
auto
M01
=
1
;
static
constexpr
auto
N01
=
1
;
static
constexpr
auto
gemm_padder
=
tensor_operation
::
device
::
GemmPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K1
*
K0PerBlock
};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
struct
Argument
:
public
ck
::
tensor_operation
::
device
::
BaseArgument
{
const
FloatA
*
p_a_grid
;
const
FloatB
*
p_b_grid
;
FloatC
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KPadded
;
index_t
K0Padded
;
index_t
k_batch
;
Argument
(
const
FloatA
*
p_a_grid_
,
const
FloatB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
K0Padded_
,
index_t
k_batch_
)
:
p_a_grid
(
p_a_grid_
),
p_b_grid
(
p_b_grid_
),
p_c_grid
(
p_c_grid_
),
M
(
M_
),
N
(
N_
),
K
(
K_
),
StrideA
(
StrideA_
),
StrideB
(
StrideB_
),
StrideC
(
StrideC_
),
MPadded
(
MPadded_
),
NPadded
(
NPadded_
),
KPadded
(
KPadded_
),
K0Padded
(
K0Padded_
),
k_batch
(
k_batch_
)
{
}
void
Print
()
const
{
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"K0Padded:"
<<
K0Padded
<<
", "
<<
"KB:"
<<
k_batch
<<
"}"
<<
std
::
endl
;
}
};
__host__
__device__
static
auto
CalculateGridSize
(
const
Argument
&
karg
)
{
return
std
::
make_tuple
(
math
::
integer_divide_ceil
(
karg
.
N
,
NPerBlock
),
math
::
integer_divide_ceil
(
karg
.
M
,
MPerBlock
),
karg
.
k_batch
);
}
// prefer this to be called on host
__host__
__device__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_least_multiple
(
M
,
MPerBlock
);
}
__host__
__device__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
__host__
__device__
static
auto
CalculateK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
// k_batch * k0 * k0_per_block * k1
auto
K_t
=
K_Batch
*
K0PerBlock
*
K1
;
return
(
K
+
K_t
-
1
)
/
K_t
*
K0PerBlock
;
}
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K0Padded
=
CalculateK0Padded
(
K
,
K_Batch
);
return
K_Batch
*
K0Padded
*
K1
;
}
__host__
__device__
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
StrideA
,
index_t
KBatch
,
index_t
K0Padded
,
index_t
KPad
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_KBatch_K0_N_K1
(
index_t
K
,
index_t
NPad
,
index_t
N
,
index_t
StrideB
,
index_t
KBatch
,
index_t
K0Padded
,
index_t
KPad
)
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
)
{
// const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
return
gemm_padder
.
PadCDescriptor_M_N
(
c_grid_desc_m_n
);
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_k0_n_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
c_block_size
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
().
GetElementSpaceSize
();
return
math
::
max
(
NumGemmKPrefetchStage
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
ComputeType
),
c_block_size
*
sizeof
(
FloatC
));
}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
auto
K_t
=
karg
.
k_batch
*
K0PerBlock
*
K1
;
if
(
!
(
karg
.
K
%
K_t
==
0
))
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
if
(
karg
.
N
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
{
return
false
;
}
}
else
{
if
(
karg
.
M
%
CBlockTransferScalarPerVector_NWaveNPerXDL
!=
0
)
{
return
false
;
}
}
const
auto
num_k_loop
=
karg
.
K0Padded
/
K0PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
return
true
;
}
__host__
__device__
static
auto
GetKPad
(
index_t
K
,
index_t
KBatch
)
{
const
index_t
K0Padded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0Padded
*
K1
;
return
KPad
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0Padded
)
{
const
index_t
num_loop
=
K0Padded
/
K0PerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
template
<
typename
CGridDesc
>
__host__
__device__
static
constexpr
auto
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_m_n_grid_desc
)
{
const
auto
M
=
c_m_n_grid_desc
.
GetLength
(
I0
);
const
auto
N
=
c_m_n_grid_desc
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
return
transform_tensor_descriptor
(
c_m_n_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
{}));
}
// return block_id to C matrix tile idx (m0, n0, k_split) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
()
{
return
BlockToCTileMap_3DGrid_KSplit
<
MPerBlock
,
NPerBlock
>
();
}
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
())
>
;
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
// Elementwise operations are not supported for A and B, arguments left only for the API
// consistency.
(
void
)
a_element_op
;
(
void
)
b_element_op
;
const
FloatA
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
a_b_k0_m_k1_grid_desc
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0Padded
,
karg
.
KPadded
);
const
auto
b_b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0Padded
,
karg
.
KPadded
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_b_k0_m_k1_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_b_k0_n_k1_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [KBatch, M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I2
]);
const
index_t
k_batch_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
a_b_k0_m_k1_block_desc
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
MPerBlock
+
1
>
{}
*
K1
,
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
constexpr
auto
b_b_k0_n_k1_block_desc
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
K0PerBlock
>
{}
*
Number
<
NPerBlock
+
1
>
{}
*
K1
,
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
1
>
{},
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_DirectLoad
<
ThisThreadBlock
,
Sequence
<
1
,
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
FloatA
,
ComputeType
,
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
ABlockTransferSrcVectorDim
,
3
,
ABlockTransferSrcScalarPerVector
>
(
a_b_k0_m_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
m_block_data_idx_on_grid
,
0
),
a_b_k0_m_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_DirectLoad
<
ThisThreadBlock
,
Sequence
<
1
,
K0PerBlock
,
NPerBlock
,
K1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
FloatB
,
ComputeType
,
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
BBlockTransferSrcVectorDim
,
3
,
BBlockTransferSrcScalarPerVector
>
(
b_b_k0_n_k1_grid_desc
,
make_multi_index
(
k_batch_id
,
0
,
n_block_data_idx_on_grid
,
0
),
b_b_k0_n_k1_block_desc
,
make_multi_index
(
0
,
0
,
0
,
0
));
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
ComputeType
,
// ComputeType A
ComputeType
,
// ComputeType B
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
K1
,
LoopSched
>
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
const
auto
a_buffers_offset
=
0
;
auto
a_block_buffers
=
ck
::
lds_utils
::
AllocateLdsBuffers
<
ComputeType
,
NumGemmKPrefetchStage
>
(
p_shared_block
,
a_b_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
a_buffers_offset
,
max_lds_align
);
const
auto
b_buffers_offset
=
a_block_space_size
*
NumGemmKPrefetchStage
;
auto
b_block_buffers
=
ck
::
lds_utils
::
AllocateLdsBuffers
<
ComputeType
,
NumGemmKPrefetchStage
>
(
p_shared_block
,
b_b_k0_n_k1_block_desc
.
GetElementSpaceSize
(),
b_buffers_offset
,
max_lds_align
);
// gridwise GEMM pipeline
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_b_k0_m_k1_grid_desc
.
GetLength
(
I1
)
*
a_b_k0_m_k1_grid_desc
.
GetLength
(
I3
))
/
(
K0PerBlock
*
K1
));
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipe
{};
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
>(
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_block_desc
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buffers
,
a_block_slice_copy_step
,
b_b_k0_n_k1_grid_desc
,
b_b_k0_n_k1_block_desc
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buffers
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
// output: register to global memory
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc
.
GetLength
(
I7
);
constexpr
auto
c_block_desc_mblock_mperblock_nblock_nperblock
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatC
*>
(
p_shared_block
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
// freeze mblock
make_unmerge_transform
(
make_tuple
(
CShuffleMRepeatPerShuffle
,
M1
,
M2
,
M3
,
M4
)),
// M1 = MWave, M2 * M3 * M4 = MPerXDL
make_freeze_transform
(
I0
),
// freeze nblock
make_unmerge_transform
(
make_tuple
(
CShuffleNRepeatPerShuffle
,
N1
,
N2
))),
// M1 = MWave, M2 * M3 * M4 = MPerXDL
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// LDS to global
auto
c_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
1
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
>
,
// BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatC
,
// typename SrcData,
FloatC
,
// typename DstData,
decltype
(
c_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun
{
c_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
c_element_op
};
constexpr
auto
mxdlperwave_forward_step
=
make_multi_index
(
0
,
CShuffleMRepeatPerShuffle
*
MWave
*
MPerXDL
,
0
,
0
);
constexpr
auto
nxdlperwave_forward_step
=
make_multi_index
(
0
,
0
,
0
,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
);
constexpr
auto
nxdlperwave_backward_step
=
make_multi_index
(
0
,
0
,
0
,
-
CShuffleNRepeatPerShuffle
*
NWave
*
NPerXDL
);
static_for
<
0
,
MRepeat
,
CShuffleMRepeatPerShuffle
>
{}([
&
](
auto
mxdlperwave_iter
)
{
constexpr
auto
mxdlperwave
=
mxdlperwave_iter
;
static_for
<
0
,
NRepeat
,
CShuffleNRepeatPerShuffle
>
{}([
&
](
auto
nxdlperwave_iter
)
{
constexpr
bool
nxdlperwave_forward_sweep
=
(
mxdlperwave
%
(
2
*
CShuffleMRepeatPerShuffle
)
==
0
);
constexpr
index_t
nxdlperwave_value
=
nxdlperwave_forward_sweep
?
nxdlperwave_iter
:
(
NRepeat
-
nxdlperwave_iter
-
CShuffleNRepeatPerShuffle
);
constexpr
auto
nxdlperwave
=
Number
<
nxdlperwave_value
>
{};
// make sure it's safe to do ds_write
block_sync_lds
();
// VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc
,
make_tuple
(
mxdlperwave
,
nxdlperwave
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_block_buf
);
// make sure it's safe to do ds_read
block_sync_lds
();
// LDS to global
c_block_copy_lds_to_global
.
Run
(
c_block_desc_mblock_mperblock_nblock_nperblock
,
c_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
// move on nxdlperwave dimension
if
constexpr
(
nxdlperwave_forward_sweep
&&
(
nxdlperwave
<
NRepeat
-
CShuffleNRepeatPerShuffle
))
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
nxdlperwave_forward_step
);
}
else
if
constexpr
((
!
nxdlperwave_forward_sweep
)
&&
(
nxdlperwave
>
0
))
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
nxdlperwave_backward_step
);
}
});
// move on mxdlperwave dimension
if
constexpr
(
mxdlperwave
<
MRepeat
-
CShuffleMRepeatPerShuffle
)
{
c_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
mxdlperwave_forward_step
);
}
});
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
View file @
cbcc844e
...
...
@@ -38,7 +38,7 @@ __global__ void
typename
GridwiseGemm
::
Block2CTileMap
block_mapping
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
cbcc844e
...
...
@@ -39,7 +39,7 @@ __global__ void
const
CGridDesc_M_N
c_grid_desc_m_n
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -70,7 +70,7 @@ __global__ void
kernel_gemm_xdlops_v2r3
(
const
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
auto
a_grid_desc_k0_m_k1
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
cbcc844e
...
...
@@ -43,7 +43,7 @@ __global__ void
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
cbcc844e
...
...
@@ -37,7 +37,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
...
...
@@ -268,6 +268,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
)
{
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
...
...
@@ -329,6 +344,21 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
if
constexpr
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
KPadding
)
{
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0Padded
,
K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
cbcc844e
...
...
@@ -47,7 +47,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
cbcc844e
...
...
@@ -50,7 +50,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94
0__) || defined(__gfx941__) || defined(__gfx942
__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
...
...
Prev
1
…
5
6
7
8
9
10
11
12
13
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment