Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6985af40
Commit
6985af40
authored
Jul 16, 2022
by
wangshaojie6
Browse files
init code
parent
63914743
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
1597 additions
and
0 deletions
+1597
-0
example/27_gemm_gemm/CMakeLists.txt
example/27_gemm_gemm/CMakeLists.txt
+1
-0
example/27_gemm_gemm/gemm_gemm_xdl_fp16.cpp
example/27_gemm_gemm/gemm_gemm_xdl_fp16.cpp
+219
-0
example/CMakeLists.txt
example/CMakeLists.txt
+1
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_c_shuffle.hpp
...eration/gpu/device/device_batched_gemm_gemm_c_shuffle.hpp
+48
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_c_shuffle.hpp
...ion/gpu/device/device_batched_gemm_gemm_xdl_c_shuffle.hpp
+597
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp
...or_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp
+731
-0
No files found.
example/27_gemm_gemm/CMakeLists.txt
0 → 100644
View file @
6985af40
add_example_executable
(
example_gemm_gemm_xdl_fp16 gemm_gemm_xdl_fp16.cpp
)
\ No newline at end of file
example/27_gemm_gemm/gemm_gemm_xdl_fp16.cpp
0 → 100644
View file @
6985af40
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
QDataType
=
F16
;
using
KDataType
=
F16
;
using
PDataType
=
F16
;
using
VDataType
=
F16
;
using
RDataType
=
F16
;
using
GemmAccDataType
=
F32
;
using
QLayout
=
Row
;
using
KLayout
=
Col
;
using
PLayout
=
Row
;
using
VLayout
=
Row
;
using
RLayout
=
Row
;
using
QElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
KElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
VElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
RElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
//static constexpr auto GemmSpecialization =
// ck::tensor_operation::device::GemmSpecialization::Default;
using
ReferenceGemmInstanceQKP
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
QDataType
,
KDataType
,
PDataType
,
QElementOp
,
KElementOp
,
PElementOp
>
;
using
ReferenceGemmInstancePVR
=
ck
::
tensor_operation
::
host
::
ReferenceBatchedGemm
<
PDataType
,
VDataType
,
RDataType
,
PElementOp
,
VElementOp
,
RElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
// GEMM shape
ck
::
index_t
N_
=
1024
;
ck
::
index_t
d_
=
64
;
#if 0
ck::index_t M_QKP = N_;
ck::index_t N_QKP = N_;
ck::index_t K_QKP = d_;
ck::index_t M_PVR = N_;
ck::index_t N_PVR = d_;
ck::index_t K_PVR = N_;
ck::index_t StrideQ = d_;
ck::index_t StrideK = d_;
ck::index_t StrideP = N_;
ck::index_t StrideV = d_;
ck::index_t StrideR = d_;
#endif
ck
::
index_t
BatchCount
=
8
*
12
;
if
(
argc
==
1
)
{
// do nothing
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
if
(
argc
==
7
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
N_
=
std
::
stoi
(
argv
[
4
]);
d_
=
std
::
stoi
(
argv
[
5
]);
BatchCount
=
std
::
stoi
(
argv
[
6
]);
#if 0
M_QKP = N_;
N_QKP = N_;
K_QKP = d_;
M_PVR = N_;
N_PVR = d_;
K_PVR = N_;
StrideQ = d_;
StrideK = d_;
StrideP = N_;
StrideV = d_;
StrideR = d_;
#endif
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: run kernel # of times (>1)
\n
"
);
printf
(
"arg4 to 6: S (256x), d(128x), BatchCount(32x)
\n
"
);
exit
(
0
);
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
batch_count
,
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
row
*
stride
,
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
batch_count
,
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
col
*
stride
,
1
,
stride
}));
}
};
Tensor
<
QDataType
>
q_g_n_d
(
f_host_tensor_descriptor
(
BatchCount
,
N_
,
d_
,
d_
,
QLayout
{}));
Tensor
<
KDataType
>
k_g_d_n
(
f_host_tensor_descriptor
(
BatchCount
,
d_
,
N_
,
d_
,
KLayout
{}));
Tensor
<
PDataType
>
p_g_n_n
(
f_host_tensor_descriptor
(
BatchCount
,
N_
,
N_
,
N_
,
PLayout
{}));
Tensor
<
VDataType
>
v_g_n_d
(
f_host_tensor_descriptor
(
BatchCount
,
N_
,
d_
,
d_
,
VLayout
{}));
Tensor
<
RDataType
>
r_g_n_d_host_result
(
f_host_tensor_descriptor
(
BatchCount
,
N_
,
d_
,
d_
,
RLayout
{}));
Tensor
<
RDataType
>
r_g_n_d_device_result
(
f_host_tensor_descriptor
(
BatchCount
,
N_
,
d_
,
d_
,
RLayout
{}));
std
::
cout
<<
"q_g_n_d: "
<<
q_g_n_d
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"k_g_d_n: "
<<
k_g_d_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"p_g_n_n: "
<<
p_g_n_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"v_g_n_d: "
<<
v_g_n_d
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"r_g_n_d: "
<<
r_g_n_d_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"time kernel: "
<<
time_kernel
<<
std
::
endl
;
switch
(
init_method
)
{
case
0
:
break
;
case
1
:
q_g_n_d
.
GenerateTensorValue
(
GeneratorTensor_2
<
QDataType
>
{
-
5
,
5
});
k_g_d_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
KDataType
>
{
-
5
,
5
});
v_g_n_d
.
GenerateTensorValue
(
GeneratorTensor_2
<
VDataType
>
{
-
5
,
5
});
break
;
default:
q_g_n_d
.
GenerateTensorValue
(
GeneratorTensor_3
<
QDataType
>
{
0.0
,
1.0
});
k_g_d_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
KDataType
>
{
-
0.5
,
0.5
});
v_g_n_d
.
GenerateTensorValue
(
GeneratorTensor_3
<
VDataType
>
{
-
0.5
,
0.5
});
break
;
}
auto
q_element_op
=
QElementOp
{};
auto
k_element_op
=
KElementOp
{};
auto
v_element_op
=
VElementOp
{};
auto
p_element_op
=
PElementOp
{};
auto
r_element_op
=
RElementOp
{};
DeviceMem
q_device_buf
(
sizeof
(
QDataType
)
*
q_g_n_d
.
mDesc
.
GetElementSpace
());
DeviceMem
k_device_buf
(
sizeof
(
KDataType
)
*
k_g_d_n
.
mDesc
.
GetElementSpace
());
DeviceMem
v_device_buf
(
sizeof
(
VDataType
)
*
v_g_n_d
.
mDesc
.
GetElementSpace
());
DeviceMem
r_device_buf
(
sizeof
(
RDataType
)
*
r_g_n_d_device_result
.
mDesc
.
GetElementSpace
());
q_device_buf
.
ToDevice
(
q_g_n_d
.
mData
.
data
());
k_device_buf
.
ToDevice
(
k_g_d_n
.
mData
.
data
());
v_device_buf
.
ToDevice
(
v_g_n_d
.
mData
.
data
());
// bool pass = true;
if
(
do_verification
)
{
auto
ref_batched_gemmQKP
=
ReferenceGemmInstanceQKP
{};
auto
ref_invokerQKP
=
ref_batched_gemmQKP
.
MakeInvoker
();
auto
ref_argumentQKP
=
ref_batched_gemmQKP
.
MakeArgument
(
q_g_n_d
,
k_g_d_n
,
p_g_n_n
,
q_element_op
,
k_element_op
,
p_element_op
);
auto
ref_batched_gemmPVR
=
ReferenceGemmInstancePVR
{};
auto
ref_invokerPVR
=
ref_batched_gemmPVR
.
MakeInvoker
();
auto
ref_argumentPVR
=
ref_batched_gemmPVR
.
MakeArgument
(
p_g_n_n
,
v_g_n_d
,
r_g_n_d_host_result
,
p_element_op
,
v_element_op
,
r_element_op
);
ref_invokerQKP
.
Run
(
ref_argumentQKP
);
ref_invokerPVR
.
Run
(
ref_argumentPVR
);
}
}
example/CMakeLists.txt
View file @
6985af40
...
...
@@ -45,3 +45,4 @@ add_subdirectory(23_softmax)
add_subdirectory
(
24_batched_gemm_c_permute
)
add_subdirectory
(
25_gemm_bias_c_permute
)
add_subdirectory
(
26_contraction
)
add_subdirectory
(
27_gemm_gemm
)
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_c_shuffle.hpp
0 → 100644
View file @
6985af40
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
struct
BatchedGemmGemmCShuffleDesc
{
ck
::
index_t
G0_
,
G1_
,
M_
,
N_
;
ck
::
index_t
stride_G0_
,
stride_G1_
,
stride_M_
,
stride_N_
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceBatchedGemmGemmCShuffle
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_A
,
index_t
stride_B
,
BatchedGemmCPermuteDesc
batched_gemm_c_permute_desc
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
BatchCount
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceBatchedGemmCPermutePtr
=
std
::
unique_ptr
<
DeviceBatchedGemmCPermute
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_c_shuffle.hpp
0 → 100644
View file @
6985af40
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_c_shuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp"
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
QGridDesc_K0_M_K1
,
typename
KGridDesc_K0_N_K1
,
typename
VGridDesc_K0_N_K1
,
typename
RGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
QElementwiseOperation
,
typename
KElementwiseOperation
,
typename
VElementwiseOperation
,
typename
PElementwiseOperation
,
typename
RElementwiseOperation
,
typename
ComputePtrOffsetOfBatch
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_gemm_c_shuffle_xdl
(
const
FloatAB
*
__restrict__
p_q_grid
,
const
FloatAB
*
__restrict__
p_k_grid
,
const
FloatAB
*
__restrict__
p_v_grid
,
FloatC
*
__restrict__
p_o_grid
,
const
index_t
batch_count
,
const
QGridDesc_K0_M_K1
q_grid_desc_k0_m_k1
,
const
KGridDesc_K0_N_K1
k_grid_desc_k0_n_k1
,
const
VGridDesc_K0_N_K1
v_grid_desc_k0_n_k1
,
const
RGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
r_grid_desc_mblock_mperblock_nblock_nperblock
,
const
QElementwiseOperation
q_element_op
,
const
KElementwiseOperation
k_element_op
,
const
VElementwiseOperation
v_element_op
,
const
PElementwiseOperation
p_element_op
,
const
RElementwiseOperation
r_element_op
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
q_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetQPtrOffset
(
g_idx
)));
const
long_index_t
k_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetKPtrOffset
(
g_idx
)));
const
long_index_t
v_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetVPtrOffset
(
g_idx
)));
const
long_index_t
o_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetRPtrOffset
(
g_idx
)));
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_q_grid
+
a_batch_offset
,
p_k_grid
+
k_batch_offset
,
p_v_grid
+
v_batch_offset
,
ck
::
Tuple
<>
{},
p_o_grid
+
o_batch_offset
,
p_shared
,
q_element_op
,
k_element_op
,
v_element_op
,
p_element_op
,
r_element_op
,
q_grid_desc_k0_m_k1
,
k_grid_desc_k0_n_k1
,
v_grid_desc_k0_n_k1
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
0
>
{},
r_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
#else
ignore
=
p_q_grid
;
ignore
=
p_v_grid
;
ignore
=
p_k_grid
;
ignore
=
p_o_grid
;
ignore
=
batch_count
;
ignore
=
q_grid_desc_k0_m_k1
;
ignore
=
k_grid_desc_k0_m_k1
ignore
=
v_grid_desc_k0_n_k1
;
ignore
=
r_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
q_element_op
;
ignore
=
k_element_op
;
ignore
=
v_element_op
;
ignore
=
p_element_op
;
ignore
=
r_element_op
;
ignore
=
compute_ptr_offset_of_batch
;
ignore
=
block_2_ctile_map
;
#endif
}
template
<
typename
QLayout
,
typename
KLayout
,
typename
VLayout
,
typename
RLayout
,
typename
KDataType
,
typename
QDataType
,
typename
VDataType
,
typename
ODataType
,
typename
AccDataType
,
typename
KElementwiseOperation
,
typename
QElementwiseOperation
,
typename
VElementwiseOperation
,
typename
PElementwiseOperation
,
typename
OlementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumPrefetch
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
QK1
,
ck
::
index_t
KK1
,
ck
::
index_t
VK1
,
ck
::
index_t
QKMPerXDL
,
ck
::
index_t
QKNPerXDL
,
ck
::
index_t
QKMXdlPerWave
,
ck
::
index_t
QKNXdlPerWave
,
ck
::
index_t
PVMPerXDL
,
ck
::
index_t
PVNPerXDL
,
ck
::
index_t
PVMXdlPerWave
,
ck
::
index_t
PVNXdlPerWave
,
typename
QBlockTransferThreadClusterLengths_K0_M_K1
,
typename
QBlockTransferThreadClusterArrangeOrder
,
typename
QBlockTransferSrcAccessOrder
,
ck
::
index_t
QBlockTransferSrcVectorDim
,
ck
::
index_t
QBlockTransferSrcScalarPerVector
,
ck
::
index_t
QBlockTransferDstScalarPerVector_K1
,
ck
::
index_t
QBlockLdsAddExtraM
,
typename
KBlockTransferThreadClusterLengths_K0_N_K1
,
typename
KBlockTransferThreadClusterArrangeOrder
,
typename
KBlockTransferSrcAccessOrder
,
ck
::
index_t
KBlockTransferSrcVectorDim
,
ck
::
index_t
KBlockTransferSrcScalarPerVector
,
ck
::
index_t
KBlockTransferDstScalarPerVector_K1
,
ck
::
index_t
KBlockLdsAddExtraN
,
typename
VBlockTransferThreadClusterLengths_K0_N_K1
,
typename
VBlockTransferThreadClusterArrangeOrder
,
typename
VBlockTransferSrcAccessOrder
,
ck
::
index_t
VBlockTransferSrcVectorDim
,
ck
::
index_t
VBlockTransferSrcScalarPerVector
,
ck
::
index_t
VBlockTransferDstScalarPerVector_K1
,
ck
::
index_t
VBlockLdsAddExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceBatchedGemmGemmCShuffleXdl
:
public
DeviceBatchedGemmGemmCShuffle
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
using
DeviceOp
=
DeviceBatchedGemmGemmCShuffleXdl
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
auto
MakeQGridDescriptor_QK0_M_QK1
(
index_t
M
,
index_t
K
,
index_t
StrideQ
)
{
// not pad M or K
assert
(
K
%
QK1
==
0
);
const
auto
QK0
=
K
/
QK1
;
const
auto
q_grid_desc_k0_m_k1
=
[
&
](){
return
make_naive_tensor_descriptor
(
make_tuple
(
QK0
,
M
,
QK1
),
make_tuple
(
M
*
QK1
,
QK1
,
I1
));
}
return
q_grid_desc_qk0_m_qk1
;
}
static
auto
MakeKGridDescriptor_KK0_N_KK1
(
index_t
N
,
index_t
K
,
index_t
StrideK
)
{
// not pad M or K
assert
(
K
%
KK1
==
0
);
const
auto
KK0
=
K
/
KK1
;
const
auto
k_grid_desc_kk0_n_kk1
=
make_naive_tensor_descriptor
(
make_tuple
(
KK0
,
N
,
KK1
),
make_tuple
(
KK1
*
N
,
KK1
,
I1
));
return
k_grid_desc_kk0_n_kk1
;
}
static
auto
MakeVGridDescriptor_VK0_N_VK1
(
index_t
N
,
index_t
K
,
index_t
StrideV
)
{
// not pad M or K
assert
(
K
%
VK1
==
0
);
const
auto
VK0
=
K
/
VK1
;
const
auto
v_grid_desc_vk0_n_vk1
=
make_naive_tensor_descriptor
(
make_tuple
(
VK0
,
N
,
VK1
),
make_tuple
(
VK1
*
N
,
VK1
,
I1
));
return
v_grid_desc_vk0_n_vk1
;
}
static
auto
MakeOGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
stride_M
,
index_t
stride_N
)
{
const
auto
o_grid_desc_mraw_nraw
=
[
&
]()
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
stride_M
,
stride_N
));
}();
return
o_grid_desc_mraw_nraw
;
}
using
QGridDesc_K0_M_K1
=
decltype
(
MakeQGridDescriptor_QK0_M_QK1
(
1
,
1
,
1
));
using
KGridDesc_K0_N_K1
=
decltype
(
MakeKGridDescriptor_KK0_N_KK1
(
1
,
1
,
1
));
using
VGridDesc_K0_N_K1
=
decltype
(
MakeVGridDescriptor_VK0_N_VK1
(
1
,
1
,
1
));
using
OGridDesc_M_N
=
decltype
(
MakeOGridDescriptor_M_N
(
1
,
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmGemmXdlopsSkipLdsV1
<
BlockSize
,
KDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
ODataType
,
InMemoryDataOperationEnum
::
Set
,
QGridDesc_K0_M_K1
,
KGridDesc_K0_N_K1
,
VGridDesc_K0_N_K1
,
OGridDesc_M_N
,
QElementwiseOperation
,
KElementwiseOperation
,
VElementwiseOperation
,
PElementwiseOperation
OElementwiseOperation
,
QKMPerBlock
,
QKNPerBlock
,
QKMPerXDL
,
QKNPerXDL
,
PVMPerBlock
,
PVNPerBlock
,
PVMPerXDL
,
PVNPerXDL
,
KPerBlock
,
QK1
,
KK1
,
VK1
,
QKMXdlPerWave
,
QKNXdlPerWave
,
PVMXdlPerWave
,
PVNXdlPerWave
,
KBlockTransferThreadClusterLengths_K0_N_K1
,
KBlockTransferThreadClusterArrangeOrder
,
KBlockTransferSrcAccessOrder
,
KBlockTransferSrcVectorDim
,
KBlockTransferSrcScalarPerVector
,
KBlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
KBlockLdsAddExtraM
,
VBlockTransferThreadClusterLengths_K0_N_K1
,
VBlockTransferThreadClusterArrangeOrder
,
VBlockTransferSrcAccessOrder
,
VBlockTransferSrcVectorDim
,
VBlockTransferSrcScalarPerVector
,
VBlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
VBlockLdsAddExtraM
,
QBlockTransferSrcScalarPerVector
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
QDataType
*
p_q_grid
,
const
KDataType
*
p_k_grid
,
const
VDataType
*
p_v_grid
,
ODataType
*
p_o_grid
,
index_t
QKM
,
index_t
QKN
,
index_t
PVM
,
index_t
PVN
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
M01
,
index_t
N01
,
QElementwiseOperation
q_element_op
,
KElementwiseOperation
k_element_op
,
VElementwiseOperation
v_element_op
,
PElementwiseOperation
p_element_op
,
OElementwiseOperation
o_element_op
)
:
p_q_grid_
{
p_q_grid
},
p_k_grid_
{
p_k_grid
},
p_v_grid_
{
p_v_grid
},
p_o_grid_
{
p_o_grid
}
q_grid_desc_k0_m_k1_
{},
k_grid_desc_k0_n_k1_
{},
o_grid_desc_m_n_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
q_element_op_
{
q_element_op
},
k_element_op_
{
k_element_op
},
p_element_op_
{
p_element_op
},
v_element_op_
{
v_element_op
},
o_element_op_
{
o_element_op
},
{
q_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeQGridDescriptor_QK0_M_QK1
(
QKM
,
K
,
StrideA
);
k_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeKGridDescriptor_KK0_N_KK1
(
QKN
,
K
,
StrideB
);
v_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeVGridDescriptor_VK0_N_VK1
(
PVN
,
K
,
StrideB
);
o_grid_desc_m_n_
=
DeviceOp
::
MakeOGridDescriptor_M_N
(
PVM
,
PVN
,
StrideC
);
if
(
GridwiseGemm
::
CheckValidity
(
q_grid_desc_k0_m_k1_
,
k_grid_desc_k0_n_k1_
,
o_grid_desc_m_n_
,
M01_
,
N01_
))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
o_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
o_grid_desc_m_n_
,
M01
,
N01
);
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_
=
GridwiseGemm
::
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
k_grid_desc_k0_n_k1_
);
}
}
// private:
const
QDataType
*
p_q_grid_
;
const
KDataType
*
p_k_grid_
;
const
VDataType
*
p_v_grid_
;
ODataType
*
p_o_grid_
;
QGridDesc_K0_M_K1
q_grid_desc_k0_m_k1_
;
KGridDesc_K0_N_K1
k_grid_desc_k0_n_k1_
;
VGridDesc_K0_N_K1
v_grid_desc_k0_n_k1_
;
OGridDesc_M_N
o_grid_desc_m_n_
;
typename
GridwiseGemm
::
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_
;
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
QElementwiseOperation
q_element_op_
;
KElementwiseOperation
k_element_op_
;
PElementwiseOperation
p_element_op_
;
VElementwiseOperation
v_element_op_
;
OElementwiseOperation
o_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
std
::
cout
<<
"arg.q_grid_desc_k0_m_k1_{"
<<
arg
.
q_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
q_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
q_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.k_grid_desc_k0_n_k1_{"
<<
arg
.
k_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
k_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
k_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.o_grid_desc_m_n_{ "
<<
arg
.
o_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
o_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
q_grid_desc_k0_m_k1_
,
arg
.
k_grid_desc_k0_n_k1_
,
arg
.
o_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
o_grid_desc_m_n_
);
const
auto
K0
=
arg
.
q_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_skip_b_lds_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
q_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_skip_b_lds_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
q_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
q_grid_desc_k0_m_k1_
,
arg
.
k_grid_desc_k0_n_k1_
,
arg
.
o_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
/* KBatch */
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceOp"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
\ No newline at end of file
include/ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp
0 → 100644
View file @
6985af40
#pragma once
#include "common_header.hpp"
#include "multi_index_transform_helper.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_xdlops.hpp"
#include "blockwise_gemm_xdlops_skip_b_lds.hpp"
#include "thread_group_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "gridwise_gemm_pipeline_v1.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
FloatC
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
B0GridDesc_K0_N_K1
,
typename
B1GridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
C0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1ElementwiseOperation
,
index_t
M0PerBlock
,
index_t
N0PerBlock
,
index_t
M0PerXDL
,
index_t
N0PerXDL
,
index_t
M1PerBlock
,
index_t
N1PerBlock
,
index_t
M1PerXDL
,
index_t
N1PerXDL
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
B0K1
,
index_t
B1K1
,
index_t
M0XdlPerWave
,
index_t
N0XdlPerWave
,
index_t
M1XdlPerWave
,
index_t
N1XdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
typename
B1BlockTransferThreadClusterLengths_K0_M_K1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_K1
,
bool
B1ThreadTransferSrcResetCoordinateAfterRun
,
bool
B1BlockLdsExtraM
,
index_t
B0BlockTransferSrcScalarPerVector
,
bool
B0ThreadTransferSrcResetCoordinateAfterRun
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseGemmGemmXdlopsSkipLdsV1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
BaseMultK0
=
2
;
static
constexpr
auto
MultiK0
=
BaseMultK0
*
1
;
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
AK1
>
{};
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXDL
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
K1
>
{};
static
constexpr
index_t
K0PerThread
=
K0PerBlock
/
xdlops_gemm
.
K0PerXdlops
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
*
MultiK0
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
*
MultiK0
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
return
a_block_desc_k0_m_k1
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
K1
)
>>::
value
,
"wrong! K1 need to be known at compile-time"
);
static_assert
((
MPerBlock
%
(
MPerXDL
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXDL
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_k0_m_k1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
return
false
;
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K0
/
K0PerBlock
)
%
MultiK0
==
0
))
{
return
false
;
}
// check M01, N01
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
if
(
!
(
M0
%
M01
==
0
&&
N0
%
N01
==
0
))
return
false
;
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
}
__host__
__device__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
index_t
grid_size
=
(
M
/
MPerBlock
)
*
(
N
/
NPerBlock
);
return
grid_size
;
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
{
const
bool
has_main_k0_block_loop
=
(
K0
/
(
MultiK0
*
K0PerBlock
))
>
1
;
return
has_main_k0_block_loop
;
}
__host__
__device__
static
constexpr
auto
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
)
{
const
auto
K0
=
b_grid_desc_k0_n_k1
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_k0_n_k1
.
GetLength
(
I1
);
const
auto
b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1
=
transform_tensor_descriptor
(
b_grid_desc_k0_n_k1
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
/
K0PerBlock
,
xdlops_gemm
.
K0PerXdlops
,
K0PerThread
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NXdlPerWave
*
NWaves
*
NPerXDL
),
NXdlPerWave
,
NWaves
,
NPerXDL
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
,
6
>
{},
Sequence
<
7
>
{}));
return
b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
GetWaveKNIdx
(
const
index_t
thread_id
)
{
constexpr
auto
wave_threadid_to_nk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
xdlops_gemm
.
K0PerXdlops
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
wave_threadid_to_nk_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
}
}();
// B matrix threadwise copy
constexpr
auto
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
Number
<
K0PerThread
>
{},
// K0PerThread
I1
,
// NBlockId
Number
<
NXdlPerWave
>
{},
// repeat
I1
,
// waves
I1
,
// NPerXdlops
Number
<
K1
>
{}));
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
),
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
K1
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
,
index_t
N01
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
M1
=
Number
<
MPerBlock
>
{};
constexpr
auto
N1
=
Number
<
NPerBlock
>
{};
const
auto
M0
=
M
/
M1
;
const
auto
N0
=
N
/
N1
;
const
auto
M00
=
M0
/
M01
;
const
auto
N00
=
N0
/
N01
;
const
auto
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M00
,
M01
)),
make_unmerge_transform
(
make_tuple
(
N00
,
N01
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}));
const
auto
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M00
,
N00
,
M01
,
N01
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
cblockid_to_m0_n0_block_cluster_adaptor
=
chain_tensor_adaptors
(
m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor
,
cblockid_to_m00_m01_n00_n01_block_cluster_adaptor
);
return
cblockid_to_m0_n0_block_cluster_adaptor
;
}
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
=
decltype
(
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
BGridDesc_K0_N_K1
{}));
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
*
MultiK0
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
1
>
(
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ignore
=
b_element_op
;
// B matrix threadwise copy
constexpr
auto
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
Number
<
K0PerThread
>
{},
// K0PerThread
I1
,
// NBlockId
Number
<
NXdlPerWave
>
{},
// repeat
I1
,
// waves
I1
,
// NPerXdlops
Number
<
K1
>
{}));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
(),
true
>
b_thread_1st_buf
,
b_thread_2nd_buf
,
b_thread_3rd_buf
,
b_thread_4th_buf
;
const
auto
wave_id
=
GetWaveIdx
();
const
auto
wave_k_n_id
=
GetWaveKNIdx
(
wave_id
[
I2
]);
#if 0
const index_t block_id = get_block_1d_id();
const index_t thread_id = get_thread_local_1d_id();
printf("block id: %d m blockid: %d n block id: %d ,thread id: %d, wave id :{%d %d %d} "
"kn id: {%d %d}\n",
block_id,
block_work_idx[I0],
block_work_idx[I1],
thread_id,
wave_id[I0],
wave_id[I1],
wave_id[I2],
wave_k_n_id[I0],
wave_k_n_id[I1]);
printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t",
xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0));
#endif
auto
b_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
),
decltype
(
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
),
Sequence
<
I1
,
I1
,
Number
<
K0PerThread
>
{},
I1
,
Number
<
NXdlPerWave
>
{},
I1
,
I1
,
Number
<
K1
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
BBlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_multi_index
(
0
,
wave_k_n_id
[
I0
],
0
,
block_work_idx
[
I1
],
0
,
wave_id
[
I1
],
wave_k_n_id
[
I1
],
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
),
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
K1
>
{};
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
// gridwise GEMM pipeline
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
*
MultiK0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
// preload data to regiester and LDS
{
// Read
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_1st_buf
);
// Move
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// a data write to lds
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_buf
);
// load 2nd a matrix data
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_2nd_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
// main body
if
constexpr
(
HasMainK0BlockLoop
)
{
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
(
MultiK0
*
K0PerBlock
));
index_t
i
=
0
;
do
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m_k1
,
a_grid_buf
);
blockwise_gemm
.
ResetABlockStartWindow
();
block_sync_lds
();
static_for
<
0
,
MultiK0
,
BaseMultK0
>
{}([
&
](
auto
)
{
// 1st
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_3rd_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_1st_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
s_nop
();
// 2nd
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_4th_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_2nd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
s_nop
();
// 3rd
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_1st_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_3rd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
s_nop
();
// 4th
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_2nd_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_4th_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
});
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m_k1
,
a_block_buf
);
// move a and b window
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m_k1
,
a_block_slice_copy_step
);
i
+=
1
;
}
while
(
i
<
(
K0BlockMainLoop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
ResetABlockStartWindow
();
static_for
<
0
,
MultiK0
,
BaseMultK0
>
{}([
&
](
auto
i
)
{
// 1st
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_3rd_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_1st_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
// 2nd
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_4th_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_2nd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
// 3rd
if
constexpr
(
i
<
MultiK0
-
BaseMultK0
)
{
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_1st_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
}
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_3rd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
// 4th
if
constexpr
(
i
<
MultiK0
-
BaseMultK0
)
{
b_threadwise_copy
.
Run
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_grid_buf
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
b_thread_2nd_buf
);
b_threadwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_slice_copy_step
);
}
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_4th_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveABlockSliceWindow
();
});
}
}
// output: register to global memory
{
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetLength
(
I7
);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_grid
=
m_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_grid
));
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
CElementwiseOperation
,
Sequence
<
M0
,
N0
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
n_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
m_thread_data_on_grid_idx
[
I3
],
m_thread_data_on_grid_idx
[
I4
],
n_thread_data_on_grid_idx
[
I2
]),
c_element_op
};
c_thread_copy
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_grid_buf
);
}
}
};
}
// namespace ck
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment