Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
35d68cf8
"vscode:/vscode.git/clone" did not exist on "29496c95d3d04eafae5eb9d0de2b3e4673df3a73"
Commit
35d68cf8
authored
Apr 21, 2021
by
Chao Liu
Browse files
replacing array with vector for tensor data
parent
712babe4
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
253 additions
and
46 deletions
+253
-46
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+64
-40
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+7
-5
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+1
-1
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
+181
-0
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
35d68cf8
...
@@ -503,8 +503,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -503,8 +503,10 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
}
__device__
void
template
<
typename
CThreadBuffer
>
Run_pipelined_2x2
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
__device__
void
Run_pipelined_2x2
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
CThreadBuffer
c_thread_buf
)
const
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -549,12 +551,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -549,12 +551,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
auto
a_thread_buf
=
make_dynamic_buffer
<
FloatA
>
(
p_a_thread
);
auto
a_thread_buf
=
make_dynamic_buffer
<
FloatA
>
(
p_a_thread
);
auto
b_thread_buf
=
make_dynamic_buffer
<
FloatB
>
(
p_b_thread
);
auto
b_thread_buf
=
make_dynamic_buffer
<
FloatB
>
(
p_b_thread
);
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
FloatA
,
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
r1
<
FloatA
,
FloatB
,
FloatB
,
FloatC
,
FloatC
,
decltype
(
a_thread_sub_mtx
),
decltype
(
a_thread_sub_mtx
),
decltype
(
b_thread_sub_mtx
),
decltype
(
b_thread_sub_mtx
),
decltype
(
c_thread_sub_mtx
)
>
{};
decltype
(
c_thread_sub_mtx
)
>
{};
// read A_sub_0
// read A_sub_0
a_thread_copy_
.
Run
(
BlockMatrixA
{},
a_thread_copy_
.
Run
(
BlockMatrixA
{},
...
@@ -589,13 +591,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -589,13 +591,20 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
a_thread_buf
);
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
b_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
c_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
threadwise_gemm
.
Run
(
a_thread_buf
,
p_a_thread
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
p_b_thread
+
b_thread_mtx_desc_
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
b_thread_buf
,
p_c_thread
+
c_thread_mtx_desc
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
make_tuple
(
Number
<
0
>
{},
Number
<
NPerThreadSubC
>
{}),
c_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
NPerThreadSubC
>
{}));
// loop over rest of k
// loop over rest of k
static_for
<
KPerThreadLoop
,
K
,
KPerThreadLoop
>
{}([
&
](
auto
k
)
{
static_for
<
KPerThreadLoop
,
K
,
KPerThreadLoop
>
{}([
&
](
auto
k
)
{
...
@@ -608,10 +617,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -608,10 +617,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
a_thread_buf
);
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
threadwise_gemm
.
Run
(
a_thread_buf
,
p_a_thread
+
a_thread_mtx_desc_
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
make_tuple
(
Number
<
0
>
{},
Number
<
MPerThreadSubC
>
{}),
p_b_thread
,
b_thread_buf
,
p_c_thread
+
c_thread_mtx_desc
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
0
)));
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
c_thread_buf
,
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
0
>
{}));
// read B_sub_0
// read B_sub_0
b_thread_copy_
.
Run
(
BlockMatrixB
{},
b_thread_copy_
.
Run
(
BlockMatrixB
{},
...
@@ -622,11 +633,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -622,11 +633,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
b_thread_buf
);
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
threadwise_gemm
.
Run
(
a_thread_buf
,
p_a_thread
+
a_thread_mtx_desc_
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
make_tuple
(
Number
<
0
>
{},
Number
<
MPerThreadSubC
>
{}),
p_b_thread
+
b_thread_mtx_desc_
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
b_thread_buf
,
p_c_thread
+
make_tuple
(
Number
<
0
>
{},
Number
<
NPerThreadSubC
>
{}),
c_thread_mtx_desc
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
NPerThreadSubC
)));
c_thread_buf
,
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}));
// read B_sub_1
// read B_sub_1
b_thread_copy_
.
Run
(
BlockMatrixB
{},
b_thread_copy_
.
Run
(
BlockMatrixB
{},
...
@@ -645,30 +657,42 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -645,30 +657,42 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
a_thread_buf
);
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
b_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
c_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
threadwise_gemm
.
Run
(
a_thread_buf
,
p_a_thread
,
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
p_b_thread
+
b_thread_mtx_desc_
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
b_thread_buf
,
p_c_thread
+
c_thread_mtx_desc
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
make_tuple
(
Number
<
0
>
{},
Number
<
NPerThreadSubC
>
{}),
c_thread_buf
,
make_tuple
(
Number
<
0
>
{},
Number
<
NPerThreadSubC
>
{}));
});
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
threadwise_gemm
.
Run
(
a_thread_buf
,
p_a_thread
+
a_thread_mtx_desc_
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
make_tuple
(
Number
<
0
>
{},
Number
<
MPerThreadSubC
>
{}),
p_b_thread
,
b_thread_buf
,
p_c_thread
+
c_thread_mtx_desc
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
0
)));
make_tuple
(
Number
<
0
>
{},
Number
<
0
>
{}),
c_thread_buf
,
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
0
>
{}));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
threadwise_gemm
.
Run
(
a_thread_buf
,
p_a_thread
+
a_thread_mtx_desc_
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
make_tuple
(
Number
<
0
>
{},
Number
<
MPerThreadSubC
>
{}),
p_b_thread
+
b_thread_mtx_desc_
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
b_thread_buf
,
p_c_thread
+
make_tuple
(
Number
<
0
>
{},
Number
<
NPerThreadSubC
>
{}),
c_thread_mtx_desc
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
NPerThreadSubC
)));
c_thread_buf
,
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}));
}
}
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
template
<
typename
CThreadBuffer
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
CThreadBuffer
c_thread_buf
)
const
{
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -682,14 +706,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -682,14 +706,14 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
if
constexpr
(
MRepeat
==
2
&&
NRepeat
==
2
)
if
constexpr
(
MRepeat
==
2
&&
NRepeat
==
2
)
{
{
Run_pipelined_2x2
(
p_a_block
,
p_b_block
,
p_
c_thread
);
Run_pipelined_2x2
(
p_a_block
,
p_b_block
,
c_thread
_buf
);
}
}
else
else
{
{
Run_naive
(
p_a_block
,
p_b_block
,
p_
c_thread
);
Run_naive
(
p_a_block
,
p_b_block
,
c_thread
_buf
);
}
}
#else
#else
Run_naive
(
p_a_block
,
p_b_block
,
p_
c_thread
);
Run_naive
(
p_a_block
,
p_b_block
,
c_thread
_buf
);
#endif
#endif
}
}
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
35d68cf8
...
@@ -732,6 +732,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -732,6 +732,8 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// register allocation for output
// register allocation for output
FloatAcc
p_c_thread
[
c_m0m1_n0n1_thread_desc
.
GetElementSpaceSize
()];
FloatAcc
p_c_thread
[
c_m0m1_n0n1_thread_desc
.
GetElementSpaceSize
()];
auto
c_thread_buf
=
make_dynamic_buffer
<
FloatAcc
>
(
p_c_thread
);
// zero out threadwise output
// zero out threadwise output
threadwise_matrix_set_zero_v2
(
c_m0m1_n0n1_thread_desc
,
p_c_thread
);
threadwise_matrix_set_zero_v2
(
c_m0m1_n0n1_thread_desc
,
p_c_thread
);
...
@@ -789,7 +791,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -789,7 +791,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_even
,
p_b_block_even
,
p_
c_thread
);
blockwise_gemm
.
Run
(
p_a_block_even
,
p_b_block_even
,
c_thread
_buf
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_odd
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_odd
);
...
@@ -812,7 +814,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -812,7 +814,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block_odd
,
p_b_block_odd
,
p_
c_thread
);
blockwise_gemm
.
Run
(
p_a_block_odd
,
p_b_block_odd
,
c_thread
_buf
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_even
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_even
);
...
@@ -839,7 +841,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -839,7 +841,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_
c_thread
);
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
c_thread
_buf
);
// LDS double buffer: store last data to LDS
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
...
@@ -850,14 +852,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -850,14 +852,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
p_b_block_double
+
b_block_space_size
,
p_b_block_double
+
b_block_space_size
,
p_
c_thread
);
c_thread
_buf
);
}
}
else
// if has 1 iteration left
else
// if has 1 iteration left
{
{
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
p_
c_thread
);
blockwise_gemm
.
Run
(
p_a_block_double
,
p_b_block_double
,
c_thread
_buf
);
}
}
// output: register to global memory
// output: register to global memory
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
35d68cf8
...
@@ -1370,7 +1370,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
...
@@ -1370,7 +1370,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
const
SrcData
*
p_src
,
const
SrcData
*
p_src
,
const
DstDesc
&
,
const
DstDesc
&
,
const
DstRefToOriginDisplacement
&
,
const
DstRefToOriginDisplacement
&
,
DstBuffer
dst_buf
)
const
DstBuffer
&
dst_buf
)
const
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
View file @
35d68cf8
...
@@ -168,5 +168,186 @@ struct ThreadwiseGemm_km_kn_mn_v1
...
@@ -168,5 +168,186 @@ struct ThreadwiseGemm_km_kn_mn_v1
}
}
};
};
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. ABuffer, BBuffer, CBuffer are static buffer
// 3. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
,
typename
std
::
enable_if
<
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseGemm_km_kn_mn_v1r1
{
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run_source
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
M
=
CDesc
{}.
GetLength
(
I0
);
constexpr
auto
N
=
CDesc
{}.
GetLength
(
I1
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I0
);
constexpr
auto
a_origin_idx
=
AOriginIdx
{};
constexpr
auto
b_origin_idx
=
BOriginIdx
{};
constexpr
auto
c_origin_idx
=
COriginIdx
{};
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
k
,
m
));
constexpr
auto
b_offset
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
n
));
constexpr
auto
c_offset
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
n
));
c_buf
.
template
AsType
<
FloatC
>()(
c_offset
)
+=
inner_product_with_conversion
<
FloatC
>
{}(
a_buf
.
template
AsType
<
FloatA
>()[
a_offset
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset
]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run_amd_asm
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
M
=
CDesc
{}.
GetLength
(
I0
);
constexpr
auto
N
=
CDesc
{}.
GetLength
(
I1
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I0
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
k
,
m
));
if
constexpr
(
N
==
2
)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
I1
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
I1
));
amd_assembly_outer_product_1x2
(
a_buf
.
template
AsType
<
FloatA
>()[
a_offset
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset_0
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset_1
],
c_buf
.
template
AsType
<
FloatC
>()(
c_offset_0
),
c_buf
.
template
AsType
<
FloatC
>()(
c_offset_1
));
}
else
if
constexpr
(
N
==
4
)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
I1
));
constexpr
auto
b_offset_2
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
I2
));
constexpr
auto
b_offset_3
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
I3
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
I1
));
constexpr
auto
c_offset_2
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
I2
));
constexpr
auto
c_offset_3
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
I3
));
amd_assembly_outer_product_1x4
(
a_buf
.
template
AsType
<
FloatA
>()[
a_offset
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset_0
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset_1
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset_2
],
b_buf
.
template
AsType
<
FloatB
>()[
b_offset_3
],
c_buf
.
template
AsType
<
FloatC
>()(
c_offset_0
),
c_buf
.
template
AsType
<
FloatC
>()(
c_offset_1
),
c_buf
.
template
AsType
<
FloatC
>()(
c_offset_2
),
c_buf
.
template
AsType
<
FloatC
>()(
c_offset_3
));
}
});
});
}
#endif
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm
(
a_buf
,
AOriginIdx
{},
b_buf
,
BOriginIdx
{},
c_buf
,
COriginIdx
{});
#else
Run_source
(
a_buf
,
AOriginIdx
{},
b_buf
,
BOriginIdx
{},
c_buf
,
COriginIdx
{});
#endif
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment