Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d075adf1
Unverified
Commit
d075adf1
authored
Apr 28, 2021
by
Chao Liu
Committed by
GitHub
Apr 28, 2021
Browse files
Use Tuple and vector_type instead of Array for holding tensor data (#30)
* replacing array with tuple and vector for tensor data
parent
e4790c25
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1316 additions
and
661 deletions
+1316
-661
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+196
-190
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
+81
-93
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+42
-31
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+43
-56
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp
.../tensor_operation/threadwise_dynamic_tensor_slice_set.hpp
+59
-0
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+279
-56
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
+95
-123
composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
+107
-82
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+69
-0
composable_kernel/include/utility/buffer.hpp
composable_kernel/include/utility/buffer.hpp
+72
-0
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+1
-0
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+3
-3
composable_kernel/include/utility/float_type.amd.hpp.in
composable_kernel/include/utility/float_type.amd.hpp.in
+237
-9
composable_kernel/include/utility/sequence_helper.hpp
composable_kernel/include/utility/sequence_helper.hpp
+15
-0
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+2
-2
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+15
-16
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
d075adf1
...
...
@@ -2,16 +2,27 @@
#define CK_BLOCKWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_gemm_v2.hpp"
namespace
ck
{
//
blockwise GEMM:
C[M, N] += transpose(A[K, M]) * B[K, N]
// C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
// Assume:
// 1. A:
// 1. BlockMatrixA is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BlockMatrixA is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. ThreadMatrixC is known at compile-time
// 2. CThreadBuffer is StaticBuffer
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
ThreadMatrixC
,
...
...
@@ -23,8 +34,12 @@ template <index_t BlockSize,
index_t
MLevel1ThreadCluster
,
index_t
NLevel1ThreadCluster
,
index_t
ThreadGemmADataPerRead_M
,
index_t
ThreadGemmBDataPerRead_N
>
struct
BlockwiseGemm_km_kn_m0m1n0n1_v1
index_t
ThreadGemmBDataPerRead_N
,
typename
std
::
enable_if
<
BlockMatrixA
::
IsKnownAtCompileTime
()
&&
BlockMatrixB
::
IsKnownAtCompileTime
()
&&
ThreadMatrixC
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemm_km_kn_m0m1n0n1_v1r1
{
struct
MatrixIndex
{
...
...
@@ -32,10 +47,49 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
index_t
col
;
};
index_t
mMyThreadOffsetA
;
index_t
mMyThreadOffsetB
;
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v1
()
private:
static
constexpr
auto
a_thread_mtx_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
ThreadMatrixC
{}.
GetLength
(
Number
<
0
>
{})));
static
constexpr
auto
b_thread_mtx_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
ThreadMatrixC
{}.
GetLength
(
Number
<
1
>
{})));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
BlockMatrixA
,
decltype
(
a_thread_mtx_desc_
),
Sequence
<
KPerThreadLoop
,
MPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
1
,
ThreadGemmADataPerRead_M
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BlockMatrixB
,
decltype
(
b_thread_mtx_desc_
),
Sequence
<
KPerThreadLoop
,
NPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
1
,
ThreadGemmBDataPerRead_N
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
MatrixIndex
c_thread_begin_mtx_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
public:
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v1r1
()
:
c_thread_begin_mtx_idx_
{
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_begin_mtx_idx_
.
row
)},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_begin_mtx_idx_
.
col
)}
{
static_assert
(
BlockMatrixA
::
IsKnownAtCompileTime
()
&&
BlockMatrixB
::
IsKnownAtCompileTime
()
&&
...
...
@@ -51,23 +105,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent
\n
"
);
"wrong! K dimension not consistent"
);
constexpr
index_t
M
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
)
==
0
&&
N
%
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
)
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
"wrong! Cannot evenly divide work among"
);
static_assert
(
ThreadMatrixC
{}.
GetLength
(
I0
)
==
GetThreadMatrixCLengths
()[
I0
]
&&
ThreadMatrixC
{}.
GetLength
(
I1
)
==
GetThreadMatrixCLengths
()[
I1
],
"wrong! ThreadMatrixC lengths is wrong"
);
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
BlockMatrixA
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
row
));
mMyThreadOffsetB
=
BlockMatrixB
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
col
));
}
__device__
static
constexpr
auto
GetThreadMatrixCLengths
()
...
...
@@ -104,103 +153,30 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run_pipelined_2x2
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
auto
K
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
MPerThread
=
c_thread_mtx
.
GetLength
(
I0
);
constexpr
auto
NPerThread
=
c_thread_mtx
.
GetLength
(
I1
);
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
ThreadGemmBDataPerRead_N
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
decltype
(
a_thread_mtx
),
decltype
(
b_thread_mtx
),
decltype
(
c_thread_mtx
)
>
{};
#pragma unroll
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#pragma unroll
// read A
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
))
+
mMyThreadOffsetA
,
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
m_repeat
*
MPerThreadSubC
)));
}
#pragma unroll
// read B
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
b_thread_copy
.
Run
(
p_b_block
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
))
+
mMyThreadOffsetB
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
n_repeat
*
NPerThreadSubC
)));
}
// C += A * B
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
}
}
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABlockBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBlockBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_pipelined_2x2
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
_desc
=
ThreadMatrixC
{};
constexpr
auto
K
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
MPerThread
=
c_thread_mtx
.
GetLength
(
I0
);
constexpr
auto
NPerThread
=
c_thread_mtx
.
GetLength
(
I1
);
constexpr
auto
MPerThread
=
c_thread_mtx
_desc
.
GetLength
(
I0
);
constexpr
auto
NPerThread
=
c_thread_mtx
_desc
.
GetLength
(
I1
);
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
...
...
@@ -211,15 +187,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
"wrong! inline asm cannot deal with this GEMM config yet"
);
// thread A, B
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{}));
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{}));
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
"wrong! only support 2x2 pipeline"
);
// thread A-sub, B-sub
constexpr
auto
a_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
...
...
@@ -234,113 +202,152 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
Number
<
NPerThread
>
{},
Number
<
1
>
{}));
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpaceSize
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpaceSize
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
ThreadGemmBDataPerRead_N
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
decltype
(
a_thread_sub_mtx
),
decltype
(
b_thread_sub_mtx
),
decltype
(
c_thread_sub_mtx
)
>
{};
auto
a_thread_buf
=
make_static_buffer
<
FloatA
>
(
a_thread_mtx_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
FloatB
>
(
b_thread_mtx_desc_
.
GetElementSpaceSize
());
const
FloatA
*
p_a_block_off
=
p_a_block
+
mMyThreadOffsetA
;
const
FloatB
*
p_b_block_off
=
p_b_block
+
mMyThreadOffsetB
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1r1
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_sub_mtx
),
decltype
(
b_thread_sub_mtx
),
decltype
(
c_thread_sub_mtx
)
>
{};
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
,
p_a_thread
);
a_thread_copy_
.
Run
(
BlockMatrixA
{},
make_tuple
(
I0
,
I0
),
a_block_buf
,
a_thread_mtx_desc_
,
make_tuple
(
I0
,
I0
),
a_thread_buf
);
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
,
p_b_thread
);
b_thread_copy_
.
Run
(
BlockMatrixB
{},
make_tuple
(
I0
,
I0
),
b_block_buf
,
b_thread_mtx_desc_
,
make_tuple
(
I0
,
I0
),
b_thread_buf
);
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerLevel1Cluster
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
b_thread_copy_
.
Run
(
BlockMatrixB
{},
make_tuple
(
I0
,
Number
<
NPerLevel1Cluster
>
{}),
b_block_buf
,
b_thread_mtx_desc_
,
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}),
b_thread_buf
);
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerLevel1Cluster
)),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)));
a_thread_copy_
.
Run
(
BlockMatrixA
{},
make_tuple
(
I0
,
Number
<
MPerLevel1Cluster
>
{}),
a_block_buf
,
a_thread_mtx_desc_
,
make_tuple
(
I0
,
Number
<
MPerThreadSubC
>
{}),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}),
c_thread_buf
,
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}));
#pragma unroll
// loop over rest of k
for
(
index_t
k
=
KPerThreadLoop
;
k
<
K
;
k
+=
KPerThreadLoop
)
{
static_for
<
KPerThreadLoop
,
K
,
KPerThreadLoop
>
{}([
&
](
auto
k
)
{
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
0
)),
p_a_thread
);
a_thread_copy_
.
Run
(
BlockMatrixA
{},
make_tuple
(
k
,
I0
),
a_block_buf
,
a_thread_mtx_desc_
,
make_tuple
(
I0
,
I0
),
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
,
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
0
)));
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
Number
<
MPerThreadSubC
>
{}),
b_thread_buf
,
make_tuple
(
I0
,
I0
),
c_thread_buf
,
make_tuple
(
Number
<
MPerThreadSubC
>
{},
I0
));
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
0
)),
p_b_thread
);
b_thread_copy_
.
Run
(
BlockMatrixB
{},
make_tuple
(
k
,
I0
),
b_block_buf
,
b_thread_mtx_desc_
,
make_tuple
(
I0
,
I0
),
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
NPerThreadSubC
)));
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
Number
<
MPerThreadSubC
>
{}),
b_thread_buf
,
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}),
c_thread_buf
,
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}));
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
NPerLevel1Cluster
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
b_thread_copy_
.
Run
(
BlockMatrixB
{},
make_tuple
(
k
,
Number
<
NPerLevel1Cluster
>
{}),
b_block_buf
,
b_thread_mtx_desc_
,
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}),
b_thread_buf
);
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
MPerLevel1Cluster
)),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)));
a_thread_copy_
.
Run
(
BlockMatrixA
{},
make_tuple
(
k
,
Number
<
MPerLevel1Cluster
>
{}),
a_block_buf
,
a_thread_mtx_desc_
,
make_tuple
(
I0
,
Number
<
MPerThreadSubC
>
{}),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
}
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}),
c_thread_buf
,
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
,
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
0
)));
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
Number
<
MPerThreadSubC
>
{}),
b_thread_buf
,
make_tuple
(
I0
,
I0
),
c_thread_buf
,
make_tuple
(
Number
<
MPerThreadSubC
>
{},
I0
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
NPerThreadSubC
)));
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
Number
<
MPerThreadSubC
>
{}),
b_thread_buf
,
make_tuple
(
I0
,
Number
<
NPerThreadSubC
>
{}),
c_thread_buf
,
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}));
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -354,17 +361,16 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1
if
constexpr
(
MRepeat
==
2
&&
NRepeat
==
2
)
{
Run_pipelined_2x2
(
p_
a_block
,
p_
b_block
,
p_
c_thread
);
Run_pipelined_2x2
(
a_block
_buf
,
b_block
_buf
,
c_thread
_buf
);
}
else
{
Run_naive
(
p_
a_block
,
p_
b_block
,
p_
c_thread
);
Run_naive
(
a_block
_buf
,
b_block
_buf
,
c_thread
_buf
);
}
#else
Run_naive
(
p_
a_block
,
p_
b_block
,
p_
c_thread
);
Run_naive
(
a_block
_buf
,
b_block
_buf
,
c_thread
_buf
);
#endif
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
View file @
d075adf1
...
...
@@ -6,12 +6,10 @@
namespace
ck
{
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// KPerThread, HPerThread, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
ThreadMatrixC
,
...
...
@@ -30,9 +28,34 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
index_t
w
;
};
index_t
mMyThreadOffsetA
;
// HACK: fix this @Jing Zhang
static
constexpr
index_t
KPerThreadSubC
=
4
;
static
constexpr
auto
a_thread_mtx_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadSubC
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
BlockMatrixA
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
1
,
ThreadGemmADataPerRead_K
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v3
()
:
c_thread_begin_mtx_idx_
{
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_begin_mtx_idx_
.
k
*
KPerThread
)}
{
static_assert
(
BlockMatrixA
::
IsKnownAtCompileTime
()
&&
BlockMatrixB
::
IsKnownAtCompileTime
()
&&
...
...
@@ -61,11 +84,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert
(
BlockSize
==
KThreadCluster
*
HThreadCluster
*
WThreadCluster
,
"wrong! wrong blocksize
\n
"
);
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
BlockMatrixA
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
k
*
KPerThread
));
}
__device__
static
constexpr
auto
GetThreadMatrixCLengths
()
...
...
@@ -91,37 +109,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
return
MatrixIndex
{
k_thread_id
,
h_thread_id
,
w_thread_id
};
}
template
<
typename
SrcDesc
,
typename
DstDesc
,
index_t
NSliceRow
,
index_t
NSliceCol
,
index_t
DataPerAccess
>
struct
ThreadwiseSliceCopy_a
{
template
<
typename
Data
>
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
using
vector_t
=
typename
vector_type_maker
<
Data
,
DataPerAccess
>::
type
::
type
;
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
constexpr
auto
src_offset
=
SrcDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
constexpr
auto
dst_offset
=
DstDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
]);
});
});
}
};
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_thread
,
FloatC
*
p_c_thread
)
const
template
<
typename
ABlockBuffer
,
typename
BThreadBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BThreadBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABlockBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -132,8 +131,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr
auto
EPerBlock
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
KPerThreadSubC
=
4
;
// HACK: fix this @Jing Zhang
constexpr
auto
HoPerThreadSubC
=
2
;
constexpr
auto
WoPerThreadSubC
=
2
;
...
...
@@ -141,63 +139,53 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert
(
HPerThread
%
HoPerThreadSubC
==
0
,
""
);
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadSubC
>
{}));
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
constexpr
auto
c_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpaceSize
()];
// thread A buffer for GEMM
StaticBuffer
<
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
>
a_thread_buf
;
constexpr
auto
a_thread_copy
=
ThreadwiseSliceCopy_a
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
EPerThreadLoop
,
KPerThreadSubC
,
ThreadGemmADataPerRead_K
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v3
<
decltype
(
a_thread_mtx
),
decltype
(
b_thread_mtx
),
decltype
(
c_thread_mtx
),
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v3
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_mtx_
),
decltype
(
b_thread_mtx_
),
decltype
(
c_thread_mtx_
),
HoPerThreadSubC
,
WoPerThreadSubC
>
{};
// loop over k
#pragma unroll
for
(
index_t
e_begin
=
0
;
e_begin
<
EPerBlock
;
e_begin
+=
EPerThreadLoop
)
{
#pragma unroll
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerThread
;
k_begin
+=
KPerThreadSubC
)
{
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
e_begin
,
k_begin
))
+
mMyThreadOffsetA
,
p_a_thread
);
#pragma unroll
for
(
index_t
h_begin
=
0
;
h_begin
<
HPerThread
;
h_begin
+=
HoPerThreadSubC
)
{
#pragma unroll
for
(
index_t
w_begin
=
0
;
w_begin
<
WPerThread
;
w_begin
+=
WoPerThreadSubC
)
{
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
e_begin
,
0
,
h_begin
,
w_begin
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
0
,
h_begin
,
w_begin
)));
}
}
}
}
static_for
<
0
,
EPerBlock
,
EPerThreadLoop
>
{}([
&
](
auto
e_begin
)
{
static_for
<
0
,
KPerThread
,
KPerThreadSubC
>
{}([
&
](
auto
k_begin
)
{
a_thread_copy_
.
Run
(
a_block_mtx
,
make_tuple
(
e_begin
,
k_begin
),
a_block_buf
,
a_thread_mtx_
,
make_tuple
(
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
HPerThread
,
HoPerThreadSubC
>
{}([
&
](
auto
h_begin
)
{
static_for
<
0
,
WPerThread
,
WoPerThreadSubC
>
{}([
&
](
auto
w_begin
)
{
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
),
b_thread_buf
,
make_tuple
(
e_begin
,
I0
,
h_begin
,
w_begin
),
c_thread_buf
,
make_tuple
(
k_begin
,
I0
,
h_begin
,
w_begin
));
});
});
});
});
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_thread
,
FloatC
*
p_c_thread
)
const
template
<
typename
ABlockSliceMoveStepIdx
>
__device__
void
MoveASliceWindow
(
const
BlockMatrixA
&
,
const
ABlockSliceMoveStepIdx
&
a_block_slice_move_step_idx
)
{
Run_naive
(
p_a_block
,
p
_b
_thread
,
p_c_thread
);
a_thread_copy_
.
MoveSrcSliceWindow
(
BlockMatrixA
{}
,
a
_b
lock_slice_move_step_idx
);
}
private:
MatrixIndex
c_thread_begin_mtx_idx_
;
AThreadCopy
a_thread_copy_
;
};
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
d075adf1
...
...
@@ -5,9 +5,10 @@
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "
blockwise_gemm_v2
.hpp"
#include "
threadwise_dynamic_tensor_slice_set
.hpp"
namespace
ck
{
...
...
@@ -256,19 +257,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
make_tuple
(
Number
<
MRepeat
*
MPerThread
>
{},
Number
<
NRepeat
*
NPerThread
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v1
<
BlockSize
,
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
c_m0m1_n0n1_thread_desc
),
MPerThread
,
NPerThread
,
KPerThread
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
MPerThread
,
NPerThread
>
{};
BlockwiseGemm_km_kn_m0m1n0n1_v1r1
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
c_m0m1_n0n1_thread_desc
),
MPerThread
,
NPerThread
,
KPerThread
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
MPerThread
,
NPerThread
>
{};
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
...
...
@@ -281,10 +285,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
FloatAcc
p_c_thread
[
c_m0m1_n0n1_thread_desc
.
GetElementSpaceSize
()];
auto
c_thread_buf
=
make_static_buffer
<
FloatAcc
>
(
c_m0m1_n0n1_thread_desc
.
GetElementSpaceSize
());
// zero out threadwise output
threadwise_matrix_set_zero_v2
(
c_m0m1_n0n1_thread_desc
,
p_c_thread
);
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m0m1_n0n1_thread_desc
),
Sequence
<
MRepeat
*
MPerThread
,
NRepeat
*
NPerThread
>>
{}
.
Run
(
c_m0m1_n0n1_thread_desc
,
make_tuple
(
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
...
...
@@ -300,6 +307,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
FloatAB
*
p_a_block_even
=
p_a_block_double
;
FloatAB
*
p_b_block_even
=
p_b_block_double
;
FloatAB
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
FloatAB
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
auto
a_block_even_buf
=
make_dynamic_buffer
(
p_a_block_even
);
auto
b_block_even_buf
=
make_dynamic_buffer
(
p_b_block_even
);
auto
a_block_odd_buf
=
make_dynamic_buffer
(
p_a_block_odd
);
auto
b_block_odd_buf
=
make_dynamic_buffer
(
p_b_block_odd
);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
...
...
@@ -311,12 +330,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
if
constexpr
(
HasMainKBlockLoop
)
{
FloatAB
*
p_a_block_even
=
p_a_block_double
;
FloatAB
*
p_b_block_even
=
p_b_block_double
;
FloatAB
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
FloatAB
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
index_t
k_block_data_begin
=
0
;
// LDS double buffer: main body
...
...
@@ -340,7 +353,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_
a_block_even
,
p_
b_block_even
,
p_
c_thread
);
blockwise_gemm
.
Run
(
a_block_even
_buf
,
b_block_even
_buf
,
c_thread
_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_odd
);
...
...
@@ -363,7 +376,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_
a_block_odd
,
p_
b_block_odd
,
p_
c_thread
);
blockwise_gemm
.
Run
(
a_block_odd
_buf
,
b_block_odd
_buf
,
c_thread
_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_even
);
...
...
@@ -390,7 +403,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_
a_block_
double
,
p_b_block_double
,
p_
c_thread
);
blockwise_gemm
.
Run
(
a_block_
even_buf
,
b_block_even_buf
,
c_thread
_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
...
...
@@ -399,16 +412,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
p_b_block_double
+
b_block_space_size
,
p_c_thread
);
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_
a_block_
double
,
p_b_block_double
,
p_
c_thread
);
blockwise_gemm
.
Run
(
a_block_
even_buf
,
b_block_even_buf
,
c_thread
_buf
);
}
// output: register to global memory
...
...
@@ -461,7 +472,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
n_thread_data_on_global
%
N1
))
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
c_thread
,
c_thread
_buf
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
d075adf1
...
...
@@ -145,17 +145,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
decltype
(
a_e_k_block_desc
),
decltype
(
b_e_n_ho_wo_block_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
KPerThread
,
HoPerThread
,
WoPerThread
,
EPerThread
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K
>
{};
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_e_k_block_desc
),
decltype
(
b_e_n_ho_wo_block_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
KPerThread
,
HoPerThread
,
WoPerThread
,
EPerThread
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K
>
{};
auto
c_thread_mtx_index
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
...
...
@@ -223,11 +225,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
FloatAB
*
p_a_block
=
p_shared_block
;
auto
a_block_buf
=
make_dynamic_buffer
(
p_a_block
);
// register allocation for output
FloatAcc
p_c_thread
[
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
]
;
StaticBuffer
<
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
c_thread_buf
;
// zero out threadwise output
threadwise_matrix_set_zero_v3
(
c_k_n_ho_wo_thread_desc
,
p_c_thread
);
// initialize output thread tensor
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_k_n_ho_wo_thread_desc
),
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>>
{}
.
Run
(
c_k_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
...
...
@@ -242,12 +249,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_thread_space_size
=
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
();
FloatAB
p_b_thread
[
b_thread_space_size
*
2
];
FloatAB
*
p_b_thread_double
=
p_b_thread
;
// double regsiter buffer for b
StaticBuffer
<
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
b_thread_even_buf
,
b_thread_odd_buf
;
// LDS double buffer: preload data
into LDS
// LDS double buffer: preload data
{
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
p_a_global
,
a_e_k_global_iterator_hacks
);
...
...
@@ -255,7 +261,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
b_thread_
double
,
b_thread_
even_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
p_a_block
);
...
...
@@ -263,13 +269,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
__syncthreads
();
index_t
b_block_data_begin
=
0
;
#if 1
if
constexpr
(
HasMainKBlockLoop
)
{
FloatAB
*
p_b_thread_even
=
p_b_thread_double
;
FloatAB
*
p_b_thread_odd
=
p_b_thread_double
+
b_thread_space_size
;
index_t
e_block_data_begin
=
0
;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
...
...
@@ -283,16 +285,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
b_thread_odd
,
b_thread_odd
_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
p_b_thread_even
,
p_c_thread
);
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
b_
block
_data_begin
+=
EPerBlock
;
block
wise_gemm
.
MoveASliceWindow
(
a_e_k_block_desc
,
make_tuple
(
EPerBlock
,
0
))
;
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_ho_wo_global_desc
,
b_thread_slice_copy_step
);
...
...
@@ -301,18 +301,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
b_thread_even
,
b_thread_even
_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
p_b_thread_odd
,
p_c_thread
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
b_
block
_data_begin
+=
EPerBlock
;
block
wise_gemm
.
MoveASliceWindow
(
a_e_k_block_desc
,
make_tuple
(
EPerBlock
,
0
))
;
}
while
(
b_block_data_begin
<
E
-
2
*
EPerBlock
);
e_block_data_begin
+=
2
*
EPerBlock
;
}
while
(
e_block_data_begin
<
E
-
2
*
EPerBlock
);
}
// LDS double buffer: tail
...
...
@@ -325,34 +324,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
b_thread_
double
+
b_thread_space_size
,
b_thread_
odd_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
p_b_thread_double
,
p_c_thread
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
b_
block
_data_begin
+=
EPerBlock
;
block
wise_gemm
.
MoveASliceWindow
(
a_e_k_block_desc
,
make_tuple
(
EPerBlock
,
0
))
;
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
p_b_thread_double
+
b_thread_space_size
,
p_c_thread
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
p_b_thread_double
,
p_c_thread
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
}
#endif
#if 1
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
...
...
@@ -380,12 +368,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
k_thread_data_on_global
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
))
.
Run
(
c_k_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
c_thread
,
c_thread
_buf
,
c_k_n_ho_wo_global_desc
,
p_c_global
,
c_k_n_ho_wo_global_tensor_iterator_hacks
);
}
#endif
}
// pass tensor descriptor by reference
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp
0 → 100644
View file @
d075adf1
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace
ck
{
// Assume:
// 1. Desc is known at compile-time
// 2. Buffer is StaticBuffer
// 3. OriginIdx is known at compile-time
// 4. use #-iterator
template
<
typename
Data
,
typename
Desc
,
typename
SliceLengths
,
typename
std
::
enable_if
<
Desc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseDynamicTensorSliceSet_v1
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
template
<
typename
OriginIdx
,
typename
Buffer
>
__device__
void
Run
(
const
Desc
&
,
const
OriginIdx
&
,
Buffer
&
buf
,
const
Data
&
initial_value
)
const
{
static_assert
(
Desc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
Buffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
OriginIdx
>>>::
value
,
"wrong! OriginIdx need to be known at compile-time"
);
// Desc is known at compile-time
constexpr
auto
desc
=
remove_cv_t
<
remove_reference_t
<
Desc
>>
{};
// OriginIdx is known at compile-time
constexpr
auto
origin_idx
=
to_multi_index
(
OriginIdx
{});
static_ford
<
SliceLengths
>
{}([
&
](
auto
access_idx
)
{
constexpr
auto
coord
=
make_dynamic_tensor_coordinate
(
desc
,
origin_idx
+
access_idx
);
constexpr
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
desc
,
coord
);
constexpr
index_t
offset
=
coord
.
GetOffset
();
if
constexpr
(
is_valid
)
{
buf
(
Number
<
offset
>
{})
=
initial_value
;
}
});
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
d075adf1
...
...
@@ -7,6 +7,15 @@
namespace
ck
{
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
namespace
detail
{
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor
template
<
index_t
VectorDim
,
index_t
ScalarPerVector
>
...
...
@@ -26,12 +35,17 @@ struct lambda_scalar_step_in_vector
return
(
i
==
VectorDim
)
?
1
:
0
;
}
};
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// Assume src_slice_origin_idx is 0
// TODO: support non-zero src_slice_oring_idx
}
// namespace detail
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is StaticBuffer
// 3. SrcSliceOrginIdx is known at compile-time
// 2. dst:
// 1. DstDesc is not known at compile-time
// 2. DstBuffer is DynamicBuffer
// 3. DstSliceOrginIdx is not known at compile time
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
...
...
@@ -69,10 +83,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
dst_slice_origin_coord_
=
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
template
<
typename
SrcSliceOriginIdx
,
typename
DstIteratorHacks
>
template
<
typename
SrcSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstIteratorHacks
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
Src
Data
*
p_src
,
const
Src
Buffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstData
*
p_dst
,
const
DstIteratorHacks
&
dst_iterator_hacks
)
...
...
@@ -84,9 +98,15 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcSliceOriginIdx
>>>::
value
,
"wrong! SrcSliceOrigin need to known at compile-time"
);
static_assert
(
SrcBuffer
::
IsStaticBuffer
(),
"wrong! SrcBuffer need to be StaticBuffer"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
SrcData
>>>::
value
,
"wrong! SrcBuffer data type is wrong"
);
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cv_t
<
remove_reference_t
<
SrcDesc
>>
{};
constexpr
auto
src_slice_origin_idx
=
SrcSliceOriginIdx
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{}
)
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -94,10 +114,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
...
@@ -178,12 +198,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
to_multi_index
(
src_slice_origin_idx
)
+
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
{}(
p_
src
[
Number
<
src_offset
>
{}]);
type_convert
<
DstData
>
{}(
src
_buf
[
Number
<
src_offset
>
{}]);
});
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
...
...
@@ -284,7 +303,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
...
@@ -359,10 +378,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
DstCoord
dst_slice_origin_coord_
;
};
// namespace ck
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// Assume dst_slice_origin_idx is 0
// Assume:
// 1. src_desc is not known at compile-time
// 2. dst_desc is known at compile-time
// 3. src_slice_origin_idx is not known at compile-time
// 4. dst_slice_origin_idx is known at compile-time and it's 0
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
...
...
@@ -399,12 +419,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
src_slice_origin_coord_
=
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
template
<
typename
DstSliceOriginIdx
,
typename
SrcIteratorHacks
>
template
<
typename
DstBuffer
,
typename
DstSliceOriginIdx
,
typename
SrcIteratorHacks
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
Dst
Data
*
p_dst
,
Dst
Buffer
&
dst_buf
,
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
...
...
@@ -414,6 +434,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
DstSliceOriginIdx
>>>::
value
,
"wrong! DstSliceOrigin need to known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
DstData
>>>::
value
&&
"wrong! inconsistent type"
);
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr
auto
dst_desc
=
remove_cv_t
<
remove_reference_t
<
DstDesc
>>
{};
constexpr
auto
dst_slice_origin_idx
=
DstSliceOriginIdx
{};
...
...
@@ -424,10 +448,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence
(
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
...
@@ -541,7 +565,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
i
*
src_scalar_step_in_vector
);
p_
dst
[
Number
<
dst_offset
>
{}
]
=
src_vector
.
template
AsType
<
SrcData
>()[
i
];
dst
_buf
(
Number
<
dst_offset
>
{}
)
=
src_vector
.
template
AsType
<
SrcData
>()[
i
];
});
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
...
...
@@ -590,7 +614,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
}
}
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
,
DstData
*
p_dst
)
template
<
typename
DstBuffer
,
typename
DstSliceOriginIdx
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
{
constexpr
index_t
ntransform_src
=
SrcDesc
::
GetNumOfTransform
();
...
...
@@ -600,7 +629,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
Run
(
src_desc
,
p_src
,
p_dst
,
src_iterator_hacks
);
Run
(
src_desc
,
p_src
,
DstDesc
{},
DstSliceOriginIdx
{},
dst_buf
,
src_iterator_hacks
);
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
...
...
@@ -610,7 +639,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
...
@@ -685,12 +714,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
SrcCoord
src_slice_origin_coord_
;
};
// namespace ck
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// 3. It does not use pointer for VGPR thread buffer
// 4. It calculate offset for thread buffer directly, instead of moving the coordinate
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. src_slice_origin and dst_slice_origin are not known at compile-time,
// 3. Use thread buffer
template
<
typename
SliceLengths
,
InMemoryDataOperation
DstInMemOp
,
typename
SrcData
,
...
...
@@ -737,6 +764,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
static_assert
(
DstAddressSpace
==
AddressSpace
::
Global
or
DstAddressSpace
==
AddressSpace
::
Lds
,
"wrong!"
);
// TODO: fix this
static_assert
(
is_same
<
SrcData
,
DstData
>::
value
,
"wrong! current implementation assume SrcData and DstData are same type"
);
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
...
@@ -760,10 +791,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence
(
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
...
@@ -838,11 +869,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
return
src_data_idx
;
}();
// copy data
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>
::
type
src_vector
;
// copy data
from src_buf to src_tmp_vector
vector_type_maker
_t
<
SrcData
,
SrcScalarPerVector
>
src
_tmp
_vector
;
using
src_vector_t
=
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
::
type
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_coord_
);
...
...
@@ -850,14 +880,14 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Global
)
{
#if CK_USE_AMD_BUFFER_ADDRESSING
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_
tmp_
vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
amd_buffer_load_v2
<
SrcData
,
SrcScalarPerVector
>
(
p_src
,
src_slice_origin_coord_
.
GetOffset
(),
is_src_valid
,
src_desc
.
GetElementSpaceSize
());
#else
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_
tmp_
vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
...
...
@@ -865,17 +895,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
}
else
{
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_
tmp_
vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
}
// copy data from src_tmp_vector to buffer_
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
buffer_offset
=
buffer_desc_
.
CalculateOffset
(
src_data_idx
+
i
*
src_scalar_step_in_vector
);
buffer_
(
Number
<
buffer_offset
>
{})
=
src_vector
.
template
AsType
<
SrcData
>()[
i
];
buffer_
(
Number
<
buffer_offset
>
{})
=
src_
tmp_
vector
.
template
AsType
<
SrcData
>()[
i
];
});
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
...
...
@@ -937,10 +968,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// src scalar per access on each dim
// TODO: don't use this
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
...
@@ -1026,20 +1057,21 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
DstInMemOp
==
InMemoryDataOperation
::
Set
,
"wrong! hardcoded for ds_write"
);
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>
::
type
dst_vector
;
vector_type_maker
_t
<
DstData
,
DstScalarPerVector
>
dst
_tmp
_vector
;
// copy data from buffer_ to dst_tmp_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
buffer_offset
=
buffer_desc_
.
CalculateOffset
(
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
buffer_
[
Number
<
buffer_offset
>
{}];
dst_
tmp_
vector
.
template
AsType
<
DstData
>()(
i
)
=
buffer_
[
Number
<
buffer_offset
>
{}];
});
using
DstVectorType
=
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
using
dst_vector_t
=
typename
decltype
(
dst_tmp_vector
)
::
type
;
*
reinterpret_cast
<
DstVectorType
*>
(
p_dst
+
dst_slice_origin_coord_
.
GetOffset
())
=
dst_vector
.
template
AsType
<
DstVectorType
>()[
Number
<
0
>
{}];
// copy data from dst_tmp_vector to dst_buf
*
reinterpret_cast
<
dst_vector_t
*>
(
p_dst
+
dst_slice_origin_coord_
.
GetOffset
())
=
dst_tmp_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}];
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
...
...
@@ -1123,7 +1155,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
...
@@ -1185,7 +1217,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
...
@@ -1274,7 +1306,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_coord_
,
adjusted_step
);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_step_idx
)
...
...
@@ -1297,11 +1328,203 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
Static
allyIndexedArray
<
SrcData
,
buffer_size_
>
buffer_
;
Static
Buffer
<
SrcData
,
buffer_size_
>
buffer_
;
SrcCoord
src_slice_origin_coord_
;
DstCoord
dst_slice_origin_coord_
;
};
// Assume:
// 1. src:
// 1. SrcDesc is known at compile-time
// 2. SrcBuffer is DynamicBuffer
// 3. src_ref_idx is known at run-time
// 4. SrcRefToOriginDisplacement is known at compile-time
// 5. use #-iterator
// 2. dst:
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation
// 3. vector access on src
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
index_t
SrcScalarStrideInVector
,
typename
std
::
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseDynamicTensorSliceTransfer_v4
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_dynamic_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoordIterator
=
decltype
(
make_dynamic_tensor_coordinate_iterator
(
SrcDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v4
(
const
Index
&
src_ref_idx
)
:
src_ref_coord_
(
make_dynamic_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
}
template
<
typename
SrcRefToOriginDisplacement
,
typename
DstOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcRefToOriginDisplacement
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstOriginIdx
&
,
DstBuffer
&
dst_buf
)
const
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
SrcBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
SrcData
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
DstBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
DstData
>>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
SrcRefToOriginDisplacement
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
DstOriginIdx
>>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
// SrcDesc and DstDesc are known at compile-time
constexpr
auto
src_desc
=
remove_cv_t
<
remove_reference_t
<
SrcDesc
>>
{};
constexpr
auto
dst_desc
=
remove_cv_t
<
remove_reference_t
<
DstDesc
>>
{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr
auto
src_ref_to_origin_disp_idx
=
to_multi_index
(
SrcRefToOriginDisplacement
{});
constexpr
auto
dst_origin_idx
=
to_multi_index
(
DstOriginIdx
{});
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
// scalar per access of each dim
constexpr
auto
src_scalar_per_access
=
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Number
<
SrcScalarPerVector
>
{};
}
else
{
return
Number
<
1
>
{};
}
},
Number
<
nDim
>
{});
// scalar step (if steping on SrcVectorDim) of each dim
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Number
<
1
>
{};
}
else
{
return
Number
<
0
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
#if 0
// TODO: unable to compile
// position in slice window
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
#else
// position in slice window
constexpr
auto
data_to_origin_disp_idx
=
ordered_access_idx
.
ReorderGivenOld2New
(
dim_access_order
)
*
src_scalar_per_access
;
#endif
// src coordinate
constexpr
auto
src_ref_to_data_disp_idx
=
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
constexpr
auto
src_ref_to_data_disp_coord_iterator
=
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
src_ref_to_data_disp_idx
);
auto
src_data_coord
=
src_ref_coord_
;
move_dynamic_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_iterator
);
// copy data from src_buf into src_tmp_buffer
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_data_coord
);
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
is_src_valid
?
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
())
:
src_vector_t
{
0
};
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
{}(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
});
}
template
<
typename
SrcSliceMoveStepIdx
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
,
const
SrcSliceMoveStepIdx
&
src_slice_move_step_idx
)
{
constexpr
auto
src_desc
=
SrcDesc
{};
const
auto
src_slice_move_step_iter
=
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
to_multi_index
(
src_slice_move_step_idx
));
move_dynamic_tensor_coordinate
(
SrcDesc
{},
src_ref_coord_
,
src_slice_move_step_iter
);
}
private:
SrcCoord
src_ref_coord_
;
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
View file @
d075adf1
...
...
@@ -6,100 +6,52 @@
namespace
ck
{
template
<
typename
Float
,
typename
Desc
>
__device__
void
threadwise_matrix_set_zero_v2
(
Desc
,
Float
*
__restrict__
p_thread
)
{
static_assert
(
Desc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
desc
=
Desc
{};
constexpr
auto
M
=
desc
.
GetLength
(
I0
);
constexpr
auto
N
=
desc
.
GetLength
(
I1
);
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
offset
=
desc
.
CalculateOffset
(
make_tuple
(
i
,
j
));
p_thread
[
offset
]
=
Float
(
0
);
});
});
}
template
<
typename
SrcDesc
,
typename
DstDesc
,
index_t
NSliceRow
,
index_t
NSliceCol
,
index_t
DataPerAccess
>
struct
ThreadwiseMatrixSliceCopy_v2
{
template
<
typename
Data
>
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
using
vector_t
=
typename
vector_type_maker
<
Data
,
DataPerAccess
>::
type
::
type
;
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
constexpr
auto
src_offset
=
SrcDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
constexpr
auto
dst_offset
=
DstDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
]);
});
});
}
};
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
template
<
typename
ADesc
,
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
,
typename
std
::
enable_if
<
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseGemm_km_kn_mn_v1
struct
ThreadwiseGemm_km_kn_mn_v1
r1
{
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
M
=
CDesc
{}.
GetLength
(
I0
);
constexpr
auto
N
=
CDesc
{}.
GetLength
(
I1
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I0
);
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
k
,
m
));
constexpr
auto
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
n
));
constexpr
auto
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
n
));
p_c
[
c_offset
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
a_offset
],
p_b
[
b_offset
]);
});
});
});
}
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_amd_asm
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong!
Desc should be known at compile-tim
e"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong!
inconsistent typ
e"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -110,61 +62,81 @@ struct ThreadwiseGemm_km_kn_mn_v1
constexpr
auto
N
=
CDesc
{}.
GetLength
(
I1
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I0
);
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
k
,
m
));
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
k
,
m
));
#if 0
if constexpr(N == 2)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I1
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I1
));
amd_assembly_outer_product_1x2
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
]);
constexpr index_t b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr index_t b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr index_t c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
amd_assembly_outer_product_1x2(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}));
}
else if constexpr(N == 4)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I1
));
constexpr
auto
b_offset_2
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I2
));
constexpr
auto
b_offset_3
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I3
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I1
));
constexpr
auto
c_offset_2
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I2
));
constexpr
auto
c_offset_3
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I3
));
amd_assembly_outer_product_1x4
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_b
[
b_offset_2
],
p_b
[
b_offset_3
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
],
p_c
[
c_offset_2
],
p_c
[
c_offset_3
]);
constexpr index_t b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr index_t b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr index_t b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2));
constexpr index_t b_offset_3 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3));
constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr index_t c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
constexpr index_t c_offset_2 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2));
constexpr index_t c_offset_3 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3));
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
}
});
});
}
else
#endif
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm
(
p_a
,
p_b
,
p_c
);
#else
Run_source
(
p_a
,
p_b
,
p_c
);
#endif
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
n
));
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
n
));
amd_assembly_inner_product
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
c_buf
(
Number
<
c_offset
>
{}));
});
}
});
});
}
};
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
View file @
d075adf1
...
...
@@ -6,35 +6,15 @@
namespace
ck
{
template
<
typename
Float
,
typename
Desc
>
__device__
void
threadwise_matrix_set_zero_v3
(
Desc
,
Float
*
__restrict__
p_thread
)
{
static_assert
(
Desc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
desc
=
Desc
{};
constexpr
auto
K
=
desc
.
GetLength
(
I0
);
constexpr
auto
H
=
desc
.
GetLength
(
I2
);
constexpr
auto
W
=
desc
.
GetLength
(
I3
);
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
H
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
W
,
1
>
{}([
&
](
auto
k
)
{
constexpr
auto
offset
=
desc
.
CalculateOffset
(
make_tuple
(
i
,
0
,
j
,
k
));
p_thread
[
offset
]
=
Float
(
0
);
});
});
});
}
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
template
<
typename
ADesc
,
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
,
index_t
H
,
...
...
@@ -44,13 +24,37 @@ template <typename ADesc,
bool
>
::
type
=
false
>
struct
ThreadwiseGemm_km_kn_mn_v3
{
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -59,79 +63,100 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr
auto
E
=
ADesc
{}.
GetLength
(
I0
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I1
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
static_for
<
0
,
E
,
1
>
{}([
&
](
auto
e
)
{
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
e
,
k
));
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
e
,
k
));
if
constexpr
(
H
==
2
&&
W
==
2
)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
0
,
0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
0
,
1
));
constexpr
auto
b_offset_2
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
1
,
0
));
constexpr
auto
b_offset_3
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
1
,
1
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
0
,
0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
0
,
1
));
constexpr
auto
c_offset_2
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
1
,
0
));
constexpr
auto
c_offset_3
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
1
,
1
));
amd_assembly_outer_product_1x4
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_b
[
b_offset_2
],
p_b
[
b_offset_3
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
],
p_c
[
c_offset_2
],
p_c
[
c_offset_3
]);
constexpr
index_t
b_offset_0
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
0
,
0
));
constexpr
index_t
b_offset_1
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
0
,
1
));
constexpr
index_t
b_offset_2
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
1
,
0
));
constexpr
index_t
b_offset_3
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
1
,
1
));
constexpr
index_t
c_offset_0
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
0
,
0
));
constexpr
index_t
c_offset_1
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
0
,
1
));
constexpr
index_t
c_offset_2
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
1
,
0
));
constexpr
index_t
c_offset_3
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
1
,
1
));
amd_assembly_outer_product_1x4
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset_0
>
{}],
b_buf
[
Number
<
b_offset_1
>
{}],
b_buf
[
Number
<
b_offset_2
>
{}],
b_buf
[
Number
<
b_offset_3
>
{}],
c_buf
(
Number
<
c_offset_0
>
{}),
c_buf
(
Number
<
c_offset_1
>
{}),
c_buf
(
Number
<
c_offset_2
>
{}),
c_buf
(
Number
<
c_offset_3
>
{}));
}
else
if
constexpr
(
H
==
4
&&
W
==
1
)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
0
,
0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
1
,
0
));
constexpr
auto
b_offset_2
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
2
,
0
));
constexpr
auto
b_offset_3
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
3
,
0
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
0
,
0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
1
,
0
));
constexpr
auto
c_offset_2
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
2
,
0
));
constexpr
auto
c_offset_3
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
3
,
0
));
amd_assembly_outer_product_1x4
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_b
[
b_offset_2
],
p_b
[
b_offset_3
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
],
p_c
[
c_offset_2
],
p_c
[
c_offset_3
]);
constexpr
index_t
b_offset_0
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
0
,
0
));
constexpr
index_t
b_offset_1
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
1
,
0
));
constexpr
index_t
b_offset_2
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
2
,
0
));
constexpr
index_t
b_offset_3
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
3
,
0
));
constexpr
index_t
c_offset_0
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
0
,
0
));
constexpr
index_t
c_offset_1
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
1
,
0
));
constexpr
index_t
c_offset_2
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
2
,
0
));
constexpr
index_t
c_offset_3
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
3
,
0
));
amd_assembly_outer_product_1x4
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset_0
>
{}],
b_buf
[
Number
<
b_offset_1
>
{}],
b_buf
[
Number
<
b_offset_2
>
{}],
b_buf
[
Number
<
b_offset_3
>
{}],
c_buf
(
Number
<
c_offset_0
>
{}),
c_buf
(
Number
<
c_offset_1
>
{}),
c_buf
(
Number
<
c_offset_2
>
{}),
c_buf
(
Number
<
c_offset_3
>
{}));
}
else
{
static_for
<
0
,
H
,
1
>
{}([
&
](
auto
h
)
{
static_for
<
0
,
W
,
1
>
{}([
&
](
auto
w
)
{
constexpr
auto
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
h
,
w
));
constexpr
auto
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
h
,
w
));
p_c
[
c_offset
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
a_offset
],
p_b
[
b_offset
]);
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
h
,
w
));
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
));
#if 0
c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}(
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]);
#else
amd_assembly_inner_product
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
c_buf
(
Number
<
c_offset
>
{}));
#endif
});
});
}
});
});
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
Run_source
(
p_a
,
p_b
,
p_c
);
}
};
}
// namespace ck
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
d075adf1
...
...
@@ -5,6 +5,75 @@
namespace
ck
{
// c += inner_product(a, b)
__device__
void
amd_assembly_inner_product
(
const
float
&
a
,
const
float
&
b
,
float
&
c
)
{
#if CK_USE_AMD_V_FMAC_F32
asm
volatile
(
"
\n
\
v_fmac_f32 %0, %1, %2
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#else
asm
volatile
(
"
\n
\
v_mac_f32 %0, %1, %2
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#endif
}
__device__
void
amd_assembly_inner_product
(
const
int8x4_t
&
a
,
const
int8x4_t
&
b
,
int32_t
&
c
)
{
#if 1
asm
volatile
(
"
\n
\
v_dot4_i32_i8 %0, %1, %2, %0
\n
\
"
:
"=v"
(
c
)
:
"v"
(
as_type
<
int32_t
>
(
a
)),
"v"
(
as_type
<
int32_t
>
(
b
)),
"0"
(
c
));
#else
c
=
__builtin_amdgcn_sdot4
(
as_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b
),
c
,
false
);
#endif
}
__device__
void
amd_assembly_inner_product
(
const
int8x8_t
&
a
,
const
int8x8_t
&
b
,
int32_t
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
amd_assembly_inner_product
(
vector_type
<
int8_t
,
8
>
{
a
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
8
>
{
b
}.
AsType
<
int8x4_t
>
()[
I0
],
c
);
amd_assembly_inner_product
(
vector_type
<
int8_t
,
8
>
{
a
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
8
>
{
b
}.
AsType
<
int8x4_t
>
()[
I1
],
c
);
}
__device__
void
amd_assembly_inner_product
(
const
int8x16_t
&
a
,
const
int8x16_t
&
b
,
int32_t
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
amd_assembly_inner_product
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I0
],
c
);
amd_assembly_inner_product
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I1
],
c
);
amd_assembly_inner_product
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I2
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I2
],
c
);
amd_assembly_inner_product
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I3
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I3
],
c
);
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
...
...
composable_kernel/include/utility/buffer.hpp
0 → 100644
View file @
d075adf1
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace
ck
{
template
<
typename
T
,
index_t
N
>
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
{
using
type
=
T
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
template
<
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
return
StaticBuffer
<
T
,
N
>
{};
}
template
<
typename
T
>
struct
DynamicBuffer
{
using
type
=
T
;
T
*
p_data_
;
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
)
:
p_data_
{
p_data
}
{}
__host__
__device__
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
template
<
typename
X
,
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
const
auto
Get
(
index_t
i
)
const
{
return
*
reinterpret_cast
<
const
X
*>
(
&
p_data_
[
i
]);
}
template
<
typename
X
,
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
const
X
&
x
)
{
*
reinterpret_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
};
template
<
typename
T
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
)
{
return
DynamicBuffer
<
T
>
{
p
};
}
}
// namespace ck
#endif
composable_kernel/include/utility/common_header.hpp
View file @
d075adf1
...
...
@@ -7,6 +7,7 @@
#include "statically_indexed_array.hpp"
#include "container_element_picker.hpp"
#include "float_type.hpp"
#include "buffer.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
d075adf1
...
...
@@ -14,11 +14,11 @@
#define CK_DEVICE_BACKEND_AMD 1
// GPU ID
#if
1
#if
0
#define CK_AMD_GPU_GFX906 1
#elif 0
#define CK_AMD_GPU_GFX908 1
#elif
0
#elif
1
#define CK_AMD_GPU_GFX1030 1
#endif
...
...
@@ -28,7 +28,7 @@
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS
0
#define CK_USE_LAUNCH_BOUNDS
1
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
...
...
composable_kernel/include/utility/float_type.amd.hpp.in
View file @
d075adf1
#ifndef CK_FLOAT_TYPE_AMD_HPP
#define CK_FLOAT_TYPE_AMD_HPP
#include "statically_indexed_array.hpp"
namespace ck {
using half_t = _Float16;
...
...
@@ -43,6 +45,15 @@ struct vector_type_maker<vector_type<T, N1>, N0>
using type = vector_type<T, N0 * N1>;
};
template <typename T, index_t N>
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
template <typename T, index_t N>
__host__ __device__ constexpr auto make_vector_type(Number<N>)
{
return typename vector_type_maker<T, N>::type{};
}
// scalar_type
template <typename TV>
struct scalar_type;
...
...
@@ -403,32 +414,249 @@ struct vector_type<T, 16>
}
};
template <typename T>
struct vector_type<T, 32>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
using type = d32_t;
union
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
};
template <typename T>
struct vector_type<T, 64>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
using type = d64_t;
union
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
// fp16
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort16_t = typename vector_type<ushort, 16>::type;
using ushort32_t = typename vector_type<ushort, 32>::type;
using ushort64_t = typename vector_type<ushort, 64>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
// i8
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// data type conversion
template <typename T>
...
...
composable_kernel/include/utility/sequence_helper.hpp
View file @
d075adf1
...
...
@@ -5,11 +5,26 @@
namespace
ck
{
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
make_sequence
(
Number
<
Is
>
...)
{
return
Sequence
<
Is
...
>
{};
}
// F returns index_t
template
<
typename
F
,
index_t
N
>
__host__
__device__
constexpr
auto
generate_sequence
(
F
,
Number
<
N
>
)
{
return
typename
sequence_gen
<
N
,
F
>::
type
{};
}
// F returns Number<>
template
<
typename
F
,
index_t
N
>
__host__
__device__
constexpr
auto
generate_sequence_v2
(
F
&&
f
,
Number
<
N
>
)
{
return
unpack
([
&
f
](
auto
&&
...
xs
)
{
return
make_sequence
(
f
(
xs
)...);
},
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
}
// namespace ck
#endif
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
d075adf1
...
...
@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr
auto
C0
=
C
/
Number
<
InWeiVectorSize
>
{};
constexpr
auto
C1
=
Number
<
InWeiVectorSize
>
{};
#if
1
#if
0
// run-time variables
constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
...
...
@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
#if
0
#if
1
// cdata = 16, BlockSize = 64, 16x64x4
constexpr
index_t
BlockSize
=
64
;
...
...
driver/src/conv_driver.cpp
View file @
d075adf1
...
...
@@ -64,7 +64,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
HI
=
1080
;
...
...
@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array
(
"ConvStrides"
,
to_multi_index
(
ConvStrides
{}));
print_array
(
"ConvDilations"
,
to_multi_index
(
ConvDilations
{}));
#if
1
#if
0
using in_data_t = float;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
...
...
@@ -724,23 +724,22 @@ int main(int argc, char* argv[])
LeftPads
{},
RightPads
{},
nrepeat
);
#elif
1
#elif
0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
<
in_data_t
,
in_vector_size
,
acc_data_t
,
out_data_t
>
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
out_data_t
>
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk
<
in_data_t
,
in_vector_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment