Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d78b9359
Commit
d78b9359
authored
Jun 21, 2023
by
Jing Zhang
Browse files
simple blockwise gemm
parent
956465c6
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
522 additions
and
386 deletions
+522
-386
example/01_gemm/gemm_dl_fp16.cpp
example/01_gemm/gemm_dl_fp16.cpp
+5
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
...ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
+89
-94
include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
...de/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
+40
-162
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+349
-21
include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp
.../tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp
+33
-107
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+6
-0
No files found.
example/01_gemm/gemm_dl_fp16.cpp
View file @
d78b9359
...
...
@@ -27,8 +27,11 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 16, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmMNPadding
,
64
,
16
,
64
,
16
,
2
,
1
,
4
,
1
,
S
<
2
,
4
>
,
S
<
2
,
4
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
4
,
1
,
16
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
4
,
1
,
16
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
>
;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 256, 128, 128, 16, 2, 8, 8, 1, S<1, 16>, S<1, 16>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<2, 1, 4, 2>, S<8, 1, 32, 1>, S<0, 3, 1, 2>, S<0, 3, 1, 2>, S<1, 1, 4, 1>, S<0, 3, 1, 2>, S<1, 1, 4, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 64, 64, 16, 2, 8, 8, 1, S<1, 8>, S<1, 8>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 64, 64, 16, 2, 4, 4, 1, S<1, 8>, S<1, 8>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 16, 64, 16, 2, 1, 4, 1, S<2, 4>, S<2, 4>, S<4, 1, 1, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 4, 2>, S<4, 1, 16, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 1>;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmMNPadding
,
64
,
16
,
64
,
16
,
2
,
2
,
8
,
1
,
S
<
1
,
8
>
,
S
<
1
,
8
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
4
,
1
,
16
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
4
,
1
,
4
,
2
>
,
S
<
4
,
1
,
16
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
1
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
View file @
d78b9359
...
...
@@ -4,8 +4,8 @@
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#include "common_header.hpp"
#include "threadwise_gemm_dlops_v3.hpp"
#include "
ck/utility/
common_header.hpp"
#include "
ck/tensor_operation/gpu/thread/
threadwise_gemm_dlops_v3.hpp"
namespace
ck
{
...
...
@@ -13,11 +13,11 @@ template <index_t BlockSize,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc_
E1_K1_E2
,
typename
BBlockDesc_
E1
_N_
Ho_Wo_E2
,
typename
CThreadDesc_K_N_Ho_Wo
,
index_t
E
PerThread
Loop
,
index_t
KPer
Thread
Loop
>
typename
ABlockDesc_
K0_M_K1
,
typename
BBlockDesc_
K0
_N_
K1
,
index_t
MPerThread
,
index_t
N
PerThread
,
index_t
K
0
PerLoop
>
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -26,105 +26,91 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
E1
=
ABlockDesc_
E1_K1_E2
{}.
GetLength
(
I0
);
static
constexpr
auto
KPerBlock
=
ABlockDesc_
E1_K1_E2
{}.
GetLength
(
I1
);
static
constexpr
auto
E2
=
ABlockDesc_
E1_K1_E2
{}.
GetLength
(
I2
);
static
constexpr
auto
K0
=
ABlockDesc_
K0_M_K1
{}.
GetLength
(
I0
);
static
constexpr
auto
M
=
ABlockDesc_
K0_M_K1
{}.
GetLength
(
I1
);
static
constexpr
auto
K1
=
ABlockDesc_
K0_M_K1
{}.
GetLength
(
I2
);
static
constexpr
auto
HoPerBlock
=
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I2
);
static
constexpr
auto
WoPerBlock
=
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I3
);
static
constexpr
auto
N
=
BBlockDesc_K0_N_K1
{}.
GetLength
(
I1
);
static
constexpr
auto
KPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I0
);
static
constexpr
auto
HoPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I2
);
static
constexpr
auto
WoPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I3
);
static
constexpr
auto
M0
=
M
/
MPerThread
;
static
constexpr
auto
M1
=
MPerThread
;
static
constexpr
auto
N0
=
N
/
NPerThread
;
static
constexpr
auto
N1
=
NPerThread
;
static
constexpr
auto
a_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
E
Per
Thread
Loop
>
{},
Number
<
K
PerThread
Loop
>
{},
Number
<
E2
>
{}));
make_tuple
(
Number
<
K0
PerLoop
>
{},
Number
<
M
PerThread
>
{},
Number
<
K1
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{},
Number
<
E2
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0PerLoop
>
{},
Number
<
NPerThread
>
{},
Number
<
K1
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
I1
>
{},
Number
<
M
1
>
{},
Number
<
I1
>
{},
Number
<
N1
>
{}));
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
:
c_thread_origin_data_idx_
{
GetBeginOfCThreadDesc_K_N_Ho_Wo
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
]
*
KPerThread
,
0
)}
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I1
]
*
MPerThread
,
0
)},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I3
]
*
NPerThread
,
0
)}
{
static_assert
(
ABlockDesc_E1_K1_E2
::
IsKnownAtCompileTime
()
&&
BBlockDesc_E1_N_Ho_Wo_E2
::
IsKnownAtCompileTime
()
&&
CThreadDesc_K_N_Ho_Wo
::
IsKnownAtCompileTime
(),
static_assert
(
ABlockDesc_K0_M_K1
::
IsKnownAtCompileTime
()
&&
BBlockDesc_K0_N_K1
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I0
)
==
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I0
)
&&
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I2
)
==
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I4
),
static_assert
(
ABlockDesc_K0_M_K1
{}.
GetLength
(
I0
)
==
BBlockDesc_K0_N_K1
{}.
GetLength
(
I0
)
&&
ABlockDesc_K0_M_K1
{}.
GetLength
(
I2
)
==
BBlockDesc_K0_N_K1
{}.
GetLength
(
I2
),
"wrong! E dimension not consistent
\n
"
);
static_assert
(
E1
%
EPerThreadLoop
==
0
,
""
);
static_assert
(
KPerThread
%
KPerThreadLoop
==
0
,
""
);
static_assert
(
K0
%
K0PerLoop
==
0
,
""
);
static_assert
(
KPerBlock
%
KPerThread
==
0
&&
HoPerBlock
%
HoPerThread
==
0
&&
WoPerBlock
%
WoPerThread
==
0
,
static_assert
(
M
%
MPerThread
==
0
&&
N
%
NPerThread
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
constexpr
auto
KThreadCluster
=
KPerBlock
/
KPerThread
;
constexpr
auto
HThreadCluster
=
HoPerBlock
/
HoPerThread
;
constexpr
auto
WThreadCluster
=
WoPerBlock
/
WoPerThread
;
static_assert
(
BlockSize
==
KThreadCluster
*
HThreadCluster
*
WThreadCluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockSize
==
M0
*
N0
,
"wrong! wrong blocksize
\n
"
);
}
__device__
static
constexpr
auto
GetCThread
Desc_K_N_Ho_WoLengths
()
__device__
static
constexpr
auto
GetCThread
TensorLengths_BM0_BM1_BN0_BN1
()
{
return
Sequence
<
KPerThread
,
I1
,
HoPerThread
,
WoPerThread
>
{};
return
Sequence
<
I1
,
M1
,
I1
,
N1
>
{};
}
__device__
static
CIndex
GetBeginOfCThreadDesc_K_N_Ho_Wo
(
index_t
thread_id
)
__device__
static
CIndex
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
index_t
thread_id
)
{
constexpr
auto
K0
=
KPerBlock
/
KPerThread
;
constexpr
auto
N0
=
I1
;
constexpr
auto
H0
=
HoPerBlock
/
HoPerThread
;
constexpr
auto
W0
=
WoPerBlock
/
WoPerThread
;
constexpr
auto
c_threadid_to_k_n_h_w_thread_cluster_adaptor
=
constexpr
auto
c_threadid_to_m0_m1_n0_n1_thread_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
N
0
,
H0
,
W
0
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
I1
,
M
0
,
I1
,
N
0
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_
k_n_h_w
_thread_cluster_idx
=
c_threadid_to_
k_n_h_w
_thread_cluster_adaptor
.
CalculateBottomIndex
(
const
auto
c_
m0_m1_n0_n1
_thread_cluster_idx
=
c_threadid_to_
m0_m1_n0_n1
_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
return
c_
k_n_h_w
_thread_cluster_idx
;
return
c_
m0_m1_n0_n1
_thread_cluster_idx
;
}
template
<
typename
ABlockBuffer
,
typename
B
Thread
Buffer
,
typename
CThreadBuffer
>
template
<
typename
ABlockBuffer
,
typename
B
Block
Buffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
B
Thread
Buffer
&
b_
thread
_buf
,
const
B
Block
Buffer
&
b_
block
_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABlockBuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
B
Thread
Buffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
B
Block
Buffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
a_block_mtx
=
ABlockDesc_E1_K1_E2
{};
constexpr
auto
a_block_mtx
=
ABlockDesc_K0_M_K1
{};
constexpr
auto
b_block_mtx
=
BBlockDesc_K0_N_K1
{};
// thread A buffer for GEMM
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
(),
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
,
b_thread_mtx_
.
GetElementSpaceSize
(),
true
>
b_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
FloatB
,
FloatC
,
...
...
@@ -132,46 +118,55 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
decltype
(
b_thread_mtx_
),
decltype
(
c_thread_mtx_
)
>
{};
static_for
<
0
,
E1
,
EPerThreadLoop
>
{}([
&
](
auto
e_begin
)
{
static_for
<
0
,
KPerThread
,
KPerThreadLoop
>
{}([
&
](
auto
k_begin
)
{
static_for
<
0
,
K0
,
K0PerLoop
>
{}([
&
](
auto
k0_begin
)
{
a_thread_copy_
.
Run
(
a_block_mtx
,
make_tuple
(
e
_begin
,
k_begin
,
I0
),
make_tuple
(
k0
_begin
,
I0
,
I0
),
a_block_buf
,
a_thread_mtx_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
b_block_mtx
,
make_tuple
(
k0_begin
,
I0
,
I0
),
b_block_buf
,
b_thread_mtx_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
e_begin
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
k_begin
,
I0
,
I0
,
I0
));
});
make_tuple
(
I0
,
I0
,
I0
,
I0
));
});
}
template
<
typename
ABlockSliceMoveStepIdx
>
__device__
void
MoveABlockSliceWindow
(
const
ABlockSliceMoveStepIdx
&
a_block_slice_move_step_idx
)
{
a_thread_copy_
.
MoveSrcSliceWindow
(
ABlockDesc_E1_K1_E2
{},
a_block_slice_move_step_idx
);
}
private:
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc_
E1_K1_E2
,
ABlockDesc_
K0_M_K1
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadLoop
,
E2
>
,
Sequence
<
K0PerLoop
,
MPerThread
,
K1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
K1
,
K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BBlockDesc_K0_N_K1
,
decltype
(
b_thread_mtx_
),
Sequence
<
K0PerLoop
,
NPerThread
,
K1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
E2
,
E2
>
;
K1
,
K1
>
;
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp
View file @
d78b9359
...
...
@@ -371,120 +371,10 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
float
ave_time
=
0
;
#if 0
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
true,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
else
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
CDataType,
remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DefaultBlock2CTileMap>,
false,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.block_2_ctile_map_);
}
#else
if
(
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
true
,
true
>
;
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
...
...
@@ -504,11 +394,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
else
if
(
has_main_k_block_loop
&&
!
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
true
,
false
>
;
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
true
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
...
...
@@ -528,11 +414,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
else
if
(
!
has_main_k_block_loop
&&
has_double_tail_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
false
,
true
>
;
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
false
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
...
...
@@ -552,11 +434,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
else
{
const
auto
kernel
=
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
false
,
false
>
;
kernel_gemm_dl_v1r3
<
GridwiseGemm
,
ADataType
,
CDataType
,
false
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
...
...
@@ -573,7 +451,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
arg
.
StrideB_
,
arg
.
StrideC_
);
}
#endif
return
ave_time
;
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
d78b9359
...
...
@@ -10,6 +10,7 @@
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
...
...
@@ -198,7 +199,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
__host__
__device__
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
__host__
__device__
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
...
...
@@ -237,7 +239,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
__host__
__device__
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
...
...
@@ -333,7 +336,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
);
}
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_K0_M0_M1_K1
(
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
)
{
...
...
@@ -420,6 +422,322 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
decltype
(
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}));
#if 1
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_K0_M0_M1_K1
&
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
&
b_grid_desc_k0_n0_n1_k1
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
&
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
&
block_2_ctile_map
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
c_m0_n0_block_cluster_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
// HACK: this force index data into SGPR
const
index_t
im0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I0
]);
const
index_t
in0
=
__builtin_amdgcn_readfirstlane
(
c_m0_n0_block_cluster_idx
[
I1
]);
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
make_tuple
(
im0
,
in0
),
make_tuple
(
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I0
),
c_grid_desc_m0_m10_m11_n0_n10_n11
.
GetLength
(
I3
))))
{
return
;
}
// TODO: change this. I think it needs multi-dimensional alignment
constexpr
auto
max_lds_align
=
K1
;
// TODO: check alignment
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
a_block_desc_k0_m0_m1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
b_block_desc_k0_n0_n1_k1
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
I1
,
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// A matrix in LDS memory, for blockwise GEMM
constexpr
auto
a_k0_m_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
max_lds_align
);
// TODO: check alignment
// B matrix in LDS memory, for blockwise GEMM
constexpr
auto
b_k0_n_k1_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
static_assert
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
()
==
a_k0_m_k1_block_desc
.
GetElementSpaceSize
()
&&
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
()
==
b_k0_n_k1_block_desc
.
GetElementSpaceSize
()
&&
"wrong!"
);
// A matrix blockwise copy
auto
a_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
1
,
MPerBlock
,
K1
.
value
>
,
ABlockTransferThreadSliceLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
remove_reference_t
<
decltype
(
a_grid_desc_k0_m0_m1_k1
)
>
,
decltype
(
a_block_desc_k0_m0_m1_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1
,
// SrcVectorTensorLengths
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1
,
// DstVectorTensorLengths
ABlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>
,
// DstVectorTensorContiguousDimOrder
false
,
true
>
(
a_grid_desc_k0_m0_m1_k1
,
make_multi_index
(
0
,
im0
,
0
,
0
),
a_block_desc_k0_m0_m1_k1
,
make_multi_index
(
0
,
0
,
0
,
0
));
// B matrix blockwise copy
auto
b_blockwise_copy
=
BlockwiseTensorSliceTransfer_v5r1
<
BlockSize
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
1
,
NPerBlock
,
K1
.
value
>
,
BBlockTransferThreadSliceLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterLengths_K0_N0_N1_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
remove_reference_t
<
decltype
(
b_grid_desc_k0_n0_n1_k1
)
>
,
decltype
(
b_block_desc_k0_n0_n1_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
,
3
>
,
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1
,
// SrcVectorTensorLengths
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
// DstVectorTensorLengths
BBlockTransferSrcVectorTensorContiguousDimOrder
,
// SrcVectorTensorContiguousDimOrder
Sequence
<
0
,
1
,
2
,
3
>
,
// DstVectorTensorContiguousDimOrder
false
,
true
>
(
b_grid_desc_k0_n0_n1_k1
,
make_multi_index
(
0
,
in0
,
0
,
0
),
b_block_desc_k0_n0_n1_k1
,
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
index_t
MPerThread
=
M1PerThreadM111
;
constexpr
index_t
NPerThread
=
N1PerThreadN111
;
const
auto
blockwise_gemm
=
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
MPerThread
,
NPerThread
,
KPerThread
>
{};
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
constexpr
auto
c_thread_desc_m10_m11_n10_n11
=
make_naive_tensor_descriptor_packed
(
sequence_to_tuple_of_number
(
c_m10_m11_n10_n11_thread_tensor_lengths
));
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_aligned_space_size
=
math
::
integer_least_multiple
(
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_aligned_space_size
=
math
::
integer_least_multiple
(
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block_double
=
p_shared_block
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_aligned_space_size
;
// register allocation for output
auto
c_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
>
(
c_thread_desc_m10_m11_n10_n11
.
GetElementSpaceSize
());
// Initialize C
c_thread_buf
.
Clear
();
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
,
0
);
auto
a_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
,
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
b_block_even_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
,
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
auto
a_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block_double
+
a_block_aligned_space_size
,
a_block_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
auto
b_block_odd_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block_double
+
b_block_aligned_space_size
,
b_block_desc_k0_n0_n1_k1
.
GetElementSpaceSize
());
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
}
if
constexpr
(
HasMainKBlockLoop
)
{
const
auto
K0
=
a_grid_desc_k0_m0_m1_k1
.
GetLength
(
I0
);
index_t
k_block_data_begin
=
0
;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
do
{
// even iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_odd_buf
);
// odd iteration
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_even_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_even_buf
);
k_block_data_begin
+=
2
*
K0PerBlock
;
}
while
(
k_block_data_begin
<
K0
-
2
*
K0PerBlock
);
}
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_k0_m0_m1_k1
,
a_block_slice_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_k0_n0_n1_k1
,
b_block_slice_copy_step
);
block_sync_lds
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunRead
(
a_grid_desc_k0_m0_m1_k1
,
a_global_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc_k0_n0_n1_k1
,
b_global_buf
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_block_desc_k0_m0_m1_k1
,
a_block_odd_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc_k0_n0_n1_k1
,
b_block_odd_buf
);
block_sync_lds
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
}
// output: register to global memory
{
constexpr
auto
c_thread_desc_m0_m10_m11_n0_n10_n11
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
]
>
{},
I1
,
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
]
>
{},
Number
<
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
{}));
const
auto
c_m10_m11_n10_n11_thread_origin_idx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
());
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_thread_desc_m0_m10_m11_n0_n10_n11
),
decltype
(
c_grid_desc_m0_m10_m11_n0_n10_n11
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I0
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I1
],
1
,
c_m10_m11_n10_n11_thread_tensor_lengths
[
I2
],
c_m10_m11_n10_n11_thread_tensor_lengths
[
I3
]
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_grid_desc_m0_m10_m11_n0_n10_n11
,
make_multi_index
(
im0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I0
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
]
*
MPerThread
,
in0
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
]
*
NPerThread
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}}
.
Run
(
c_thread_desc_m0_m10_m11_n0_n10_n11
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_m0_m10_m11_n0_n10_n11
,
c_grid_buf
);
}
}
#else
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
...
...
@@ -710,6 +1028,15 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
blockwise_gemm
.
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
());
if
(
get_block_1d_id
()
==
0
)
{
printf
(
"%d %d %d %d
\n
"
,
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I0
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I1
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I2
],
c_m10_m11_n10_n11_thread_origin_idx_on_block
[
I3
]);
}
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
...
...
@@ -1284,6 +1611,7 @@ struct GridwiseGemmDl_bkm_bkn_mn_v1r3
c_grid_buf
);
}
}
#endif
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp
View file @
d78b9359
...
...
@@ -4,26 +4,21 @@
#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
#define CK_THREADWISE_GEMM_DLOPS_V3_HPP
#include "common_header.hpp"
#include "math.hpp"
#include "ck/utility/common_header.hpp"
namespace
ck
{
// C[M, N] += transpose(A[
K
, M]) * B[
K
, N]
// C[M, N] += transpose(A[
M
, M]) * B[
M
, N]
// Element of matrix can be vectorized data
// Assume:
// 1. AThreadDesc_E1_K_E2, BThreadDesc_E1_N_Ho_Wo_E2, CThreadDesc_K_N_Ho_Wo are known at
// compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AThreadDesc_
E1_K_E2
,
typename
BThreadDesc_
E1
_N_
Ho_Wo_E2
,
typename
CThreadDesc_
K
_N
_Ho_Wo
,
typename
enable_if
<
AThreadDesc_
E1_K_E2
::
IsKnownAtCompileTime
()
&&
BThreadDesc_
E1
_N_
Ho_Wo_E2
::
IsKnownAtCompileTime
()
&&
CThreadDesc_
K
_N
_Ho_Wo
::
IsKnownAtCompileTime
(),
typename
AThreadDesc_
K0_M_K1
,
typename
BThreadDesc_
K0
_N_
K1
,
typename
CThreadDesc_
M
_N
,
typename
enable_if
<
AThreadDesc_
K0_M_K1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_
K0
_N_
K1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_
M
_N
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseGemmDlops_km_kn_mn_v3
{
...
...
@@ -42,9 +37,9 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
COriginIdx
)
{
static_assert
(
AThreadDesc_
E1_K_E2
::
IsKnownAtCompileTime
()
&&
BThreadDesc_
E1
_N_
Ho_Wo_E2
::
IsKnownAtCompileTime
()
&&
CThreadDesc_
K
_N
_Ho_Wo
::
IsKnownAtCompileTime
(),
static_assert
(
AThreadDesc_
K0_M_K1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_
K0
_N_
K1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_
M
_N
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
...
...
@@ -61,96 +56,29 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
E1
=
AThreadDesc_
E1_K_E2
{}.
GetLength
(
I0
);
constexpr
auto
K
=
AThreadDesc_
E1_K_E2
{}.
GetLength
(
I1
);
constexpr
auto
E2
=
AThreadDesc_
E1_K_E2
{}.
GetLength
(
I2
);
constexpr
auto
K0
=
AThreadDesc_
K0_M_K1
{}.
GetLength
(
I0
);
constexpr
auto
M
=
AThreadDesc_
K0_M_K1
{}.
GetLength
(
I1
);
constexpr
auto
K1
=
AThreadDesc_
K0_M_K1
{}.
GetLength
(
I2
);
constexpr
auto
Ho
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I2
);
constexpr
auto
Wo
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I3
);
constexpr
auto
N
=
BThreadDesc_K0_N_K1
{}.
GetLength
(
I1
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
if
constexpr
((
Ho
%
2
==
0
)
&&
(
Wo
%
2
==
0
))
{
constexpr
auto
SubHW
=
2
;
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
Ho
,
SubHW
>
{}([
&
](
auto
h
)
{
static_for
<
0
,
Wo
,
SubHW
>
{}([
&
](
auto
w
)
{
static_for
<
0
,
E1
,
1
>
{}([
&
](
auto
e1
)
{
static_for
<
0
,
E2
,
1
>
{}([
&
](
auto
e2
)
{
constexpr
index_t
a_offset
=
AThreadDesc_E1_K_E2
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
e1
,
k
,
e2
));
constexpr
index_t
b0_offset
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e1
,
0
,
h
,
w
,
e2
));
constexpr
index_t
b1_offset
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e1
,
0
,
h
,
w
+
1
,
e2
));
constexpr
index_t
b2_offset
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e1
,
0
,
h
+
1
,
w
,
e2
));
constexpr
index_t
b3_offset
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e1
,
0
,
h
+
1
,
w
+
1
,
e2
));
constexpr
index_t
c0_offset
=
CThreadDesc_K_N_Ho_Wo
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
));
constexpr
index_t
c1_offset
=
CThreadDesc_K_N_Ho_Wo
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
+
1
));
constexpr
index_t
c2_offset
=
CThreadDesc_K_N_Ho_Wo
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
+
1
,
w
));
constexpr
index_t
c3_offset
=
CThreadDesc_K_N_Ho_Wo
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
+
1
,
w
+
1
));
amd_assembly_outer_product_1x4
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b0_offset
>
{}],
b_buf
[
Number
<
b1_offset
>
{}],
b_buf
[
Number
<
b2_offset
>
{}],
b_buf
[
Number
<
b3_offset
>
{}],
c_buf
(
Number
<
c0_offset
>
{}),
c_buf
(
Number
<
c1_offset
>
{}),
c_buf
(
Number
<
c2_offset
>
{}),
c_buf
(
Number
<
c3_offset
>
{}));
});
});
});
});
});
}
else
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
static_for
<
0
,
K0
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
K1
,
1
>
{}([
&
](
auto
k1
)
{
constexpr
index_t
a_offset
=
AThreadDesc_K0_M_K1
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
k0
,
m
,
k1
));
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
Ho
,
1
>
{}([
&
](
auto
h
)
{
static_for
<
0
,
Wo
,
1
>
{}([
&
](
auto
w
)
{
static_for
<
0
,
E1
,
1
>
{}([
&
](
auto
e1
)
{
static_for
<
0
,
E2
,
1
>
{}([
&
](
auto
e2
)
{
constexpr
index_t
a_offset
=
AThreadDesc_E1_K_E2
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
e1
,
k
,
e2
));
constexpr
index_t
b_offset
=
BThreadDesc_K0_N_K1
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k0
,
n
,
k1
));
constexpr
index_t
b_offset
=
BThreadDesc_E1_N_Ho_Wo_E2
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e1
,
0
,
h
,
w
,
e2
));
constexpr
index_t
c_offset
=
CThreadDesc_K_N_Ho_Wo
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
));
constexpr
index_t
c_offset
=
CThreadDesc_M_N
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
0
,
m
,
0
,
n
));
inner_product
<
FloatA
,
FloatB
,
FloatC
>
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
...
...
@@ -159,9 +87,7 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3
});
});
});
});
}
}
}
// namespace ck
};
}
// namespace ck
...
...
include/ck/utility/inner_product.hpp
View file @
d78b9359
...
...
@@ -9,6 +9,12 @@ namespace ck {
template
<
typename
TA
,
typename
TB
,
typename
TC
>
__device__
void
inner_product
(
const
TA
&
a
,
const
TB
&
b
,
TC
&
c
);
template
<
>
__device__
void
inner_product
<
half_t
,
half_t
,
float
>
(
const
half_t
&
a
,
const
half_t
&
b
,
float
&
c
)
{
c
+=
a
*
b
;
}
template
<
>
__device__
void
inner_product
<
float
,
float
,
float
>
(
const
float
&
a
,
const
float
&
b
,
float
&
c
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment