Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4b21c0fd
Commit
4b21c0fd
authored
May 30, 2021
by
Chao Liu
Browse files
overhauling fwd-v4r4
parent
a25f992d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
146 additions
and
128 deletions
+146
-128
composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
...e_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
+96
-55
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
...l/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
+44
-73
composable_kernel/include/utility/sequence_helper.hpp
composable_kernel/include/utility/sequence_helper.hpp
+6
-0
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2r2.hpp
View file @
4b21c0fd
...
...
@@ -14,7 +14,7 @@ namespace ck {
// 1. AKMBlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1.
AKM
BlockDesc is known at compile-time
// 1.
BKN
BlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CM0M1N0N1ThreadDesc is known at compile-time
...
...
@@ -27,7 +27,6 @@ template <index_t BlockSize,
typename
FloatC
,
typename
AKMBlockDesc
,
typename
BKNBlockDesc
,
typename
CM0M1N0N1ThreadDesc
,
index_t
M1PerThreadM11
,
index_t
N1PerThreadN11
,
index_t
KPerThread
,
...
...
@@ -38,10 +37,9 @@ template <index_t BlockSize,
index_t
AThreadCopyScalarPerVector_M11
,
index_t
BThreadCopyScalarPerVector_N11
,
typename
std
::
enable_if
<
AKMBlockDesc
::
IsKnownAtCompileTime
()
&&
BKNBlockDesc
::
IsKnownAtCompileTime
()
&&
CM0M1N0N1ThreadDesc
::
IsKnownAtCompileTime
(),
BKNBlockDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemm_km
0m1_kn0n1
_m0m1n0n1_v2r2_pipeline_2x2
struct
BlockwiseGemm_km
_kn
_m0m1n0n1_v2r2_pipeline_2x2
{
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
...
...
@@ -52,40 +50,76 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
K
=
AKMBlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
M
=
AKMBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N
=
BKNBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
M1
=
M1N1ThreadClusterM100
*
M1N1ThreadClusterM101
*
M1PerThreadM11
;
static
constexpr
index_t
N1
=
M1N1ThreadClusterN100
*
M1N1ThreadClusterN101
*
N1PerThreadN11
;
static
constexpr
index_t
M0
=
M
/
M1
;
static
constexpr
index_t
N0
=
N
/
N1
;
__host__
__device__
static
constexpr
auto
MakeAKM0M1BlockDescriptor
(
const
AKMBlockDesc
&
a_k_m_block_desc
)
{
const
auto
a_k_m0_m1_block_desc
=
transform_dynamic_tensor_descriptor
(
AKMBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
M0
>
{},
Number
<
M1
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
return
a_k_m0_m1_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeBKN0N1BlockDescriptor
(
const
BKNBlockDesc
&
n_k_n_block_desc
)
{
const
auto
b_k_n0_n1_block_desc
=
transform_dynamic_tensor_descriptor
(
BKNBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
N0
>
{},
Number
<
N1
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
return
b_k_n0_n1_block_desc
;
}
__host__
__device__
static
constexpr
auto
GetCM0M1N0N1ThreadTensorLengths
()
{
return
Sequence
<
M0
,
M1PerThreadM11
,
N0
,
N1PerThreadN11
>
{};
}
static
constexpr
auto
a_k_m0_m1_block_desc_
=
MakeAKM0M1BlockDescriptor
(
AKMBlockDesc
{});
static
constexpr
auto
b_k_n0_n1_block_desc_
=
MakeBKN0N1BlockDescriptor
(
BKNBlockDesc
{});
public:
__device__
BlockwiseGemm_km
0m1_kn0n1
_m0m1n0n1_v2r2_pipeline_2x2
()
__device__
BlockwiseGemm_km
_kn
_m0m1n0n1_v2r2_pipeline_2x2
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginDataIndex
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
])},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I2
],
c_thread_origin_data_idx_
[
I3
])}
{
static_assert
(
AKMBlockDesc
::
IsKnownAtCompileTime
()
&&
BKNBlockDesc
::
IsKnownAtCompileTime
()
&&
CM0M1N0N1ThreadDesc
::
IsKnownAtCompileTime
(),
static_assert
(
AKMBlockDesc
::
IsKnownAtCompileTime
()
&&
BKNBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
M1N1ThreadClusterM101
*
M1N1ThreadClusterM100
*
M1N1ThreadClusterN101
*
M1N1ThreadClusterN100
,
"wrong! blocksize and cluster size not consistent"
);
static_assert
(
M
%
M1
==
0
&&
N
%
N1
==
0
,
"wrong!"
);
static_assert
(
AKMBlockDesc
{}.
GetLength
(
I0
)
==
BKNBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
// TODO: remove this restriction
static_assert
(
AKMBlockDesc
{}.
GetLength
(
I1
)
==
2
&&
BKNBlockDesc
{}.
GetLength
(
I1
)
==
2
&&
CM0M1N0N1ThreadDesc
{}.
GetLength
(
I0
)
==
2
&&
CM0M1N0N1ThreadDesc
{}.
GetLength
(
I2
)
==
2
,
"wrong"
);
static_assert
(
M0
==
2
&&
N0
==
2
,
"wrong"
);
}
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
index_t
thread_id
)
{
constexpr
index_t
M0
=
AKMBlockDesc
{}.
GetLength
(
I1
);
constexpr
index_t
N0
=
BKNBlockDesc
{}.
GetLength
(
I1
);
constexpr
index_t
M1
=
AKMBlockDesc
{}.
GetLength
(
I2
);
constexpr
index_t
N1
=
BKNBlockDesc
{}.
GetLength
(
I2
);
// 4-d data space into 4-d thread space
constexpr
auto
adaptor0
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_vectorize_transform
(
M0
,
1
),
...
...
@@ -119,58 +153,68 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
return
cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
template
<
typename
CM0M1N0N1ThreadDesc
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
CM0M1N0N1ThreadDesc
&
c_m0_m1_n0_n1_thread_desc
,
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_assert
(
CM0M1N0N1ThreadDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
// TODO: remove this restriction
static_assert
(
M0
==
2
&&
N0
==
2
&&
CM0M1N0N1ThreadDesc
{}.
GetLength
(
I0
)
==
M0
&&
CM0M1N0N1ThreadDesc
{}.
GetLength
(
I2
)
==
N0
,
"wrong"
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatA
>
(
a_k_m0_m1_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatB
>
(
b_k_n0_n1_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
decltype
(
a_
k_m0_m1_
thread_desc_
),
decltype
(
b_
k_n0_n1_
thread_desc_
),
CM0M1N0N1ThreadDesc
,
Sequence
<
KPerThread
>
,
Sequence
<
1
,
M1PerThreadM11
>
,
Sequence
<
1
,
N1PerThreadN11
>>
{};
constexpr
index_t
K
=
AKMBlockDesc
{}.
GetLength
(
I0
);
// read A_sub_0
a_thread_copy_
.
Run
(
AKMB
lock
D
esc
{}
,
a_thread_copy_
.
Run
(
a_k_m0_m1_b
lock
_d
esc
_
,
make_tuple
(
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
a_
k_m0_m1_
thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
BKNB
lock
D
esc
{}
,
b_thread_copy_
.
Run
(
b_k_n0_n1_b
lock
_d
esc
_
,
make_tuple
(
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
b_
k_n0_n1_
thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
BKNB
lock
D
esc
{}
,
b_thread_copy_
.
Run
(
b_k_n0_n1_b
lock
_d
esc
_
,
make_tuple
(
I0
,
I1
,
I0
),
b_block_buf
,
b_thread_desc_
,
b_
k_n0_n1_
thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
AKMB
lock
D
esc
{}
,
a_thread_copy_
.
Run
(
a_k_m0_m1_b
lock
_d
esc
_
,
make_tuple
(
I0
,
I1
,
I0
),
a_block_buf
,
a_thread_desc_
,
a_
k_m0_m1_
thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
...
...
@@ -193,10 +237,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
// loop over rest of k
static_for
<
KPerThread
,
K
,
KPerThread
>
{}([
&
](
auto
k
)
{
// read A_sub_0
a_thread_copy_
.
Run
(
AKMB
lock
D
esc
{}
,
a_thread_copy_
.
Run
(
a_k_m0_m1_b
lock
_d
esc
_
,
make_tuple
(
k
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
a_
k_m0_m1_
thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
...
...
@@ -209,10 +253,10 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// read B_sub_0
b_thread_copy_
.
Run
(
BKNB
lock
D
esc
{}
,
b_thread_copy_
.
Run
(
b_k_n0_n1_b
lock
_d
esc
_
,
make_tuple
(
k
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
b_
k_n0_n1_
thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
...
...
@@ -225,18 +269,18 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
make_tuple
(
I1
,
I0
,
I1
,
I0
));
// read B_sub_1
b_thread_copy_
.
Run
(
BKNB
lock
D
esc
{}
,
b_thread_copy_
.
Run
(
b_k_n0_n1_b
lock
_d
esc
_
,
make_tuple
(
k
,
I1
,
I0
),
b_block_buf
,
b_thread_desc_
,
b_
k_n0_n1_
thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
AKMB
lock
D
esc
{}
,
a_thread_copy_
.
Run
(
a_k_m0_m1_b
lock
_d
esc
_
,
make_tuple
(
k
,
I1
,
I0
),
a_block_buf
,
a_thread_desc_
,
a_
k_m0_m1_
thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
...
...
@@ -275,22 +319,19 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
}
private:
static
constexpr
index_t
M0_
=
AKMBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N0_
=
BKNBlockDesc
{}.
GetLength
(
I1
);
// A[K, M0, M1]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0
_
>
{},
Number
<
M1PerThreadM11
>
{}));
static
constexpr
auto
a_
k_m0_m1_
thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0
>
{},
Number
<
M1PerThreadM11
>
{}));
// B[K, N0, N1]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0
_
>
{},
Number
<
N1PerThreadN11
>
{}));
static
constexpr
auto
b_
k_n0_n1_
thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0
>
{},
Number
<
N1PerThreadN11
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
AKMB
lock
D
esc
,
decltype
(
a_thread_desc_
),
decltype
(
a_k_m0_m1_b
lock
_d
esc
_
)
,
decltype
(
a_
k_m0_m1_
thread_desc_
),
Sequence
<
KPerThread
,
1
,
M1PerThreadM11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
...
...
@@ -300,8 +341,8 @@ struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BKNB
lock
D
esc
,
decltype
(
b_thread_desc_
),
decltype
(
b_k_n0_n1_b
lock
_d
esc
_
)
,
decltype
(
b_
k_n0_n1_
thread_desc_
),
Sequence
<
KPerThread
,
1
,
N1PerThreadN11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v1r2.hpp
View file @
4b21c0fd
...
...
@@ -283,59 +283,27 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
static_assert
(
MPerBlock
%
(
M1PerThread
*
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
)
==
0
&&
NPerBlock
%
(
N1PerThread
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
)
==
0
,
"wrong!"
);
constexpr
index_t
M0PerThread
=
MPerBlock
/
(
M1PerThread
*
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
);
constexpr
index_t
N0PerThread
=
NPerBlock
/
(
N1PerThread
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
);
constexpr
auto
a_k_m0_m1_block_desc
=
transform_dynamic_tensor_descriptor
(
a_k_m_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
M0PerThread
>
{},
Number
<
M1PerThread
*
M1N1ThreadClusterM11
*
M1N1ThreadClusterM10
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
constexpr
auto
b_k_n0_n1_block_desc
=
transform_dynamic_tensor_descriptor
(
b_k_n_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
N0PerThread
>
{},
Number
<
N1PerThread
*
M1N1ThreadClusterN11
*
M1N1ThreadClusterN10
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
constexpr
auto
c_m0_m1_n0_n1_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
M0PerThread
>
{},
Number
<
M1PerThread
>
{},
Number
<
N0PerThread
>
{},
Number
<
N1PerThread
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2r2_pipeline_2x2
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k_m0_m1_block_desc
),
decltype
(
b_k_n0_n1_block_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
),
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterN11
,
M1PerThread
,
N1PerThread
>
{};
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_desc
),
M1PerThread
,
N1PerThread
,
KPerThread
,
M1N1ThreadClusterM10
,
M1N1ThreadClusterN10
,
M1N1ThreadClusterM11
,
M1N1ThreadClusterN11
,
M1PerThread
,
N1PerThread
>
{};
constexpr
auto
c_m0_m1_n0_n1_thread_tensor_lengths
=
decltype
(
blockwise_gemm
)
::
GetCM0M1N0N1ThreadTensorLengths
();
constexpr
auto
c_m0_m1_n0_n1_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
c_m0_m1_n0_n1_thread_tensor_lengths
));
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
...
...
@@ -351,10 +319,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
auto
c_thread_buf
=
make_static_buffer
<
AddressSpace
::
Vgpr
,
FloatAcc
>
(
c_m0_m1_n0_n1_thread_desc
.
GetElementSpaceSize
());
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
Sequence
<
M0PerThread
,
M1PerThread
,
N0PerThread
,
N1PerThread
>>
{}
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_thread_tensor_lengths
)
>
{}
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
...
...
@@ -415,7 +382,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
b_k_n_global_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_odd_buf
);
...
...
@@ -438,7 +406,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
b_k_n_global_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_even_buf
);
...
...
@@ -465,7 +434,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
b_global_buf
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
a_block_odd_buf
);
...
...
@@ -474,14 +444,16 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
a_block_even_buf
,
b_block_even_buf
,
c_thread_buf
);
}
// output: register to global memory
...
...
@@ -495,18 +467,17 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1r2
const
auto
c_thread_data_idx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
get_thread_local_1d_id
());
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
Sequence
<
M0PerThread
,
M1PerThread
,
N0PerThread
,
N1PerThread
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
FloatAcc
,
FloatC
,
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
decltype
(
c_m0_m1_n0_n1_thread_tensor_lengths
),
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_m0_m1_n0_n1_global_desc
,
make_multi_index
(
m_block_data_idx_on_global
/
M1
+
c_thread_data_idx_on_block
[
I0
],
c_thread_data_idx_on_block
[
I1
],
...
...
composable_kernel/include/utility/sequence_helper.hpp
View file @
4b21c0fd
...
...
@@ -26,5 +26,11 @@ __host__ __device__ constexpr auto generate_sequence_v2(F&& f, Number<N>)
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
to_sequence
(
Tuple
<
Number
<
Is
>
...
>
)
{
return
Sequence
<
Is
...
>
{};
}
}
// namespace ck
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment