Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
90b3ccac
Commit
90b3ccac
authored
Feb 21, 2021
by
Chao Liu
Browse files
recovering code
parent
a7c587ee
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
832 additions
and
110 deletions
+832
-110
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
...osable_kernel/include/tensor_operation/blockwise_gemm.hpp
+361
-22
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+25
-19
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+330
-6
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
...sable_kernel/include/tensor_operation/threadwise_gemm.hpp
+106
-61
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+10
-2
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
View file @
90b3ccac
...
@@ -95,28 +95,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -95,28 +95,6 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
}
#if 0
__device__ static MatrixIndex GetDistanceFromBeginOfThreadMatrixC(index_t m_in_c,
index_t n_in_c)
{
constexpr auto c_thread_mtx = ThreadMatrixC{};
constexpr index_t MPerLevel1Cluster =
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
constexpr index_t NPerLevel1Cluster =
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
index_t m_repeat = m_in_c / MPerThreadSubC;
index_t n_repeat = n_in_c / NPerThreadSubC;
index_t m_in_sub_c = m_in_c % MPerThreadSubC;
index_t n_in_sub_c = n_in_c % NPerThreadSubC;
return MatrixIndex{m_repeat * MPerLevel1Cluster + m_in_sub_c,
n_repeat * NPerLevel1Cluster + n_in_sub_c};
}
#endif
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
__device__
void
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
...
@@ -352,5 +330,366 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -352,5 +330,366 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
}
};
};
// blockwise GEMM: C += transpose(A) * B
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template
<
index_t
BlockSize
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
ThreadMatrixC
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
KPerThreadLoop
,
index_t
MLevel0ThreadCluster
,
index_t
NLevel0ThreadCluster
,
index_t
MLevel1ThreadCluster
,
index_t
NLevel1ThreadCluster
,
index_t
ThreadGemmADataPerRead_M
,
index_t
ThreadGemmBDataPerRead_N
>
struct
BlockwiseGemm_km_kn_m0m1n0n1_v1
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
index_t
mMyThreadOffsetA
;
index_t
mMyThreadOffsetB
;
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v1
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
*
MLevel1ThreadCluster
*
NLevel1ThreadCluster
;
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
)
==
0
&&
N
%
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
)
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
#if 0
static_assert(ThreadMatrixC{}.GetLength(I0) == GetThreadMatrixCLengths()[I0] &&
ThreadMatrixC{}.GetLength(I1) == GetThreadMatrixCLengths()[I1],
"wrong! ThreadMatrixC lengths is wrong");
#else
constexpr
auto
tmp0
=
GetThreadMatrixCLengths
()[
I0
];
constexpr
auto
tmp1
=
GetThreadMatrixCLengths
()[
I1
];
static_assert
(
tmp0
==
8
,
"wrong!"
);
static_assert
(
tmp1
==
8
,
"wrong!"
);
static_assert
(
tmp0
==
Number
<
8
>
{},
"wrong!"
);
static_assert
(
tmp1
==
Number
<
8
>
{},
"wrong!"
);
static_assert
(
ThreadMatrixC
{}.
GetLength
(
I0
)
==
tmp0
&&
ThreadMatrixC
{}.
GetLength
(
I1
)
==
tmp1
,
"wrong! ThreadMatrixC lengths is wrong"
);
#endif
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
BlockMatrixA
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
row
));
mMyThreadOffsetB
=
BlockMatrixB
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
col
));
}
__device__
static
constexpr
auto
GetThreadMatrixCLengths
()
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
M
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
constexpr
index_t
MRepeat
=
M
/
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
);
constexpr
index_t
NRepeat
=
N
/
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
);
static_assert
(
M
==
128
,
"wrong!"
);
static_assert
(
MPerThreadSubC
==
4
,
"wrong!"
);
static_assert
(
MRepeat
==
2
,
"wrong!"
);
static_assert
(
NRepeat
==
2
,
"wrong!"
);
static_assert
(
NPerThreadSubC
==
4
,
"wrong!"
);
return
Sequence
<
MRepeat
*
MPerThreadSubC
,
NRepeat
*
NPerThreadSubC
>
{};
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
{
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
;
index_t
level1_id
=
thread_id
/
ThreadPerLevel0Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1ThreadCluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1ThreadCluster
;
index_t
level0_id
=
thread_id
%
ThreadPerLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0ThreadCluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0ThreadCluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
;
return
MatrixIndex
{
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
index_t
K
=
a_block_mtx
[
I0
];
constexpr
index_t
MPerThread
=
c_thread_mtx
[
I0
];
constexpr
index_t
NPerThread
=
c_thread_mtx
[
I1
];
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
ThreadGemmBDataPerRead_N
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmTransANormalBNormalC
<
decltype
(
a_thread_mtx
),
decltype
(
b_thread_mtx
),
decltype
(
c_thread_mtx
)
>
{};
#pragma unroll
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#pragma unroll
// read A
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
))
+
mMyThreadOffsetA
,
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
m_repeat
*
MPerThreadSubC
)));
}
#pragma unroll
// read B
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
b_thread_copy
.
Run
(
p_b_block
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
))
+
mMyThreadOffsetB
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
n_repeat
*
NPerThreadSubC
)));
}
// C += A * B
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
}
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_pipelined_2x2
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
auto
K
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
MPerThread
=
c_thread_mtx
.
GetLength
(
I0
);
constexpr
auto
NPerThread
=
c_thread_mtx
.
GetLength
(
I1
);
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
"wrong! inline asm cannot deal with this GEMM config yet"
);
// thread A, B
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
// thread A-sub, B-sub
constexpr
auto
a_thread_sub_mtx
=
a_thread_mtx
.
MakeSubMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{});
constexpr
auto
b_thread_sub_mtx
=
b_thread_mtx
.
MakeSubMatrixDescriptor
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{});
// thread C-sub
constexpr
auto
c_thread_sub_mtx
=
ThreadMatrixC
::
MakeSubMatrixDescriptor
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
ThreadGemmBDataPerRead_N
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmTransANormalBNormalC
<
decltype
(
a_thread_sub_mtx
),
decltype
(
b_thread_sub_mtx
),
decltype
(
c_thread_sub_mtx
)
>
{};
const
FloatA
*
p_a_block_off
=
p_a_block
+
mMyThreadOffsetA
;
const
FloatB
*
p_b_block_off
=
p_b_block
+
mMyThreadOffsetB
;
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
,
p_a_thread
);
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
,
p_b_thread
);
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
0
,
NPerLevel1Cluster
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
));
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
0
,
MPerLevel1Cluster
),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
0
,
NPerThreadSubC
));
#pragma unroll
// loop over rest of k
for
(
index_t
k
=
KPerThreadLoop
;
k
<
K
;
k
+=
KPerThreadLoop
)
{
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
k
,
0
),
p_a_thread
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
p_b_thread
,
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
0
));
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
k
,
0
),
p_b_thread
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
NPerThreadSubC
));
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
k
,
NPerLevel1Cluster
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
));
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
k
,
MPerLevel1Cluster
),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
0
,
NPerThreadSubC
));
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
p_b_thread
,
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
0
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
0
,
MPerThreadSubC
),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
0
,
NPerThreadSubC
),
p_c_thread
+
ThreadMatrixC
::
CalculateOffset
(
MPerThreadSubC
,
NPerThreadSubC
));
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
MPerThread
=
ThreadMatrixC
{}.
GetLength
(
I0
);
constexpr
index_t
NPerThread
=
ThreadMatrixC
{}.
GetLength
(
I1
);
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
if
constexpr
(
MRepeat
==
2
&&
NRepeat
==
2
)
{
Run_pipelined_2x2
(
p_a_block
,
p_b_block
,
p_c_thread
);
}
else
{
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
}
#else
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
90b3ccac
...
@@ -130,12 +130,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -130,12 +130,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
a_k_m_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_
multi_index
(
KPerBlock
,
MPerBlock
),
max_lds_align
);
make_
tuple
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{}),
Number
<
max_lds_align
>
{}
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
constexpr
auto
b_k_n_block_desc
=
make_dynamic_naive_tensor_descriptor_aligned_v2
(
make_
multi_index
(
KPerBlock
,
NPerBlock
),
max_lds_align
);
make_
tuple
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{}),
Number
<
max_lds_align
>
{}
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
...
@@ -201,6 +201,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -201,6 +201,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
#if 0
constexpr index_t a_k_m_block_mtx_stride =
constexpr index_t a_k_m_block_mtx_stride =
a_k_m_block_desc.CalculateOffset(make_multi_index(1, 0)) -
a_k_m_block_desc.CalculateOffset(make_multi_index(1, 0)) -
a_k_m_block_desc.CalculateOffset(make_multi_index(0, 0));
a_k_m_block_desc.CalculateOffset(make_multi_index(0, 0));
...
@@ -212,6 +213,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -212,6 +213,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
Number<KPerBlock>{}, Number<MPerBlock>{}, Number<a_k_m_block_mtx_stride>{});
Number<KPerBlock>{}, Number<MPerBlock>{}, Number<a_k_m_block_mtx_stride>{});
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(
constexpr auto b_k_n_block_mtx_desc = make_ConstantMatrixDescriptor(
Number<KPerBlock>{}, Number<NPerBlock>{}, Number<b_k_n_block_mtx_stride>{});
Number<KPerBlock>{}, Number<NPerBlock>{}, Number<b_k_n_block_mtx_stride>{});
#endif
// sanity check
// sanity check
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
...
@@ -223,14 +225,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -223,14 +225,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// c_thread_mtx definition: this is a mess
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
// TODO:: more elegent way of defining c_thread_mtx
#if 0
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
constexpr auto c_m0m1_n0n1_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{});
Number<MRepeat * MPerThread>{}, Number<NRepeat * NPerThread>{});
#else
constexpr
auto
c_m0m1_n0n1_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
MRepeat
*
MPerThread
>
{},
Number
<
NRepeat
*
NPerThread
>
{}));
#endif
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
const
auto
blockwise_gemm
=
BlockSize
,
BlockwiseGemm_km_kn_m0m1n0n1_v1
<
BlockSize
,
decltype
(
a_k_m_block_
mtx_
desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_
mtx_
desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
c_m0m1_n0n1_thread_
mtx_
desc
),
decltype
(
c_m0m1_n0n1_thread_desc
),
MPerThread
,
MPerThread
,
NPerThread
,
NPerThread
,
KPerThread
,
KPerThread
,
...
@@ -252,10 +259,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -252,10 +259,10 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
Float
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
// register allocation for output
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_
mtx_
desc
.
GetElementSpace
()];
AccFloat
p_c_thread
[
c_m0m1_n0n1_thread_desc
.
GetElementSpace
Size
()];
// zero out threadwise output
// zero out threadwise output
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_
mtx_
desc
,
p_c_thread
);
threadwise_matrix_set_zero
(
c_m0m1_n0n1_thread_desc
,
p_c_thread
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
...
@@ -422,7 +429,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -422,7 +429,6 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
AddressSpace
::
Global
,
AddressSpace
::
Global
,
CGlobalMemoryDataOperation
,
CGlobalMemoryDataOperation
,
1
,
1
,
true
,
true
>
(
c_m0_m1_n0_n1_global_desc
,
true
>
(
c_m0_m1_n0_n1_global_desc
,
make_multi_index
(
m_thread_data_on_global
/
M1
,
make_multi_index
(
m_thread_data_on_global
/
M1
,
m_thread_data_on_global
%
M1
,
m_thread_data_on_global
%
M1
,
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
90b3ccac
...
@@ -44,11 +44,11 @@ template <typename SrcData,
...
@@ -44,11 +44,11 @@ template <typename SrcData,
AddressSpace
DstAddressSpace
,
AddressSpace
DstAddressSpace
,
InMemoryDataOperation
DstInMemOp
,
InMemoryDataOperation
DstInMemOp
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
bool
DstResetCoordinateAfterRun
>
bool
DstResetCoordinateAfterRun
>
struct
ThreadwiseDynamicTensorSliceTransfer_v1r3
struct
ThreadwiseDynamicTensorSliceTransfer_v1r3
{
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
using
DstCoord
=
decltype
(
make_dynamic_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_dynamic_tensor_coordinate
(
DstDesc
{},
Index
{}));
...
@@ -61,10 +61,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -61,10 +61,12 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
{
{
}
}
#if 0
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3()
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v1r3()
: ThreadwiseDynamicTensorSliceTransfer_v1r3(DstDesc{}, make_zero_multi_index<nDim>())
: ThreadwiseDynamicTensorSliceTransfer_v1r3(DstDesc{}, make_zero_multi_index<nDim>())
{
{
}
}
#endif
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
{
{
...
@@ -297,7 +299,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -297,7 +299,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
return
forward_sweep
;
return
forward_sweep
;
}();
}();
// calculate dst data index after last iteration in Run
Write
(), if it has not being reset by
// calculate dst data index after last iteration in Run(), if it has not being reset by
// RunWrite()
// RunWrite()
constexpr
auto
dst_data_idx
=
[
&
]()
{
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
Index
ordered_idx
;
...
@@ -328,7 +330,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -328,7 +330,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_step_idx
)
const
Index
&
dst_slice_origin_step_idx
)
{
{
// if dst coord was not reset by Run
Write
(), then need to adjust the step here
// if dst coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
const
auto
adjusted_step_idx
=
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
DstResetCoordinateAfterRun
?
dst_slice_origin_step_idx
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
:
dst_slice_origin_step_idx
+
GetDstCoordinateResetStep
();
...
@@ -344,6 +346,326 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -344,6 +346,326 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
DstCoord
dst_slice_origin_coord_
;
DstCoord
dst_slice_origin_coord_
;
};
// namespace ck
};
// namespace ck
// this version is less likely to have scratch memory issue, due to:
// 1. It does not keep reference to tensor descriptor
// 2. It does not construct new tensor coordinate for this->Run()
// Assume dst_slice_origin_idx is 0
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
SrcVectorDim
,
index_t
SrcScalarPerVector
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
index_t
SrcScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
>
struct
ThreadwiseDynamicTensorSliceTransfer_v2
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
SrcCoord
=
decltype
(
make_dynamic_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoordIterator
=
decltype
(
make_dynamic_tensor_coordinate_iterator
(
SrcDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseDynamicTensorSliceTransfer_v2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
:
src_slice_origin_coord_
(
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
))
{
}
#if 0
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v2()
: ThreadwiseDynamicTensorSliceTransfer_v1r3(SrcDesc{}, make_zero_multi_index<nDim>())
{
}
#endif
__device__
void
SetDstSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
src_slice_origin_coord_
=
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
template
<
typename
SrcIteratorHacks
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
,
DstData
*
p_dst
,
const
SrcIteratorHacks
&
src_iterator_hacks
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
// Comments: dst_desc is constexpr
constexpr
auto
dst_desc
=
remove_cv_t
<
remove_reference_t
<
DstDesc
>>
{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence
(
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// make forward iterators
const
auto
src_forward_iterators
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_scalar_per_access
[
i
]
:
0
;
});
return
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
forward_step
,
src_iterator_hacks
[
I0
][
i
]);
},
Number
<
nDim
>
{});
// make backward iterators
const
auto
src_backward_iterators
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_scalar_per_access
[
i
]
:
0
;
});
return
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
backward_step
,
src_iterator_hacks
[
I1
][
i
]);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep
;
forward_sweep
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_idx
[
I0
];
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_idx
[
j
];
});
forward_sweep
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep
;
}();
// calculate src data index
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_access_idx
[
i
]
:
ordered_access_lengths
[
i
]
-
1
-
ordered_access_idx
[
i
];
});
auto
src_data_idx
=
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
src_scalar_per_access
;
return
src_data_idx
;
}();
// copy data
// hardcoding for buffer_store
// TODO refactor transfer_data() to encapsulate this
static_assert
(
DstAddressSpace
==
AddressSpace
::
Vgpr
,
"wrong! hardcode for ds_read"
);
vector_type
<
SrcData
,
SrcScalarPerVector
>
src_vector
;
using
src_vector_t
=
typename
vector_type
<
SrcData
,
SrcScalarPerVector
>::
MemoryType
;
if
constexpr
(
SrcAddressSpace
==
AddressSpace
::
Global
)
{
src_vector
.
Vector
()
=
amd_buffer_load
<
SrcData
,
SrcScalarPerVector
>
(
p_src
,
src_slice_origin_coord_
.
GetOffset
(),
true
,
src_desc
.
GetElementSpaceSize
());
const
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_coord_
);
src_vector
.
Vector
()
=
is_valid
?
src_vector
.
Vector
()
:
src_vector_t
{
0
};
}
else
{
const
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_coord_
);
src_vector
.
Vector
()
=
is_valid
?
*
reinterpret_cast
<
const
src_vector_t
*>
(
&
p_src
[
src_slice_origin_coord_
.
GetOffset
()])
:
src_vector_t
{
0
};
}
// this is hardcoded for dst that has compile-time tensor descriptor
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// assume dst_slice_origin_idx is 0
// TODO: support non-zero dst_slice_oring_idx
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
src_data_idx
+
i
*
src_scalar_step_in_vector
);
p_dst
[
Number
<
dst_offset
>
{}]
=
src_vector
[
i
];
});
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim
(
i
)
=
ordered_access_idx
[
i
]
<
ordered_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim
(
i
)
&=
ordered_access_idx
[
j
]
==
ordered_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim
;
}
();
// move
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_coord_
,
src_forward_iterators
[
dim_access_order
[
i
]]);
}
else
{
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_coord_
,
src_backward_iterators
[
dim_access_order
[
i
]]);
}
}
});
});
// move src coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
{
const
auto
src_reset_iterator
=
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
GetSrcCoordinateResetStep
());
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_coord_
,
src_reset_iterator
);
}
}
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcData
*
p_src
,
DstData
*
p_dst
)
{
constexpr
index_t
ntransform_src
=
SrcDesc
::
GetNumOfTransform
();
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
constexpr
auto
src_iterator_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
Run
(
src_desc
,
p_src
,
p_dst
,
src_iterator_hacks
);
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep
;
forward_sweep
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_lengths
[
I0
]
-
1
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_lengths
[
j
]
-
1
;
});
forward_sweep
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep
;
}();
// calculate src data index after last iteration in Run(), if it has not being reset by
// RunWrite()
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_access_lengths
[
i
]
-
1
:
0
;
});
auto
src_data_idx
=
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
src_scalar_per_access
;
return
src_data_idx
;
}();
//
constexpr
auto
reset_src_data_step
=
[
&
]()
{
Index
reset_src_data_step
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_src_data_step
(
i
)
=
-
src_data_idx
[
i
];
});
return
reset_src_data_step
;
}();
return
reset_src_data_step
;
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_step_idx
)
{
// if src coord was not reset by Run(), then need to adjust the step here
const
auto
adjusted_step_idx
=
SrcResetCoordinateAfterRun
?
src_slice_origin_step_idx
:
src_slice_origin_step_idx
+
GetSrcCoordinateResetStep
();
// is it OK to construct a new step every time?
const
auto
adjusted_step
=
make_dynamic_tensor_coordinate_iterator
(
src_desc
,
adjusted_step_idx
);
move_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_coord_
,
adjusted_step
);
}
private:
SrcCoord
src_slice_origin_coord_
;
};
// namespace ck
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// this version does following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions
// and sometimes useless instructions
// 1. It does not keep reference to tensor descriptor
// 1. It does not keep reference to tensor descriptor
...
@@ -398,11 +720,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -398,11 +720,13 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
"wrong!"
);
"wrong!"
);
}
}
#if 0
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3()
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3()
: ThreadwiseDynamicTensorSliceTransfer_v3(
: ThreadwiseDynamicTensorSliceTransfer_v3(
SrcDesc{}, make_zero_multi_index<nDim>(), DstDesc{}, make_zero_multi_index<nDim>())
SrcDesc{}, make_zero_multi_index<nDim>(), DstDesc{}, make_zero_multi_index<nDim>())
{
{
}
}
#endif
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
{
{
...
@@ -512,7 +836,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -512,7 +836,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
vector_type
<
SrcData
,
SrcScalarPerVector
>
src_vector
;
vector_type
<
SrcData
,
SrcScalarPerVector
>
src_vector
;
using
S
rc
V
ector
Type
=
typename
vector_type
<
SrcData
,
SrcScalarPerVector
>::
MemoryType
;
using
s
rc
_v
ector
_t
=
typename
vector_type
<
SrcData
,
SrcScalarPerVector
>::
MemoryType
;
#if 1
#if 1
src_vector
.
Vector
()
=
amd_buffer_load
<
SrcData
,
SrcScalarPerVector
>
(
src_vector
.
Vector
()
=
amd_buffer_load
<
SrcData
,
SrcScalarPerVector
>
(
...
@@ -521,7 +845,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -521,7 +845,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
const
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
const
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_slice_origin_coord_
);
src_desc
,
src_slice_origin_coord_
);
src_vector
.
Vector
()
=
is_valid
?
src_vector
.
Vector
()
:
S
rc
V
ector
Type
{
0
};
src_vector
.
Vector
()
=
is_valid
?
src_vector
.
Vector
()
:
s
rc
_v
ector
_t
{
0
};
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
buffer_offset
=
constexpr
index_t
buffer_offset
=
...
...
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
View file @
90b3ccac
...
@@ -10,6 +10,7 @@ namespace ck {
...
@@ -10,6 +10,7 @@ namespace ck {
template
<
typename
Float
,
class
Matrix
>
template
<
typename
Float
,
class
Matrix
>
__device__
void
threadwise_matrix_set_zero
(
Matrix
,
Float
*
__restrict__
p_thread
)
__device__
void
threadwise_matrix_set_zero
(
Matrix
,
Float
*
__restrict__
p_thread
)
{
{
#if 0
for(index_t i = 0; i < Matrix::NRow(); ++i)
for(index_t i = 0; i < Matrix::NRow(); ++i)
{
{
for(index_t j = 0; j < Matrix::NCol(); ++j)
for(index_t j = 0; j < Matrix::NCol(); ++j)
...
@@ -18,6 +19,21 @@ __device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
...
@@ -18,6 +19,21 @@ __device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
p_thread[id] = Float(0);
p_thread[id] = Float(0);
}
}
}
}
#else
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
M
=
Matrix
{}.
GetLength
(
I0
);
constexpr
auto
N
=
Matrix
{}.
GetLength
(
I1
);
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
offset
=
Matrix
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
p_thread
[
offset
]
=
Float
(
0
);
});
});
#endif
}
}
template
<
typename
SrcMatrix
,
template
<
typename
SrcMatrix
,
...
@@ -32,6 +48,7 @@ struct ThreadwiseMatrixSliceCopy
...
@@ -32,6 +48,7 @@ struct ThreadwiseMatrixSliceCopy
static_assert
(
SrcMatrix
::
RowStride
()
%
DataPerAccess
==
0
&&
static_assert
(
SrcMatrix
::
RowStride
()
%
DataPerAccess
==
0
&&
DstMatrix
::
RowStride
()
%
DataPerAccess
==
0
,
DstMatrix
::
RowStride
()
%
DataPerAccess
==
0
,
"wrong! wrong alignment"
);
"wrong! wrong alignment"
);
static_assert
(
NSliceCol
%
DataPerAccess
==
0
,
static_assert
(
NSliceCol
%
DataPerAccess
==
0
,
"wrong! should be NSliceCol % DataPerAccess == 0"
);
"wrong! should be NSliceCol % DataPerAccess == 0"
);
}
}
...
@@ -41,6 +58,7 @@ struct ThreadwiseMatrixSliceCopy
...
@@ -41,6 +58,7 @@ struct ThreadwiseMatrixSliceCopy
{
{
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
MemoryType
;
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
MemoryType
;
#if 0
for(index_t i = 0; i < NSliceRow; ++i)
for(index_t i = 0; i < NSliceRow; ++i)
{
{
for(index_t j = 0; j < NSliceCol; j += DataPerAccess)
for(index_t j = 0; j < NSliceCol; j += DataPerAccess)
...
@@ -52,6 +70,17 @@ struct ThreadwiseMatrixSliceCopy
...
@@ -52,6 +70,17 @@ struct ThreadwiseMatrixSliceCopy
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
}
}
}
}
#else
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
constexpr
auto
src_offset
=
SrcMatrix
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
constexpr
auto
dst_offset
=
DstMatrix
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
]);
});
});
#endif
}
}
};
};
...
@@ -62,85 +91,95 @@ struct ThreadwiseGemmTransANormalBNormalC
...
@@ -62,85 +91,95 @@ struct ThreadwiseGemmTransANormalBNormalC
{
{
__device__
constexpr
ThreadwiseGemmTransANormalBNormalC
()
__device__
constexpr
ThreadwiseGemmTransANormalBNormalC
()
{
{
#if 0
static_assert(MatrixA::NRow() == MatrixB::NRow() && MatrixA::NCol() == MatrixC::NRow() &&
static_assert(MatrixA::NRow() == MatrixB::NRow() && MatrixA::NCol() == MatrixC::NRow() &&
MatrixB::NCol() == MatrixC::NCol(),
MatrixB::NCol() == MatrixC::NCol(),
"wrong!");
"wrong!");
#endif
}
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
{
constexpr
index_t
M
=
MatrixC
::
NRow
();
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
index_t
N
=
MatrixC
::
NCol
();
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
K
=
MatrixA
::
NRow
();
// A is transposed
constexpr
index_t
M
=
MatrixC
{}[
I0
];
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
constexpr
index_t
N
=
MatrixC
{}[
I1
];
{
constexpr
index_t
K
=
MatrixA
{}[
I0
];
for
(
index_t
m
=
0
;
m
<
M
;
++
m
)
{
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
){
for
(
index_t
n
=
0
;
n
<
N
;
++
n
)
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
){
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
){
const
index_t
aindex
=
MatrixA
::
CalculateOffset
(
k
,
m
);
// A is transposed
const
index_t
a_offset
=
const
index_t
bindex
=
MatrixB
::
CalculateOffset
(
k
,
n
);
MatrixA
{}.
CalculateOffset
(
make_tuple
(
k
,
m
));
// A is transposed
const
index_t
cindex
=
MatrixC
::
CalculateOffset
(
m
,
n
);
const
index_t
b_offset
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
n
));
const
index_t
c_offset
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
n
));
p_c
[
cindex
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
aindex
],
p_b
[
bindex
]);
p_c
[
c_offset
]
+=
}
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
a_offset
],
p_b
[
b_offset
]);
}
});
}
});
});
}
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_amd_asm
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
__device__
static
void
Run_amd_asm
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
{
constexpr
index_t
M
=
MatrixC
::
NRow
();
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
index_t
N
=
MatrixC
::
NCol
();
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
K
=
MatrixA
::
NRow
();
// A is transposed
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
constexpr
index_t
M
=
MatrixC
{}[
I0
];
constexpr
index_t
N
=
MatrixC
{}[
I1
];
constexpr
index_t
K
=
MatrixA
{}[
I0
];
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
{
for
(
index_t
m
=
0
;
m
<
M
;
++
m
)
{
const
index_t
aindex
=
MatrixA
::
CalculateOffset
(
k
,
m
);
// A is transposed
static_
if
<
N
==
2
>
{}([
&
](
auto
)
{
static_
for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
const
index_t
bindex_0
=
MatrixB
::
CalculateOffset
(
k
,
0
);
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
){
const
index_t
bindex_1
=
Matrix
B
::
CalculateOffset
(
k
,
1
);
const
expr
auto
a_offset
=
Matrix
A
{}.
CalculateOffset
(
make_tuple
(
k
,
m
)
);
const
index_t
cindex_0
=
MatrixC
::
CalculateOffset
(
m
,
0
);
if
constexpr
(
N
==
2
)
const
index_t
cindex_1
=
MatrixC
::
CalculateOffset
(
m
,
1
);
{
constexpr
auto
b_offset_0
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I1
));
amd_assembly_outer_product_1x2
(
constexpr
auto
c_offset_0
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I0
));
p_a
[
aindex
],
p_b
[
bindex_0
],
p_b
[
bindex_1
],
p_c
[
cindex_0
],
p_c
[
cindex_1
]);
constexpr
auto
c_offset_1
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I1
));
});
static_if
<
N
==
4
>
{}([
&
](
auto
)
{
amd_assembly_outer_product_1x2
(
p_a
[
a_offset
],
const
index_t
bindex_0
=
MatrixB
::
CalculateOffset
(
k
,
0
);
p_b
[
b_offset_0
],
const
index_t
bindex_1
=
MatrixB
::
CalculateOffset
(
k
,
1
);
p_b
[
b_offset_1
],
const
index_t
bindex_2
=
MatrixB
::
CalculateOffset
(
k
,
2
);
p_c
[
c_offset_0
],
const
index_t
bindex_3
=
MatrixB
::
CalculateOffset
(
k
,
3
);
p_c
[
c_offset_1
]);
const
index_t
cindex_0
=
MatrixC
::
CalculateOffset
(
m
,
0
);
const
index_t
cindex_1
=
MatrixC
::
CalculateOffset
(
m
,
1
);
const
index_t
cindex_2
=
MatrixC
::
CalculateOffset
(
m
,
2
);
const
index_t
cindex_3
=
MatrixC
::
CalculateOffset
(
m
,
3
);
amd_assembly_outer_product_1x4
(
p_a
[
aindex
],
p_b
[
bindex_0
],
p_b
[
bindex_1
],
p_b
[
bindex_2
],
p_b
[
bindex_3
],
p_c
[
cindex_0
],
p_c
[
cindex_1
],
p_c
[
cindex_2
],
p_c
[
cindex_3
]);
});
}
}
else
if
constexpr
(
N
==
4
)
{
constexpr
auto
b_offset_0
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I1
));
constexpr
auto
b_offset_2
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I2
));
constexpr
auto
b_offset_3
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I3
));
constexpr
auto
c_offset_0
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I1
));
constexpr
auto
c_offset_2
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I2
));
constexpr
auto
c_offset_3
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I3
));
amd_assembly_outer_product_1x4
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_b
[
b_offset_2
],
p_b
[
b_offset_3
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
],
p_c
[
c_offset_2
],
p_c
[
c_offset_3
]);
}
}
});
});
}
}
#endif
#endif
...
@@ -153,8 +192,14 @@ struct ThreadwiseGemmTransANormalBNormalC
...
@@ -153,8 +192,14 @@ struct ThreadwiseGemmTransANormalBNormalC
(
is_same
<
FloatA
,
half2_t
>
{}
&&
is_same
<
FloatB
,
half2_t
>
{})
||
(
is_same
<
FloatA
,
half2_t
>
{}
&&
is_same
<
FloatB
,
half2_t
>
{})
||
(
is_same
<
FloatA
,
half4_t
>
{}
&&
is_same
<
FloatB
,
half4_t
>
{}));
(
is_same
<
FloatA
,
half4_t
>
{}
&&
is_same
<
FloatB
,
half4_t
>
{}));
static_if
<
has_amd_asm
>
{}([
&
](
auto
fwd
)
{
Run_amd_asm
(
p_a
,
p_b
,
fwd
(
p_c
));
})
if
constexpr
(
has_amd_asm
)
.
Else
([
&
](
auto
)
{
Run_source
(
p_a
,
p_b
,
p_c
);
});
{
Run_amd_asm
(
p_a
,
p_b
,
p_c
);
}
else
{
Run_source
(
p_a
,
p_b
,
p_c
);
}
#else
#else
Run_source
(
p_a
,
p_b
,
p_c
);
Run_source
(
p_a
,
p_b
,
p_c
);
#endif
#endif
...
...
composable_kernel/include/utility/math.hpp
View file @
90b3ccac
...
@@ -114,8 +114,8 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
...
@@ -114,8 +114,8 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
}
}
// greatest common divisor, aka highest common factor
// greatest common divisor, aka highest common factor
template
<
typename
X
,
typename
Y
>
template
<
typename
T
>
__host__
__device__
constexpr
auto
gcd
(
X
x
,
Y
y
)
__host__
__device__
constexpr
T
gcd
(
T
x
,
T
y
)
{
{
if
(
x
==
y
||
x
==
0
)
if
(
x
==
y
||
x
==
0
)
{
{
...
@@ -135,6 +135,14 @@ __host__ __device__ constexpr auto gcd(X x, Y y)
...
@@ -135,6 +135,14 @@ __host__ __device__ constexpr auto gcd(X x, Y y)
}
}
}
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
gcd
(
Number
<
X
>
,
Number
<
Y
>
)
{
constexpr
auto
r
=
gcd
(
X
,
Y
);
return
Number
<
r
>
{};
}
template
<
typename
X
,
typename
...
Ys
>
template
<
typename
X
,
typename
...
Ys
>
__host__
__device__
constexpr
auto
gcd
(
X
x
,
Ys
...
ys
)
__host__
__device__
constexpr
auto
gcd
(
X
x
,
Ys
...
ys
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment