Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fc148cef
Commit
fc148cef
authored
Apr 29, 2021
by
Chao Liu
Browse files
added back pipelined 2x2 to blockwise gemm
parent
0374f8de
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
372 additions
and
68 deletions
+372
-68
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+335
-43
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+16
-16
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
+21
-9
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
fc148cef
...
@@ -11,21 +11,21 @@ namespace ck {
...
@@ -11,21 +11,21 @@ namespace ck {
// A and B are visable to the whole block, C is distributed among each thread
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// Assume:
// 1. A:
// 1. A:
// 1. Block
MatrixA
is known at compile-time
// 1.
A
Block
Desc
is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 2. B:
// 1. Block
MatrixA
is known at compile-time
// 1.
A
Block
Desc
is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 3. C:
// 1. Thread
MatrixC
is known at compile-time
// 1.
C
Thread
Desc
is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// 2. CThreadBuffer is StaticBuffer
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatA
,
typename
FloatB
,
typename
FloatB
,
typename
FloatC
,
typename
FloatC
,
typename
Block
MatrixA
,
typename
A
Block
Desc
,
typename
Block
MatrixB
,
typename
B
Block
Desc
,
typename
Thread
MatrixC
,
typename
C
Thread
Desc
,
index_t
MPerThreadSubC
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
KPerThreadLoop
,
index_t
KPerThreadLoop
,
...
@@ -35,9 +35,9 @@ template <index_t BlockSize,
...
@@ -35,9 +35,9 @@ template <index_t BlockSize,
index_t
NLevel1ThreadCluster
,
index_t
NLevel1ThreadCluster
,
index_t
ThreadGemmADataPerRead_M
,
index_t
ThreadGemmADataPerRead_M
,
index_t
ThreadGemmBDataPerRead_N
,
index_t
ThreadGemmBDataPerRead_N
,
typename
std
::
enable_if
<
Block
MatrixA
::
IsKnownAtCompileTime
()
&&
typename
std
::
enable_if
<
A
Block
Desc
::
IsKnownAtCompileTime
()
&&
Block
MatrixB
::
IsKnownAtCompileTime
()
&&
B
Block
Desc
::
IsKnownAtCompileTime
()
&&
Thread
MatrixC
::
IsKnownAtCompileTime
(),
C
Thread
Desc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
BlockwiseGemm_km_kn_m0m1n0n1_v1r1
struct
BlockwiseGemm_km_kn_m0m1n0n1_v1r1
{
{
...
@@ -49,13 +49,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -49,13 +49,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
public:
public:
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v1r1
()
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v1r1
()
:
c_thread_begin_mtx_idx_
{
GetBeginOfThread
MatrixC
(
get_thread_local_1d_id
())},
:
c_thread_begin_mtx_idx_
{
GetBeginOf
C
Thread
Desc
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_begin_mtx_idx_
.
row
)},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_begin_mtx_idx_
.
row
)},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_begin_mtx_idx_
.
col
)}
b_thread_copy_
{
make_tuple
(
0
,
c_thread_begin_mtx_idx_
.
col
)}
{
{
static_assert
(
BlockMatrixA
::
IsKnownAtCompileTime
()
&&
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
()
&&
BlockMatrixB
::
IsKnownAtCompileTime
()
&&
CThreadDesc
::
IsKnownAtCompileTime
(),
ThreadMatrixC
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -66,27 +65,27 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -66,27 +65,27 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
Block
MatrixA
{}.
GetLength
(
I0
)
==
Block
MatrixB
{}.
GetLength
(
I0
),
static_assert
(
A
Block
Desc
{}.
GetLength
(
I0
)
==
B
Block
Desc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
"wrong! K dimension not consistent"
);
constexpr
index_t
M
=
Block
MatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
M
=
A
Block
Desc
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
Block
MatrixB
{}.
GetLength
(
I1
);
constexpr
index_t
N
=
B
Block
Desc
{}.
GetLength
(
I1
);
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
)
==
0
&&
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
)
==
0
&&
N
%
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
)
==
0
,
N
%
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
)
==
0
,
"wrong! Cannot evenly divide work among"
);
"wrong! Cannot evenly divide work among"
);
static_assert
(
Thread
MatrixC
{}.
GetLength
(
I0
)
==
GetThread
MatrixC
Lengths
()[
I0
]
&&
static_assert
(
C
Thread
Desc
{}.
GetLength
(
I0
)
==
Get
C
Thread
Desc
Lengths
()[
I0
]
&&
Thread
MatrixC
{}.
GetLength
(
I1
)
==
GetThread
MatrixC
Lengths
()[
I1
],
C
Thread
Desc
{}.
GetLength
(
I1
)
==
Get
C
Thread
Desc
Lengths
()[
I1
],
"wrong! Thread
MatrixC
lengths is wrong"
);
"wrong!
C
Thread
Desc
lengths is wrong"
);
}
}
__device__
static
constexpr
auto
GetThread
MatrixC
Lengths
()
__device__
static
constexpr
auto
Get
C
Thread
Desc
Lengths
()
{
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
M
=
Block
MatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
M
=
A
Block
Desc
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
Block
MatrixB
{}.
GetLength
(
I1
);
constexpr
index_t
N
=
B
Block
Desc
{}.
GetLength
(
I1
);
constexpr
index_t
MRepeat
=
constexpr
index_t
MRepeat
=
M
/
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
);
M
/
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
);
...
@@ -96,7 +95,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -96,7 +95,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
return
Sequence
<
MRepeat
*
MPerThreadSubC
,
NRepeat
*
NPerThreadSubC
>
{};
return
Sequence
<
MRepeat
*
MPerThreadSubC
,
NRepeat
*
NPerThreadSubC
>
{};
}
}
__device__
static
MatrixIndex
GetBeginOfThread
MatrixC
(
index_t
thread_id
)
__device__
static
MatrixIndex
GetBeginOf
C
Thread
Desc
(
index_t
thread_id
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -130,9 +129,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -130,9 +129,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
a_block_mtx
=
Block
MatrixA
{};
constexpr
auto
a_block_mtx
=
A
Block
Desc
{};
constexpr
auto
b_block_mtx
=
Block
MatrixB
{};
constexpr
auto
b_block_mtx
=
B
Block
Desc
{};
constexpr
auto
c_thread_mtx_desc
=
Thread
MatrixC
{};
constexpr
auto
c_thread_mtx_desc
=
C
Thread
Desc
{};
constexpr
auto
K
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
K
=
a_block_mtx
.
GetLength
(
I0
);
...
@@ -174,7 +173,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -174,7 +173,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
decltype
(
c_thread_sub_mtx
)
>
{};
decltype
(
c_thread_sub_mtx
)
>
{};
// read A_sub_0
// read A_sub_0
a_thread_copy_
.
Run
(
Block
MatrixA
{},
a_thread_copy_
.
Run
(
A
Block
Desc
{},
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_mtx_desc_
,
a_thread_mtx_desc_
,
...
@@ -182,7 +181,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -182,7 +181,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
a_thread_buf
);
a_thread_buf
);
// read B_sub_0
// read B_sub_0
b_thread_copy_
.
Run
(
Block
MatrixB
{},
b_thread_copy_
.
Run
(
B
Block
Desc
{},
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_mtx_desc_
,
b_thread_mtx_desc_
,
...
@@ -190,7 +189,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -190,7 +189,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
b_thread_buf
);
b_thread_buf
);
// read B_sub_1
// read B_sub_1
b_thread_copy_
.
Run
(
Block
MatrixB
{},
b_thread_copy_
.
Run
(
B
Block
Desc
{},
make_tuple
(
I0
,
Number
<
NPerLevel1Cluster
>
{}),
make_tuple
(
I0
,
Number
<
NPerLevel1Cluster
>
{}),
b_block_buf
,
b_block_buf
,
b_thread_mtx_desc_
,
b_thread_mtx_desc_
,
...
@@ -198,7 +197,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -198,7 +197,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
b_thread_buf
);
b_thread_buf
);
// read A_sub_1
// read A_sub_1
a_thread_copy_
.
Run
(
Block
MatrixA
{},
a_thread_copy_
.
Run
(
A
Block
Desc
{},
make_tuple
(
I0
,
Number
<
MPerLevel1Cluster
>
{}),
make_tuple
(
I0
,
Number
<
MPerLevel1Cluster
>
{}),
a_block_buf
,
a_block_buf
,
a_thread_mtx_desc_
,
a_thread_mtx_desc_
,
...
@@ -224,7 +223,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -224,7 +223,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
// loop over rest of k
// loop over rest of k
static_for
<
KPerThreadLoop
,
K
,
KPerThreadLoop
>
{}([
&
](
auto
k
)
{
static_for
<
KPerThreadLoop
,
K
,
KPerThreadLoop
>
{}([
&
](
auto
k
)
{
// read A_sub_0
// read A_sub_0
a_thread_copy_
.
Run
(
Block
MatrixA
{},
a_thread_copy_
.
Run
(
A
Block
Desc
{},
make_tuple
(
k
,
I0
),
make_tuple
(
k
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_mtx_desc_
,
a_thread_mtx_desc_
,
...
@@ -240,7 +239,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -240,7 +239,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple
(
Number
<
MPerThreadSubC
>
{},
I0
));
make_tuple
(
Number
<
MPerThreadSubC
>
{},
I0
));
// read B_sub_0
// read B_sub_0
b_thread_copy_
.
Run
(
Block
MatrixB
{},
b_thread_copy_
.
Run
(
B
Block
Desc
{},
make_tuple
(
k
,
I0
),
make_tuple
(
k
,
I0
),
b_block_buf
,
b_block_buf
,
b_thread_mtx_desc_
,
b_thread_mtx_desc_
,
...
@@ -256,7 +255,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -256,7 +255,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}));
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}));
// read B_sub_1
// read B_sub_1
b_thread_copy_
.
Run
(
Block
MatrixB
{},
b_thread_copy_
.
Run
(
B
Block
Desc
{},
make_tuple
(
k
,
Number
<
NPerLevel1Cluster
>
{}),
make_tuple
(
k
,
Number
<
NPerLevel1Cluster
>
{}),
b_block_buf
,
b_block_buf
,
b_thread_mtx_desc_
,
b_thread_mtx_desc_
,
...
@@ -264,7 +263,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -264,7 +263,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
b_thread_buf
);
b_thread_buf
);
// read A_sub_1
// read A_sub_1
a_thread_copy_
.
Run
(
Block
MatrixA
{},
a_thread_copy_
.
Run
(
A
Block
Desc
{},
make_tuple
(
k
,
Number
<
MPerLevel1Cluster
>
{}),
make_tuple
(
k
,
Number
<
MPerLevel1Cluster
>
{}),
a_block_buf
,
a_block_buf
,
a_thread_mtx_desc_
,
a_thread_mtx_desc_
,
...
@@ -314,8 +313,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -314,8 +313,8 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
MPerThread
=
Thread
MatrixC
{}.
GetLength
(
I0
);
constexpr
index_t
MPerThread
=
C
Thread
Desc
{}.
GetLength
(
I0
);
constexpr
index_t
NPerThread
=
Thread
MatrixC
{}.
GetLength
(
I1
);
constexpr
index_t
NPerThread
=
C
Thread
Desc
{}.
GetLength
(
I1
);
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
...
@@ -342,15 +341,15 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -342,15 +341,15 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
Sequence
<
0
,
1
,
2
,
3
>
{});
Sequence
<
0
,
1
,
2
,
3
>
{});
static
constexpr
auto
a_thread_mtx_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
a_thread_mtx_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Thread
MatrixC
{}.
GetLength
(
Number
<
0
>
{})));
make_tuple
(
Number
<
KPerThreadLoop
>
{},
C
Thread
Desc
{}.
GetLength
(
Number
<
0
>
{})));
static
constexpr
auto
b_thread_mtx_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
static
constexpr
auto
b_thread_mtx_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Thread
MatrixC
{}.
GetLength
(
Number
<
1
>
{})));
make_tuple
(
Number
<
KPerThreadLoop
>
{},
C
Thread
Desc
{}.
GetLength
(
Number
<
1
>
{})));
using
AThreadCopy
=
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
FloatA
,
Block
MatrixA
,
A
Block
Desc
,
decltype
(
a_thread_mtx_desc_
),
decltype
(
a_thread_mtx_desc_
),
Sequence
<
KPerThreadLoop
,
MPerThreadSubC
>
,
Sequence
<
KPerThreadLoop
,
MPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
...
@@ -363,7 +362,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -363,7 +362,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
using
BThreadCopy
=
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
FloatB
,
Block
MatrixB
,
B
Block
Desc
,
decltype
(
b_thread_mtx_desc_
),
decltype
(
b_thread_mtx_desc_
),
Sequence
<
KPerThreadLoop
,
NPerThreadSubC
>
,
Sequence
<
KPerThreadLoop
,
NPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
...
@@ -411,7 +410,7 @@ template <index_t BlockSize,
...
@@ -411,7 +410,7 @@ template <index_t BlockSize,
BBlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
()
&&
CThreadDesc
::
IsKnownAtCompileTime
(),
CThreadDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
struct
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
r1
{
{
using
AIndex
=
MultiIndex
<
3
>
;
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
...
@@ -423,7 +422,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -423,7 +422,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
public:
public:
__device__
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
()
__device__
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
r1
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginDataIndex
(
get_thread_local_1d_id
())},
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginDataIndex
(
get_thread_local_1d_id
())},
a_thread_copy_
{
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
])},
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
])},
...
@@ -479,7 +478,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -479,7 +478,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
FloatC
,
FloatC
,
decltype
(
a_thread_desc_
),
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
b_thread_desc_
),
CThreadDesc
>
{};
CThreadDesc
,
Sequence
<
KPerThreadLoop
>
,
Sequence
<
M0_
,
M1PerThread
>
,
Sequence
<
N0_
,
N1PerThread
>>
{};
constexpr
index_t
K
=
ABlockDesc
{}.
GetLength
(
I0
);
constexpr
index_t
K
=
ABlockDesc
{}.
GetLength
(
I0
);
...
@@ -553,5 +555,295 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
...
@@ -553,5 +555,295 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
BThreadCopy
b_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
};
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. ABlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc
,
typename
BBlockDesc
,
typename
CThreadDesc
,
index_t
M1PerThread
,
index_t
N1PerThread
,
index_t
KPerThreadLoop
,
index_t
MLevel0ThreadCluster
,
index_t
NLevel0ThreadCluster
,
index_t
MLevel1ThreadCluster
,
index_t
NLevel1ThreadCluster
,
index_t
AThreadCopyScalarPerVector_M1
,
index_t
BThreadCopyScalarPerVector_N1
,
typename
std
::
enable_if
<
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
()
&&
CThreadDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2
{
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
public:
__device__
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1r2
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginDataIndex
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
])},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I2
],
c_thread_origin_data_idx_
[
I3
])}
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
()
&&
CThreadDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
c_thread_cluster_desc_
.
GetElementSize
(),
"wrong! wrong blocksize"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
// TODO: remove this restriction
static_assert
(
ABlockDesc
{}.
GetLength
(
I1
)
==
2
&&
BBlockDesc
{}.
GetLength
(
I1
)
==
2
&&
CThreadDesc
{}.
GetLength
(
I0
)
==
2
&&
CThreadDesc
{}.
GetLength
(
I2
)
==
2
,
"wrong"
);
}
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
index_t
thread_id
)
{
const
auto
thread_cluster_idx
=
c_thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
constexpr
index_t
MPerLevel0Cluster
=
M1PerThread
*
MLevel0ThreadCluster
;
constexpr
index_t
NPerLevel0Cluster
=
N1PerThread
*
NLevel0ThreadCluster
;
return
make_multi_index
(
0
,
thread_cluster_idx
[
I0
]
*
MPerLevel0Cluster
+
thread_cluster_idx
[
I2
]
*
M1PerThread
,
0
,
thread_cluster_idx
[
I1
]
*
NPerLevel0Cluster
+
thread_cluster_idx
[
I3
]
*
N1PerThread
);
}
__host__
__device__
static
constexpr
auto
GetCThreadClusterDescriptor
()
{
return
make_cluster_descriptor_v2
(
Sequence
<
MLevel1ThreadCluster
,
NLevel1ThreadCluster
,
MLevel0ThreadCluster
,
NLevel0ThreadCluster
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{});
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
CThreadDesc
,
Sequence
<
KPerThreadLoop
>
,
Sequence
<
1
,
M1PerThread
>
,
Sequence
<
1
,
N1PerThread
>>
{};
constexpr
index_t
K
=
ABlockDesc
{}.
GetLength
(
I0
);
// read A_sub_0
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
I0
,
I1
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
I0
,
I1
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I1
,
I0
));
// loop over rest of k
static_for
<
KPerThreadLoop
,
K
,
KPerThreadLoop
>
{}([
&
](
auto
k
)
{
// read A_sub_0
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I1
,
I0
));
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I1
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I1
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I1
,
I0
));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I1
,
I0
));
}
private:
static
constexpr
auto
c_thread_cluster_desc_
=
GetCThreadClusterDescriptor
();
static
constexpr
index_t
M0_
=
ABlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N0_
=
BBlockDesc
{}.
GetLength
(
I1
);
// A[K, M0, M1]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
M0_
>
{},
Number
<
M1PerThread
>
{}));
// B[K, N0, N1]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
N0_
>
{},
Number
<
N1PerThread
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
KPerThreadLoop
,
1
,
M1PerThread
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
AThreadCopyScalarPerVector_M1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
KPerThreadLoop
,
1
,
N1PerThread
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BThreadCopyScalarPerVector_N1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
fc148cef
...
@@ -721,7 +721,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -721,7 +721,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
Number
<
MRepeat
>
{},
Number
<
MPerThread
>
{},
Number
<
NRepeat
>
{},
Number
<
NPerThread
>
{}));
Number
<
MRepeat
>
{},
Number
<
MPerThread
>
{},
Number
<
NRepeat
>
{},
Number
<
NPerThread
>
{}));
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
<
BlockSize
,
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
r2
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
View file @
fc148cef
...
@@ -151,11 +151,27 @@ template <typename FloatA,
...
@@ -151,11 +151,27 @@ template <typename FloatA,
typename
ADesc
,
typename
ADesc
,
typename
BDesc
,
typename
BDesc
,
typename
CDesc
,
typename
CDesc
,
typename
KLengths
,
typename
MLengths
,
typename
NLengths
,
typename
std
::
enable_if
<
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
typename
std
::
enable_if
<
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
CDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
struct
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
{
{
__device__
constexpr
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
()
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
// TODO remove this restriction
static_assert
(
KLengths
::
Size
()
==
1
&&
MLengths
::
Size
()
==
2
&&
NLengths
::
Size
()
==
2
,
"wrong!"
);
}
template
<
typename
ABuffer
,
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BBuffer
,
...
@@ -169,10 +185,6 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
...
@@ -169,10 +185,6 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
CBuffer
&
c_buf
,
CBuffer
&
c_buf
,
COriginIdx
)
COriginIdx
)
{
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
...
@@ -192,11 +204,11 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
...
@@ -192,11 +204,11 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
K
=
ADesc
{}.
Get
Length
(
I0
)
;
constexpr
auto
K
=
K
Length
s
{}[
I0
]
;
constexpr
auto
M0
=
CDesc
{}.
Get
Length
(
I0
)
;
constexpr
auto
M0
=
M
Length
s
{}[
I0
]
;
constexpr
auto
M1
=
CDesc
{}.
Get
Length
(
I1
)
;
constexpr
auto
M1
=
M
Length
s
{}[
I1
]
;
constexpr
auto
N0
=
CDesc
{}.
GetLength
(
I2
)
;
constexpr
auto
N0
=
NLengths
{}[
I0
]
;
constexpr
auto
N1
=
CDesc
{}.
GetLength
(
I3
)
;
constexpr
auto
N1
=
NLengths
{}[
I1
]
;
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment