Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dea22d0e
Commit
dea22d0e
authored
Feb 22, 2021
by
Chao Liu
Browse files
added IsKnownAtCompileTime() in multi-index transform and tensor descriptor
parent
55599afd
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
678 additions
and
531 deletions
+678
-531
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
...lude/tensor_description/dynamic_multi_index_transform.hpp
+49
-0
composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp
.../include/tensor_description/dynamic_tensor_descriptor.hpp
+13
-0
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
...osable_kernel/include/tensor_operation/blockwise_gemm.hpp
+0
-355
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+370
-0
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+5
-3
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+10
-4
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
...sable_kernel/include/tensor_operation/threadwise_gemm.hpp
+0
-155
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
+184
-0
composable_kernel/include/utility/container_helper.hpp
composable_kernel/include/utility/container_helper.hpp
+13
-0
composable_kernel/include/utility/tuple_helper.hpp
composable_kernel/include/utility/tuple_helper.hpp
+18
-0
composable_kernel/include/utility/type.hpp
composable_kernel/include/utility/type.hpp
+15
-0
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+1
-14
No files found.
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
View file @
dea22d0e
...
@@ -74,6 +74,11 @@ struct DynamicPassThrough
...
@@ -74,6 +74,11 @@ struct DynamicPassThrough
return
true
;
return
true
;
}
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
;
}
__host__
__device__
void
Print
()
const
__host__
__device__
void
Print
()
const
{
{
printf
(
"{"
);
printf
(
"{"
);
...
@@ -160,6 +165,13 @@ struct DynamicPad
...
@@ -160,6 +165,13 @@ struct DynamicPad
(
idx_up
[
Number
<
0
>
{}]
<
up_lengths_
[
Number
<
0
>
{}]
-
right_pad_
));
(
idx_up
[
Number
<
0
>
{}]
<
up_lengths_
[
Number
<
0
>
{}]
-
right_pad_
));
}
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
LeftPad
>::
value
&&
is_known_at_compile_time
<
RightPad
>::
value
;
}
__host__
__device__
void
Print
()
const
__host__
__device__
void
Print
()
const
{
{
printf
(
"{"
);
printf
(
"{"
);
...
@@ -243,6 +255,12 @@ struct DynamicLeftPad
...
@@ -243,6 +255,12 @@ struct DynamicLeftPad
return
SkipIsValidCheck
||
(
idx_up
[
Number
<
0
>
{}]
>=
left_pad_
);
return
SkipIsValidCheck
||
(
idx_up
[
Number
<
0
>
{}]
>=
left_pad_
);
}
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
LeftPad
>::
value
;
}
__host__
__device__
void
Print
()
const
__host__
__device__
void
Print
()
const
{
{
printf
(
"{"
);
printf
(
"{"
);
...
@@ -328,6 +346,13 @@ struct DynamicRightPad
...
@@ -328,6 +346,13 @@ struct DynamicRightPad
return
SkipIsValidCheck
||
(
idx_up
[
Number
<
0
>
{}]
<
low_length_
);
return
SkipIsValidCheck
||
(
idx_up
[
Number
<
0
>
{}]
<
low_length_
);
}
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
LowLength
>::
value
&&
is_known_at_compile_time
<
RightPad
>::
value
;
}
__host__
__device__
void
Print
()
const
__host__
__device__
void
Print
()
const
{
{
printf
(
"{"
);
printf
(
"{"
);
...
@@ -424,6 +449,12 @@ struct DynamicEmbed
...
@@ -424,6 +449,12 @@ struct DynamicEmbed
return
true
;
return
true
;
}
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
Coefficients
>::
value
;
}
__host__
__device__
void
Print
()
const
__host__
__device__
void
Print
()
const
{
{
printf
(
"{"
);
printf
(
"{"
);
...
@@ -930,6 +961,13 @@ struct DynamicMerge
...
@@ -930,6 +961,13 @@ struct DynamicMerge
return
true
;
return
true
;
}
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsScan
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
...
@@ -1033,6 +1071,12 @@ struct DynamicUnMerge
...
@@ -1033,6 +1071,12 @@ struct DynamicUnMerge
return
true
;
return
true
;
}
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
UpLengthsScan
>::
value
;
}
__host__
__device__
void
Print
()
const
__host__
__device__
void
Print
()
const
{
{
printf
(
"{"
);
printf
(
"{"
);
...
@@ -1097,6 +1141,11 @@ struct DynamicFreeze
...
@@ -1097,6 +1141,11 @@ struct DynamicFreeze
return
true
;
return
true
;
}
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowerIndex
>::
value
;
}
__host__
__device__
void
Print
()
const
__host__
__device__
void
Print
()
const
{
{
printf
(
"DynamicFreeze"
);
printf
(
"DynamicFreeze"
);
...
...
composable_kernel/include/tensor_description/dynamic_tensor_descriptor.hpp
View file @
dea22d0e
...
@@ -201,6 +201,19 @@ struct DynamicTensorDescriptor
...
@@ -201,6 +201,19 @@ struct DynamicTensorDescriptor
return
VisibleDimensionIds
{};
return
VisibleDimensionIds
{};
}
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cv_t
<
remove_reference_t
<
decltype
(
Transforms
{}[
i
])
>>::
IsKnownAtCompileTime
();
});
return
is_known
&&
is_known_at_compile_time
<
ElementSize
>::
value
&&
is_known_at_compile_time
<
ElementSpaceSize
>::
value
;
}
__host__
__device__
void
Print
()
const
__host__
__device__
void
Print
()
const
{
{
printf
(
"{"
);
printf
(
"{"
);
...
...
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
View file @
dea22d0e
...
@@ -330,360 +330,5 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -330,360 +330,5 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
}
}
};
};
// blockwise GEMM: C += transpose(A) * B
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template
<
index_t
BlockSize
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
ThreadMatrixC
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
KPerThreadLoop
,
index_t
MLevel0ThreadCluster
,
index_t
NLevel0ThreadCluster
,
index_t
MLevel1ThreadCluster
,
index_t
NLevel1ThreadCluster
,
index_t
ThreadGemmADataPerRead_M
,
index_t
ThreadGemmBDataPerRead_N
>
struct
BlockwiseGemm_km_kn_m0m1n0n1_v1
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
index_t
mMyThreadOffsetA
;
index_t
mMyThreadOffsetB
;
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v1
()
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
*
MLevel1ThreadCluster
*
NLevel1ThreadCluster
;
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
)
==
0
&&
N
%
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
)
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
static_assert
(
ThreadMatrixC
{}.
GetLength
(
I0
)
==
GetThreadMatrixCLengths
()[
I0
]
&&
ThreadMatrixC
{}.
GetLength
(
I1
)
==
GetThreadMatrixCLengths
()[
I1
],
"wrong! ThreadMatrixC lengths is wrong"
);
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
BlockMatrixA
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
row
));
mMyThreadOffsetB
=
BlockMatrixB
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
col
));
}
__device__
static
constexpr
auto
GetThreadMatrixCLengths
()
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
M
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
constexpr
index_t
MRepeat
=
M
/
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
);
constexpr
index_t
NRepeat
=
N
/
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
);
return
Sequence
<
MRepeat
*
MPerThreadSubC
,
NRepeat
*
NPerThreadSubC
>
{};
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
{
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
;
index_t
level1_id
=
thread_id
/
ThreadPerLevel0Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1ThreadCluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1ThreadCluster
;
index_t
level0_id
=
thread_id
%
ThreadPerLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0ThreadCluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0ThreadCluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
;
return
MatrixIndex
{
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
auto
K
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
MPerThread
=
c_thread_mtx
.
GetLength
(
I0
);
constexpr
auto
NPerThread
=
c_thread_mtx
.
GetLength
(
I1
);
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
ThreadGemmBDataPerRead_N
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
decltype
(
a_thread_mtx
),
decltype
(
b_thread_mtx
),
decltype
(
c_thread_mtx
)
>
{};
#pragma unroll
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#pragma unroll
// read A
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
))
+
mMyThreadOffsetA
,
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
m_repeat
*
MPerThreadSubC
)));
}
#pragma unroll
// read B
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
b_thread_copy
.
Run
(
p_b_block
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
))
+
mMyThreadOffsetB
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
n_repeat
*
NPerThreadSubC
)));
}
// C += A * B
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
}
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_pipelined_2x2
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
auto
K
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
MPerThread
=
c_thread_mtx
.
GetLength
(
I0
);
constexpr
auto
NPerThread
=
c_thread_mtx
.
GetLength
(
I1
);
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
"wrong! inline asm cannot deal with this GEMM config yet"
);
// thread A, B
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{}));
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{}));
// thread A-sub, B-sub
constexpr
auto
a_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{}),
make_tuple
(
Number
<
MPerThread
>
{},
Number
<
1
>
{}));
constexpr
auto
b_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
Number
<
NPerThread
>
{},
Number
<
1
>
{}));
constexpr
auto
c_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
Number
<
NPerThread
>
{},
Number
<
1
>
{}));
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpaceSize
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpaceSize
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
ThreadGemmBDataPerRead_N
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
decltype
(
a_thread_sub_mtx
),
decltype
(
b_thread_sub_mtx
),
decltype
(
c_thread_sub_mtx
)
>
{};
const
FloatA
*
p_a_block_off
=
p_a_block
+
mMyThreadOffsetA
;
const
FloatB
*
p_b_block_off
=
p_b_block
+
mMyThreadOffsetB
;
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
,
p_a_thread
);
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
,
p_b_thread
);
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerLevel1Cluster
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerLevel1Cluster
)),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
#pragma unroll
// loop over rest of k
for
(
index_t
k
=
KPerThreadLoop
;
k
<
K
;
k
+=
KPerThreadLoop
)
{
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
0
)),
p_a_thread
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
,
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
0
)));
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
0
)),
p_b_thread
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
NPerThreadSubC
)));
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
NPerLevel1Cluster
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
MPerLevel1Cluster
)),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
,
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
0
)));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
NPerThreadSubC
)));
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
MPerThread
=
ThreadMatrixC
{}.
GetLength
(
I0
);
constexpr
index_t
NPerThread
=
ThreadMatrixC
{}.
GetLength
(
I1
);
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
if
constexpr
(
MRepeat
==
2
&&
NRepeat
==
2
)
{
Run_pipelined_2x2
(
p_a_block
,
p_b_block
,
p_c_thread
);
}
else
{
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
}
#else
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
0 → 100644
View file @
dea22d0e
#ifndef CK_BLOCKWISE_GEMM_V2_HPP
#define CK_BLOCKWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "threadwise_gemm_v2.hpp"
namespace
ck
{
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template
<
index_t
BlockSize
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
ThreadMatrixC
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
KPerThreadLoop
,
index_t
MLevel0ThreadCluster
,
index_t
NLevel0ThreadCluster
,
index_t
MLevel1ThreadCluster
,
index_t
NLevel1ThreadCluster
,
index_t
ThreadGemmADataPerRead_M
,
index_t
ThreadGemmBDataPerRead_N
>
struct
BlockwiseGemm_km_kn_m0m1n0n1_v1
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
index_t
mMyThreadOffsetA
;
index_t
mMyThreadOffsetB
;
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v1
()
{
static_assert
(
BlockMatrixA
::
IsKnownAtCompileTime
()
&&
BlockMatrixB
::
IsKnownAtCompileTime
()
&&
ThreadMatrixC
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
ThreadPerLevel1Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
*
MLevel1ThreadCluster
*
NLevel1ThreadCluster
;
static_assert
(
BlockSize
==
ThreadPerLevel1Cluster
,
"wrong! wrong blocksize
\n
"
);
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
static_assert
(
M
%
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
)
==
0
&&
N
%
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
)
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
static_assert
(
ThreadMatrixC
{}.
GetLength
(
I0
)
==
GetThreadMatrixCLengths
()[
I0
]
&&
ThreadMatrixC
{}.
GetLength
(
I1
)
==
GetThreadMatrixCLengths
()[
I1
],
"wrong! ThreadMatrixC lengths is wrong"
);
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
BlockMatrixA
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
row
));
mMyThreadOffsetB
=
BlockMatrixB
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
col
));
}
__device__
static
constexpr
auto
GetThreadMatrixCLengths
()
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
M
=
BlockMatrixA
{}.
GetLength
(
I1
);
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
{}.
GetLength
(
I1
);
constexpr
index_t
MRepeat
=
M
/
(
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
);
constexpr
index_t
NRepeat
=
N
/
(
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
);
return
Sequence
<
MRepeat
*
MPerThreadSubC
,
NRepeat
*
NPerThreadSubC
>
{};
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
{
constexpr
index_t
ThreadPerLevel0Cluster
=
MLevel0ThreadCluster
*
NLevel0ThreadCluster
;
index_t
level1_id
=
thread_id
/
ThreadPerLevel0Cluster
;
index_t
level1_m_id
=
level1_id
/
NLevel1ThreadCluster
;
index_t
level1_n_id
=
level1_id
%
NLevel1ThreadCluster
;
index_t
level0_id
=
thread_id
%
ThreadPerLevel0Cluster
;
index_t
level0_m_id
=
level0_id
/
NLevel0ThreadCluster
;
index_t
level0_n_id
=
level0_id
%
NLevel0ThreadCluster
;
constexpr
index_t
MPerLevel0Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
;
constexpr
index_t
NPerLevel0Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
;
return
MatrixIndex
{
level1_m_id
*
MPerLevel0Cluster
+
level0_m_id
*
MPerThreadSubC
,
level1_n_id
*
NPerLevel0Cluster
+
level0_n_id
*
NPerThreadSubC
};
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
auto
K
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
MPerThread
=
c_thread_mtx
.
GetLength
(
I0
);
constexpr
auto
NPerThread
=
c_thread_mtx
.
GetLength
(
I1
);
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{});
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{});
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpace
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpace
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
ThreadGemmBDataPerRead_N
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
decltype
(
a_thread_mtx
),
decltype
(
b_thread_mtx
),
decltype
(
c_thread_mtx
)
>
{};
#pragma unroll
// loop over k
for
(
index_t
k_begin
=
0
;
k_begin
<
K
;
k_begin
+=
KPerThreadLoop
)
{
#pragma unroll
// read A
for
(
index_t
m_repeat
=
0
;
m_repeat
<
MRepeat
;
++
m_repeat
)
{
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
m_repeat
*
MPerLevel1Cluster
))
+
mMyThreadOffsetA
,
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
m_repeat
*
MPerThreadSubC
)));
}
#pragma unroll
// read B
for
(
index_t
n_repeat
=
0
;
n_repeat
<
NRepeat
;
++
n_repeat
)
{
b_thread_copy
.
Run
(
p_b_block
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
n_repeat
*
NPerLevel1Cluster
))
+
mMyThreadOffsetB
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
n_repeat
*
NPerThreadSubC
)));
}
// C += A * B
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
}
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_pipelined_2x2
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
b_block_mtx
=
BlockMatrixB
{};
constexpr
auto
c_thread_mtx
=
ThreadMatrixC
{};
constexpr
auto
K
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
MPerThread
=
c_thread_mtx
.
GetLength
(
I0
);
constexpr
auto
NPerThread
=
c_thread_mtx
.
GetLength
(
I1
);
constexpr
index_t
MPerLevel1Cluster
=
MPerThreadSubC
*
MLevel0ThreadCluster
*
MLevel1ThreadCluster
;
constexpr
index_t
NPerLevel1Cluster
=
NPerThreadSubC
*
NLevel0ThreadCluster
*
NLevel1ThreadCluster
;
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
static_assert
(
MRepeat
==
2
&&
NRepeat
==
2
,
"wrong! inline asm cannot deal with this GEMM config yet"
);
// thread A, B
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThread
>
{}));
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThread
>
{}));
// thread A-sub, B-sub
constexpr
auto
a_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
MPerThreadSubC
>
{}),
make_tuple
(
Number
<
MPerThread
>
{},
Number
<
1
>
{}));
constexpr
auto
b_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
Number
<
NPerThread
>
{},
Number
<
1
>
{}));
constexpr
auto
c_thread_sub_mtx
=
make_dynamic_naive_tensor_descriptor_v2
(
make_tuple
(
Number
<
MPerThreadSubC
>
{},
Number
<
NPerThreadSubC
>
{}),
make_tuple
(
Number
<
NPerThread
>
{},
Number
<
1
>
{}));
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpaceSize
()];
FloatB
p_b_thread
[
b_thread_mtx
.
GetElementSpaceSize
()];
constexpr
auto
a_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
KPerThreadLoop
,
MPerThreadSubC
,
ThreadGemmADataPerRead_M
>
{};
constexpr
auto
b_thread_copy
=
ThreadwiseMatrixSliceCopy_v2
<
BlockMatrixB
,
decltype
(
b_thread_mtx
),
KPerThreadLoop
,
NPerThreadSubC
,
ThreadGemmBDataPerRead_N
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v1
<
decltype
(
a_thread_sub_mtx
),
decltype
(
b_thread_sub_mtx
),
decltype
(
c_thread_sub_mtx
)
>
{};
const
FloatA
*
p_a_block_off
=
p_a_block
+
mMyThreadOffsetA
;
const
FloatB
*
p_b_block_off
=
p_b_block
+
mMyThreadOffsetB
;
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
,
p_a_thread
);
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
,
p_b_thread
);
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerLevel1Cluster
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerLevel1Cluster
)),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
#pragma unroll
// loop over rest of k
for
(
index_t
k
=
KPerThreadLoop
;
k
<
K
;
k
+=
KPerThreadLoop
)
{
// read A_sub_0
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
0
)),
p_a_thread
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
,
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
0
)));
// read B_sub_0
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
0
)),
p_b_thread
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
NPerThreadSubC
)));
// read B_sub_1
b_thread_copy
.
Run
(
p_b_block_off
+
b_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
NPerLevel1Cluster
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
// read A_sub_1
a_thread_copy
.
Run
(
p_a_block_off
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
k
,
MPerLevel1Cluster
)),
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)));
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
,
p_c_thread
);
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)));
}
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
,
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
0
)));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
p_a_thread
+
a_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
MPerThreadSubC
)),
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
0
,
NPerThreadSubC
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
MPerThreadSubC
,
NPerThreadSubC
)));
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_block
,
FloatC
*
p_c_thread
)
const
{
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
MPerThread
=
ThreadMatrixC
{}.
GetLength
(
I0
);
constexpr
index_t
NPerThread
=
ThreadMatrixC
{}.
GetLength
(
I1
);
constexpr
index_t
MRepeat
=
MPerThread
/
MPerThreadSubC
;
constexpr
index_t
NRepeat
=
NPerThread
/
NPerThreadSubC
;
if
constexpr
(
MRepeat
==
2
&&
NRepeat
==
2
)
{
Run_pipelined_2x2
(
p_a_block
,
p_b_block
,
p_c_thread
);
}
else
{
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
}
#else
Run_naive
(
p_a_block
,
p_b_block
,
p_c_thread
);
#endif
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
dea22d0e
...
@@ -2,13 +2,12 @@
...
@@ -2,13 +2,12 @@
#define CK_GRIDWISE_DYNAMIC_GEMM_HPP
#define CK_GRIDWISE_DYNAMIC_GEMM_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_gemm_v2.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -397,6 +396,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
...
@@ -397,6 +396,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
constexpr
auto
c_m0_m1_n0_n1_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
constexpr
auto
c_m0_m1_n0_n1_global_tensor_iterator_hacks
=
CGlobalIteratorHacks
{};
constexpr
auto
tmp
=
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MPerThread
>
{},
Number
<
NRepeat
>
{},
Number
<
NPerThread
>
{}));
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
ThreadwiseDynamicTensorSliceTransfer_v1r3
<
AccFloat
,
AccFloat
,
Float
,
Float
,
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
dea22d0e
...
@@ -44,7 +44,8 @@ template <typename SrcData,
...
@@ -44,7 +44,8 @@ template <typename SrcData,
AddressSpace
DstAddressSpace
,
AddressSpace
DstAddressSpace
,
InMemoryDataOperation
DstInMemOp
,
InMemoryDataOperation
DstInMemOp
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
DstResetCoordinateAfterRun
>
bool
DstResetCoordinateAfterRun
,
typename
std
::
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseDynamicTensorSliceTransfer_v1r3
struct
ThreadwiseDynamicTensorSliceTransfer_v1r3
{
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
...
@@ -59,6 +60,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -59,6 +60,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
:
dst_slice_origin_coord_
(
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
))
:
dst_slice_origin_coord_
(
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
))
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
}
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
)
...
@@ -342,7 +345,8 @@ template <typename SrcData,
...
@@ -342,7 +345,8 @@ template <typename SrcData,
AddressSpace
SrcAddressSpace
,
AddressSpace
SrcAddressSpace
,
AddressSpace
DstAddressSpace
,
AddressSpace
DstAddressSpace
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
>
bool
SrcResetCoordinateAfterRun
,
typename
std
::
enable_if
<
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseDynamicTensorSliceTransfer_v2
struct
ThreadwiseDynamicTensorSliceTransfer_v2
{
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
...
@@ -357,6 +361,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
...
@@ -357,6 +361,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
const
Index
&
src_slice_origin_idx
)
const
Index
&
src_slice_origin_idx
)
:
src_slice_origin_coord_
(
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
))
:
src_slice_origin_coord_
(
make_dynamic_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
))
{
{
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
}
}
__device__
void
SetDstSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetDstSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
@@ -1233,9 +1239,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
...
@@ -1233,9 +1239,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
private:
private:
static
constexpr
auto
buffer_desc_
=
static
constexpr
auto
buffer_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
to_multi_index
(
SliceLengths
{}));
make_dynamic_naive_tensor_descriptor_packed_v2
(
sequence_to_tuple_of_number
(
SliceLengths
{}));
static
constexpr
index_t
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
static
constexpr
auto
buffer_size_
=
buffer_desc_
.
GetElementSpaceSize
();
StaticallyIndexedArray
<
SrcData
,
buffer_size_
>
buffer_
;
StaticallyIndexedArray
<
SrcData
,
buffer_size_
>
buffer_
;
...
...
composable_kernel/include/tensor_operation/threadwise_gemm.hpp
View file @
dea22d0e
...
@@ -161,160 +161,5 @@ struct ThreadwiseGemmTransANormalBNormalC
...
@@ -161,160 +161,5 @@ struct ThreadwiseGemmTransANormalBNormalC
}
}
};
};
template
<
typename
Float
,
class
Matrix
>
__device__
void
threadwise_matrix_set_zero_v2
(
Matrix
,
Float
*
__restrict__
p_thread
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
M
=
Matrix
{}.
GetLength
(
I0
);
constexpr
auto
N
=
Matrix
{}.
GetLength
(
I1
);
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
offset
=
Matrix
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
p_thread
[
offset
]
=
Float
(
0
);
});
});
}
template
<
typename
SrcMatrix
,
typename
DstMatrix
,
index_t
NSliceRow
,
index_t
NSliceCol
,
index_t
DataPerAccess
>
struct
ThreadwiseMatrixSliceCopy_v2
{
template
<
typename
Data
>
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
MemoryType
;
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
constexpr
auto
src_offset
=
SrcMatrix
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
constexpr
auto
dst_offset
=
DstMatrix
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
]);
});
});
}
};
// C += transpose(A) * B
// Element of matrix can be vectorized data
template
<
typename
MatrixA
,
typename
MatrixB
,
typename
MatrixC
>
struct
ThreadwiseGemm_km_kn_mn_v1
{
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
M
=
MatrixC
{}[
I0
];
constexpr
index_t
N
=
MatrixC
{}[
I1
];
constexpr
index_t
K
=
MatrixA
{}[
I0
];
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
const
index_t
a_offset
=
MatrixA
{}.
CalculateOffset
(
make_tuple
(
k
,
m
));
// A is transposed
const
index_t
b_offset
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
n
));
const
index_t
c_offset
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
n
));
p_c
[
c_offset
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
a_offset
],
p_b
[
b_offset
]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_amd_asm
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
index_t
M
=
MatrixC
{}.
GetLength
(
I0
);
constexpr
index_t
N
=
MatrixC
{}.
GetLength
(
I1
);
constexpr
index_t
K
=
MatrixA
{}.
GetLength
(
I0
);
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
a_offset
=
MatrixA
{}.
CalculateOffset
(
make_tuple
(
k
,
m
));
if
constexpr
(
N
==
2
)
{
constexpr
auto
b_offset_0
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I1
));
constexpr
auto
c_offset_0
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I1
));
amd_assembly_outer_product_1x2
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
]);
}
else
if
constexpr
(
N
==
4
)
{
constexpr
auto
b_offset_0
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I1
));
constexpr
auto
b_offset_2
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I2
));
constexpr
auto
b_offset_3
=
MatrixB
{}.
CalculateOffset
(
make_tuple
(
k
,
I3
));
constexpr
auto
c_offset_0
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I1
));
constexpr
auto
c_offset_2
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I2
));
constexpr
auto
c_offset_3
=
MatrixC
{}.
CalculateOffset
(
make_tuple
(
m
,
I3
));
amd_assembly_outer_product_1x4
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_b
[
b_offset_2
],
p_b
[
b_offset_3
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
],
p_c
[
c_offset_2
],
p_c
[
c_offset_3
]);
}
});
});
}
#endif
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr
bool
has_amd_asm
=
is_same
<
FloatC
,
float
>
{}
&&
((
is_same
<
FloatA
,
float
>
{}
&&
is_same
<
FloatB
,
float
>
{})
||
(
is_same
<
FloatA
,
half2_t
>
{}
&&
is_same
<
FloatB
,
half2_t
>
{})
||
(
is_same
<
FloatA
,
half4_t
>
{}
&&
is_same
<
FloatB
,
half4_t
>
{}));
if
constexpr
(
has_amd_asm
)
{
Run_amd_asm
(
p_a
,
p_b
,
p_c
);
}
else
{
Run_source
(
p_a
,
p_b
,
p_c
);
}
#else
Run_source
(
p_a
,
p_b
,
p_c
);
#endif
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
0 → 100644
View file @
dea22d0e
#ifndef CK_THREADWISE_GEMM_V2_HPP
#define CK_THREADWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "math.hpp"
namespace
ck
{
template
<
typename
Float
,
typename
Desc
>
__device__
void
threadwise_matrix_set_zero_v2
(
Desc
,
Float
*
__restrict__
p_thread
)
{
static_assert
(
Desc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
desc
=
Desc
{};
constexpr
auto
M
=
desc
.
GetLength
(
I0
);
constexpr
auto
N
=
desc
.
GetLength
(
I1
);
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
offset
=
desc
.
CalculateOffset
(
make_tuple
(
i
,
j
));
p_thread
[
offset
]
=
Float
(
0
);
});
});
}
template
<
typename
SrcDesc
,
typename
DstDesc
,
index_t
NSliceRow
,
index_t
NSliceCol
,
index_t
DataPerAccess
>
struct
ThreadwiseMatrixSliceCopy_v2
{
template
<
typename
Data
>
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
using
vector_t
=
typename
vector_type
<
Data
,
DataPerAccess
>::
MemoryType
;
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
constexpr
auto
src_offset
=
SrcDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
constexpr
auto
dst_offset
=
DstDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
]);
});
});
}
};
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
template
<
typename
ADesc
,
typename
BDesc
,
typename
CDesc
,
typename
std
::
enable_if
<
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseGemm_km_kn_mn_v1
{
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
M
=
CDesc
{}[
I0
];
constexpr
auto
N
=
CDesc
{}[
I1
];
constexpr
auto
K
=
ADesc
{}[
I0
];
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
k
,
m
));
constexpr
auto
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
n
));
constexpr
auto
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
n
));
p_c
[
c_offset
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
a_offset
],
p_b
[
b_offset
]);
});
});
});
}
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_amd_asm
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
M
=
CDesc
{}.
GetLength
(
I0
);
constexpr
auto
N
=
CDesc
{}.
GetLength
(
I1
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I0
);
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
k
,
m
));
if
constexpr
(
N
==
2
)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I1
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I1
));
amd_assembly_outer_product_1x2
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
]);
}
else
if
constexpr
(
N
==
4
)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I1
));
constexpr
auto
b_offset_2
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I2
));
constexpr
auto
b_offset_3
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I3
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I1
));
constexpr
auto
c_offset_2
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I2
));
constexpr
auto
c_offset_3
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I3
));
amd_assembly_outer_product_1x4
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_b
[
b_offset_2
],
p_b
[
b_offset_3
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
],
p_c
[
c_offset_2
],
p_c
[
c_offset_3
]);
}
});
});
}
#endif
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
constexpr
bool
has_amd_asm
=
is_same
<
FloatC
,
float
>
{}
&&
((
is_same
<
FloatA
,
float
>
{}
&&
is_same
<
FloatB
,
float
>
{})
||
(
is_same
<
FloatA
,
half2_t
>
{}
&&
is_same
<
FloatB
,
half2_t
>
{})
||
(
is_same
<
FloatA
,
half4_t
>
{}
&&
is_same
<
FloatB
,
half4_t
>
{}));
if
constexpr
(
has_amd_asm
)
{
Run_amd_asm
(
p_a
,
p_b
,
p_c
);
}
else
{
Run_source
(
p_a
,
p_b
,
p_c
);
}
#else
Run_source
(
p_a
,
p_b
,
p_c
);
#endif
}
};
}
// namespace ck
#endif
composable_kernel/include/utility/container_helper.hpp
View file @
dea22d0e
...
@@ -279,5 +279,18 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
...
@@ -279,5 +279,18 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
static_for
<
0
,
sizeof
...(
Is
),
1
>
{}([
&
](
auto
i
)
{
y
(
picks
[
i
])
=
x
[
i
];
});
static_for
<
0
,
sizeof
...(
Is
),
1
>
{}([
&
](
auto
i
)
{
y
(
picks
[
i
])
=
x
[
i
];
});
}
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
sequence_to_tuple_of_number
(
Sequence
<
Is
...
>
)
{
using
Seq
=
Sequence
<
Is
...
>
;
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
index_t
tmp
=
Seq
::
At
(
i
);
return
Number
<
tmp
>
{};
},
Seq
::
Size
());
}
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/utility/tuple_helper.hpp
View file @
dea22d0e
...
@@ -6,6 +6,24 @@
...
@@ -6,6 +6,24 @@
namespace
ck
{
namespace
ck
{
template
<
typename
...
Ts
>
struct
is_known_at_compile_time
<
Tuple
<
Ts
...
>>
{
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
container_reduce
(
Tuple
<
Ts
...
>
{},
[](
auto
x
,
bool
r
)
{
return
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
decltype
(
x
)
>>>::
value
&
r
;
},
true
);
}
static
constexpr
bool
value
=
IsKnownAtCompileTime
();
};
template
<
typename
F
,
index_t
N
>
template
<
typename
F
,
index_t
N
>
__host__
__device__
constexpr
auto
generate_tuple
(
F
&&
f
,
Number
<
N
>
)
__host__
__device__
constexpr
auto
generate_tuple
(
F
&&
f
,
Number
<
N
>
)
{
{
...
...
composable_kernel/include/utility/type.hpp
View file @
dea22d0e
...
@@ -27,5 +27,20 @@ constexpr std::remove_reference_t<T>&& move(T&& t) noexcept
...
@@ -27,5 +27,20 @@ constexpr std::remove_reference_t<T>&& move(T&& t) noexcept
return
static_cast
<
typename
std
::
remove_reference
<
T
>::
type
&&>
(
t
);
return
static_cast
<
typename
std
::
remove_reference
<
T
>::
type
&&>
(
t
);
}
}
template
<
typename
T
>
struct
is_known_at_compile_time
;
template
<
>
struct
is_known_at_compile_time
<
index_t
>
{
static
constexpr
bool
value
=
false
;
};
template
<
typename
T
,
T
X
>
struct
is_known_at_compile_time
<
integral_constant
<
T
,
X
>>
{
static
constexpr
bool
value
=
true
;
};
}
// namespace ck
}
// namespace ck
#endif
#endif
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
dea22d0e
...
@@ -3,19 +3,6 @@
...
@@ -3,19 +3,6 @@
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template
<
typename
T
>
__host__
__device__
constexpr
auto
sequence_to_tuple_of_number
(
const
T
&
x
)
{
using
namespace
ck
;
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
index_t
tmp
=
T
::
At
(
i
);
return
Number
<
tmp
>
{};
},
T
::
Size
());
}
template
<
class
T
,
template
<
class
T
,
class
InDesc
,
class
InDesc
,
class
WeiDesc
,
class
WeiDesc
,
...
@@ -269,7 +256,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
...
@@ -269,7 +256,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
constexpr
auto
conv_driver
=
constexpr
auto
conv_driver
=
#if 1
#if 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_pad
#elif
0
#elif
1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_no_pad
#elif 1
#elif 1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
DriverDynamicConvolutionForwardImplicitGemm_v4r4_nchw_kcyx_nkhw_1x1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment