Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a25f992d
Commit
a25f992d
authored
May 29, 2021
by
Chao Liu
Browse files
overhauling fwd-v4r4
parent
849243b8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
60 additions
and
58 deletions
+60
-58
composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
...e_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
+60
-58
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
View file @
a25f992d
...
...
@@ -11,13 +11,13 @@ namespace ck {
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc is known at compile-time
// 1. A
KM
BlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. ABlockDesc is known at compile-time
// 1. A
KM
BlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc is known at compile-time
// 1. C
M0M1N0N1
ThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
...
...
@@ -25,21 +25,21 @@ template <index_t BlockSize,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc
,
typename
BBlockDesc
,
typename
CThreadDesc
,
index_t
M1PerThread
,
index_t
N1PerThread
,
typename
A
KM
BlockDesc
,
typename
B
KN
BlockDesc
,
typename
C
M0M1N0N1
ThreadDesc
,
index_t
M1PerThread
M11
,
index_t
N1PerThread
N11
,
index_t
KPerThread
,
index_t
M1N1ThreadClusterM10
,
index_t
M1N1ThreadClusterN10
,
index_t
M1N1ThreadClusterM11
,
index_t
M1N1ThreadClusterN11
,
index_t
AThreadCopyScalarPerVector_M1
,
index_t
BThreadCopyScalarPerVector_N1
,
typename
std
::
enable_if
<
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
()
&&
CThreadDesc
::
IsKnownAtCompileTime
(),
index_t
M1N1ThreadClusterM10
0
,
index_t
M1N1ThreadClusterN10
0
,
index_t
M1N1ThreadClusterM1
0
1
,
index_t
M1N1ThreadClusterN1
0
1
,
index_t
AThreadCopyScalarPerVector_M1
1
,
index_t
BThreadCopyScalarPerVector_N1
1
,
typename
std
::
enable_if
<
A
KM
BlockDesc
::
IsKnownAtCompileTime
()
&&
B
KN
BlockDesc
::
IsKnownAtCompileTime
()
&&
C
M0M1N0N1
ThreadDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
{
...
...
@@ -60,36 +60,38 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
b_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I2
],
c_thread_origin_data_idx_
[
I3
])}
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
()
&&
CThreadDesc
::
IsKnownAtCompileTime
(),
static_assert
(
AKMBlockDesc
::
IsKnownAtCompileTime
()
&&
BKNBlockDesc
::
IsKnownAtCompileTime
()
&&
CM0M1N0N1ThreadDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
,
static_assert
(
BlockSize
==
M1N1ThreadClusterM1
0
1
*
M1N1ThreadClusterM10
0
*
M1N1ThreadClusterN1
0
1
*
M1N1ThreadClusterN10
0
,
"wrong! blocksize and cluster size not consistent"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
static_assert
(
A
KM
BlockDesc
{}.
GetLength
(
I0
)
==
B
KN
BlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
// TODO: remove this restriction
static_assert
(
ABlockDesc
{}.
GetLength
(
I1
)
==
2
&&
BBlockDesc
{}.
GetLength
(
I1
)
==
2
&&
CThreadDesc
{}.
GetLength
(
I0
)
==
2
&&
CThreadDesc
{}.
GetLength
(
I2
)
==
2
,
static_assert
(
AKMBlockDesc
{}.
GetLength
(
I1
)
==
2
&&
BKNBlockDesc
{}.
GetLength
(
I1
)
==
2
&&
CM0M1N0N1ThreadDesc
{}.
GetLength
(
I0
)
==
2
&&
CM0M1N0N1ThreadDesc
{}.
GetLength
(
I2
)
==
2
,
"wrong"
);
}
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
index_t
thread_id
)
{
constexpr
index_t
M0
=
ABlockDesc
{}.
GetLength
(
I1
);
constexpr
index_t
N0
=
BBlockDesc
{}.
GetLength
(
I1
);
constexpr
index_t
M1
=
ABlockDesc
{}.
GetLength
(
I2
);
constexpr
index_t
N1
=
BBlockDesc
{}.
GetLength
(
I2
);
constexpr
index_t
M0
=
A
KM
BlockDesc
{}.
GetLength
(
I1
);
constexpr
index_t
N0
=
B
KN
BlockDesc
{}.
GetLength
(
I1
);
constexpr
index_t
M1
=
A
KM
BlockDesc
{}.
GetLength
(
I2
);
constexpr
index_t
N1
=
B
KN
BlockDesc
{}.
GetLength
(
I2
);
// 4-d data space into 4-d thread space
constexpr
auto
adaptor0
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_vectorize_transform
(
M0
,
1
),
make_vectorize_transform
(
M1PerThread
,
M1
/
M1PerThread
),
make_vectorize_transform
(
M1PerThread
M11
,
M1
/
M1PerThread
M11
),
make_vectorize_transform
(
N0
,
1
),
make_vectorize_transform
(
N1PerThread
,
N1
/
N1PerThread
)),
make_vectorize_transform
(
N1PerThread
N11
,
N1
/
N1PerThread
N11
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
...
...
@@ -97,18 +99,18 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
constexpr
auto
adaptor1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_freeze_transform
(
make_multi_index
(
0
)),
make_unmerge_transform
(
make_tuple
(
M1N1ThreadClusterM10
,
M1N1ThreadClusterM11
)),
make_unmerge_transform
(
make_tuple
(
M1N1ThreadClusterM10
0
,
M1N1ThreadClusterM1
0
1
)),
make_freeze_transform
(
make_multi_index
(
0
)),
make_unmerge_transform
(
make_tuple
(
M1N1ThreadClusterN10
,
M1N1ThreadClusterN11
))),
make_unmerge_transform
(
make_tuple
(
M1N1ThreadClusterN10
0
,
M1N1ThreadClusterN1
0
1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
1
>
{},
Sequence
<>
{},
Sequence
<
2
,
3
>
{}));
// 4-d thread space to 1-d thread space
constexpr
auto
adaptor2
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M1N1ThreadClusterM10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterN11
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
M1N1ThreadClusterM10
0
,
M1N1ThreadClusterN10
0
,
M1N1ThreadClusterM1
0
1
,
M1N1ThreadClusterN1
0
1
))),
make_tuple
(
Sequence
<
0
,
2
,
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
...
...
@@ -133,15 +135,15 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
FloatC
,
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
CThreadDesc
,
C
M0M1N0N1
ThreadDesc
,
Sequence
<
KPerThread
>
,
Sequence
<
1
,
M1PerThread
>
,
Sequence
<
1
,
N1PerThread
>>
{};
Sequence
<
1
,
M1PerThread
M11
>
,
Sequence
<
1
,
N1PerThread
N11
>>
{};
constexpr
index_t
K
=
ABlockDesc
{}.
GetLength
(
I0
);
constexpr
index_t
K
=
A
KM
BlockDesc
{}.
GetLength
(
I0
);
// read A_sub_0
a_thread_copy_
.
Run
(
ABlockDesc
{},
a_thread_copy_
.
Run
(
A
KM
BlockDesc
{},
make_tuple
(
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
...
...
@@ -149,7 +151,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
a_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
b_thread_copy_
.
Run
(
B
KN
BlockDesc
{},
make_tuple
(
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
...
...
@@ -157,7 +159,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
b_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
b_thread_copy_
.
Run
(
B
KN
BlockDesc
{},
make_tuple
(
I0
,
I1
,
I0
),
b_block_buf
,
b_thread_desc_
,
...
...
@@ -165,7 +167,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
ABlockDesc
{},
a_thread_copy_
.
Run
(
A
KM
BlockDesc
{},
make_tuple
(
I0
,
I1
,
I0
),
a_block_buf
,
a_thread_desc_
,
...
...
@@ -191,7 +193,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
// loop over rest of k
static_for
<
KPerThread
,
K
,
KPerThread
>
{}([
&
](
auto
k
)
{
// read A_sub_0
a_thread_copy_
.
Run
(
ABlockDesc
{},
a_thread_copy_
.
Run
(
A
KM
BlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
...
...
@@ -207,7 +209,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// read B_sub_0
b_thread_copy_
.
Run
(
BBlockDesc
{},
b_thread_copy_
.
Run
(
B
KN
BlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
...
...
@@ -223,7 +225,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
make_tuple
(
I1
,
I0
,
I1
,
I0
));
// read B_sub_1
b_thread_copy_
.
Run
(
BBlockDesc
{},
b_thread_copy_
.
Run
(
B
KN
BlockDesc
{},
make_tuple
(
k
,
I1
,
I0
),
b_block_buf
,
b_thread_desc_
,
...
...
@@ -231,7 +233,7 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
ABlockDesc
{},
a_thread_copy_
.
Run
(
A
KM
BlockDesc
{},
make_tuple
(
k
,
I1
,
I0
),
a_block_buf
,
a_thread_desc_
,
...
...
@@ -273,37 +275,37 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
}
private:
static
constexpr
index_t
M0_
=
ABlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N0_
=
BBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
M0_
=
A
KM
BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N0_
=
B
KN
BlockDesc
{}.
GetLength
(
I1
);
// A[K, M0, M1]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0_
>
{},
Number
<
M1PerThread
>
{}));
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0_
>
{},
Number
<
M1PerThread
M11
>
{}));
// B[K, N0, N1]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0_
>
{},
Number
<
N1PerThread
>
{}));
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0_
>
{},
Number
<
N1PerThread
N11
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc
,
A
KM
BlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
KPerThread
,
1
,
M1PerThread
>
,
Sequence
<
KPerThread
,
1
,
M1PerThread
M11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
AThreadCopyScalarPerVector_M1
,
AThreadCopyScalarPerVector_M1
1
,
1
>
;
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BBlockDesc
,
B
KN
BlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
KPerThread
,
1
,
N1PerThread
>
,
Sequence
<
KPerThread
,
1
,
N1PerThread
N11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BThreadCopyScalarPerVector_N1
,
BThreadCopyScalarPerVector_N1
1
,
1
>
;
CIndex
c_thread_origin_data_idx_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment