Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dfbe7e20
Commit
dfbe7e20
authored
May 17, 2021
by
Jing Zhang
Browse files
added tuning params
parent
b3a4d179
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
66 additions
and
101 deletions
+66
-101
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+0
-5
composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp
...e_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp
+10
-18
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+8
-8
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
...include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
+34
-45
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+9
-20
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+5
-5
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
dfbe7e20
...
@@ -13,8 +13,6 @@ namespace ck {
...
@@ -13,8 +13,6 @@ namespace ck {
// GemmK = C * Y * X
// GemmK = C * Y * X
template
<
index_t
GemmMPerBlock
,
template
<
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmM1
,
index_t
GemmN1
,
typename
...
Wei
,
typename
...
Wei
,
typename
...
In
,
typename
...
In
,
typename
...
Out
,
typename
...
Out
,
...
@@ -108,9 +106,6 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
...
@@ -108,9 +106,6 @@ transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad(
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
);
const
auto
GemmM0
=
GemmM
/
Number
<
GemmM1
>
{};
const
auto
GemmN0
=
GemmN
/
Number
<
GemmN1
>
{};
const
auto
out_m0_m1_m2_n_global_desc
=
transform_dynamic_tensor_descriptor
(
const
auto
out_m0_m1_m2_n_global_desc
=
transform_dynamic_tensor_descriptor
(
out_gemmm_gemmn_global_desc
,
out_gemmm_gemmn_global_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM
/
8
,
2
,
4
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
GemmM
/
8
,
2
,
4
)),
...
...
composable_kernel/include/driver/driver_dynamic_gemm_xdlops_v1.hpp
View file @
dfbe7e20
...
@@ -21,13 +21,9 @@ template <index_t BlockSize,
...
@@ -21,13 +21,9 @@ template <index_t BlockSize,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
MPerThread
,
index_t
MPerWave
,
index_t
NPerThread
,
index_t
NPerWave
,
index_t
KPerThread
,
index_t
KPerWave
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
...
@@ -81,10 +77,7 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -81,10 +77,7 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
}
constexpr
auto
M1
=
Number
<
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
>
{};
if
(
!
(
MPerBlock
%
MPerWave
==
0
&&
NPerBlock
%
NPerWave
==
0
))
constexpr
auto
N1
=
Number
<
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
>
{};
if
(
!
(
MPerBlock
%
M1
==
0
&&
NPerBlock
%
N1
==
0
))
{
{
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
throw
std
::
runtime_error
(
"wrong! GEMM size no divisible"
);
}
}
...
@@ -103,13 +96,9 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -103,13 +96,9 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
MPerThread
,
MPerWave
,
NPerThread
,
NPerWave
,
KPerThread
,
KPerWave
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadSliceLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterLengths_K_M
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
...
@@ -141,6 +130,9 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
...
@@ -141,6 +130,9 @@ __host__ float launch_kernel_dynamic_gemm_xdlops_v1(const FloatAB* p_a_global,
const
bool
has_double_tail_k_block_loop
=
(
K
/
KPerBlock
)
%
2
==
0
;
const
bool
has_double_tail_k_block_loop
=
(
K
/
KPerBlock
)
%
2
==
0
;
std
::
cerr
<<
"has_main_k_block_loop = "
<<
has_main_k_block_loop
<<
" has_double_tail_k_block_loop = "
<<
has_double_tail_k_block_loop
<<
std
::
endl
;
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
float
ave_time
=
0
;
float
ave_time
=
0
;
...
...
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
dfbe7e20
...
@@ -15,9 +15,7 @@ template <index_t BlockSize,
...
@@ -15,9 +15,7 @@ template <index_t BlockSize,
class
BBlockDesc
,
class
BBlockDesc
,
index_t
MPerWave
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
index_t
KPerWave
>
index_t
MWaves
,
index_t
NWaves
>
struct
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
struct
BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
{
{
...
@@ -32,6 +30,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -32,6 +30,11 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
MPerBlock
=
ABlockDesc
{}.
GetLength
(
I1
);
// A is transposed
static
constexpr
index_t
NPerBlock
=
BBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
MPerWave
;
static
constexpr
index_t
NWaves
=
NPerBlock
/
NPerWave
;
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
XdlopsGemm
.
GetOutputLayout
();
}
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
XdlopsGemm
.
GetOutputLayout
();
}
__device__
constexpr
auto
GetNumBlks
()
const
__device__
constexpr
auto
GetNumBlks
()
const
...
@@ -90,11 +93,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
...
@@ -90,11 +93,8 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
"wrong! K dimension not consistent"
);
constexpr
index_t
M
=
ABlockDesc
{}.
GetLength
(
I1
);
// A is transposed
static_assert
(
MPerWave
*
MWaves
==
MPerBlock
,
"GemmMWaves * MPerWave != M"
);
constexpr
index_t
N
=
BBlockDesc
{}.
GetLength
(
I1
);
static_assert
(
NPerWave
*
NWaves
==
NPerBlock
,
"GemmNWaves * NPerWave != N"
);
static_assert
(
MPerWave
*
MWaves
==
M
,
"GemmMWaves * MPerWave != M"
);
static_assert
(
NPerWave
*
NWaves
==
N
,
"GemmNWaves * NPerWave != N"
);
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_xdlops.hpp
View file @
dfbe7e20
...
@@ -108,13 +108,9 @@ template <index_t BlockSize,
...
@@ -108,13 +108,9 @@ template <index_t BlockSize,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
MPerThread
,
index_t
MPerWave
,
index_t
NPerThread
,
index_t
NPerWave
,
index_t
KPerThread
,
index_t
KPerWave
,
index_t
MLevel0Cluster
,
index_t
NLevel0Cluster
,
index_t
MLevel1Cluster
,
index_t
NLevel1Cluster
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadSliceLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterLengths_K_M
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
...
@@ -144,9 +140,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -144,9 +140,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{});
Number
<
MPerThread
>
{},
Number
<
NPerThread
>
{});
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
...
@@ -209,9 +203,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -209,9 +203,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
constexpr
auto
max_lds_align
=
math
::
lcm
(
Number
<
ABlockTransferDstScalarPerVector_M
>
{},
Number
<
BBlockTransferDstScalarPerVector_N
>
{}
,
Number
<
BBlockTransferDstScalarPerVector_N
>
{}
);
Number
<
MPerThread
>
{},
//
Number<MPerThread>{},
Number
<
NPerThread
>
{});
//
Number<NPerThread>{});
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
...
@@ -284,30 +278,28 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -284,30 +278,28 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
// sanity check
// sanity check
static_assert
(
MPerBlock
%
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
)
==
0
&&
static_assert
(
MPerBlock
%
MPerWave
==
0
&&
NPerBlock
%
NPerWave
==
0
,
"wrong!"
);
NPerBlock
%
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
"wrong!"
);
// constexpr index_t MRepeat = MPerBlock / (MPerThread * MLevel0Cluster * MLevel1Cluster);
// constexpr index_t NRepeat = NPerBlock / (NPerThread * NLevel0Cluster * NLevel1Cluster);
constexpr
index_t
MRepeat
=
MPerBlock
/
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
NRepeat
=
NPerBlock
/
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
);
// constexpr auto a_k_m0_m1_block_desc = transform_dynamic_tensor_descriptor(
// a_k_m_block_desc,
constexpr
auto
a_k_m0_m1_block_desc
=
transform_dynamic_tensor_descriptor
(
// make_tuple(
a_k_m_block_desc
,
// make_pass_through_transform(Number<KPerBlock>{}),
make_tuple
(
// make_unmerge_transform(make_tuple(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
// Number<MRepeat>{}, Number<MPerThread * MLevel0Cluster * MLevel1Cluster>{}))),
make_unmerge_transform
(
make_tuple
(
// make_tuple(Sequence<0>{}, Sequence<1>{}),
Number
<
MRepeat
>
{},
Number
<
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
>
{}))),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
// constexpr auto b_k_n0_n1_block_desc = transform_dynamic_tensor_descriptor(
// b_k_n_block_desc,
constexpr
auto
b_k_n0_n1_block_desc
=
transform_dynamic_tensor_descriptor
(
// make_tuple(
b_k_n_block_desc
,
// make_pass_through_transform(Number<KPerBlock>{}),
make_tuple
(
// make_unmerge_transform(make_tuple(
make_pass_through_transform
(
Number
<
KPerBlock
>
{}),
// Number<NRepeat>{}, Number<NPerThread * NLevel0Cluster * NLevel1Cluster>{}))),
make_unmerge_transform
(
make_tuple
(
// make_tuple(Sequence<0>{}, Sequence<1>{}),
Number
<
NRepeat
>
{},
Number
<
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
>
{}))),
// make_tuple(Sequence<0>{}, Sequence<1, 2>{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
// constexpr auto c_m0_m1_n0_n1_thread_desc =
// constexpr auto c_m0_m1_n0_n1_thread_desc =
// make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
// make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
...
@@ -318,12 +310,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -318,12 +310,9 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
FloatAB
,
FloatAB
,
decltype
(
a_k_m_block_desc
),
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
b_k_n_block_desc
),
64
,
// MPerWave,
MPerWave
,
64
,
// NPerWave,
NPerWave
,
1
,
// KPerWave,
KPerWave
>
{};
1
,
// MWaves,
1
// NWaves,
>
{};
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
...
@@ -481,7 +470,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -481,7 +470,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr
index_t
K1
=
OutputLayout
.
N1
();
constexpr
index_t
K1
=
OutputLayout
.
N1
();
constexpr
index_t
K2
=
OutputLayout
.
M0
();
constexpr
index_t
K2
=
OutputLayout
.
M0
();
static_assert
(
K0
==
4
&&
K1
==
2
&&
K2
==
4
,
""
);
//
static_assert(K0 == 4 && K1 == 2 && K2 == 4, "");
constexpr
auto
c_m0_m1_m2_n_thread_desc
=
constexpr
auto
c_m0_m1_m2_n_thread_desc
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_dynamic_naive_tensor_descriptor_packed_v2
(
...
@@ -490,7 +479,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
...
@@ -490,7 +479,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v1
constexpr
index_t
BlkSize
=
OutputLayout
.
GetBlkSize
();
constexpr
index_t
BlkSize
=
OutputLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
OutputLayout
.
GetNumBlks
();
constexpr
index_t
NumBlks
=
OutputLayout
.
GetNumBlks
();
static_assert
(
BlkSize
==
16
&&
NumBlks
==
4
,
""
);
//
static_assert(BlkSize == 16 && NumBlks == 4, "");
// force unrolling the output loop to get ride of scratches
// force unrolling the output loop to get ride of scratches
static_for
<
0
,
NumBlks
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumBlks
,
1
>
{}([
&
](
auto
i
)
{
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
dfbe7e20
...
@@ -84,14 +84,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -84,14 +84,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmKPerWave
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
32
>
;
using
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
32
>
;
...
@@ -107,14 +102,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -107,14 +102,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
constexpr
index_t
GemmCThreadTransferDstScalarPerVector_GemmN1
=
1
;
constexpr
index_t
GemmM1
=
GemmMPer
Thread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
GemmM1
=
GemmMPer
Wave
;
constexpr
index_t
GemmN1
=
GemmNPer
Thread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
GemmN1
=
GemmNPer
Wave
;
const
auto
descs
=
const
auto
descs
=
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
<
GemmMPerBlock
,
transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad
<
GemmMPerBlock
,
GemmNPerBlock
,
GemmNPerBlock
>
(
GemmM1
,
GemmN1
>
(
wei_k_c_y_x_desc
,
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
out_n_k_ho_wo_desc
,
...
@@ -138,13 +131,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -138,13 +131,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmMPerBlock
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmKPerBlock
,
GemmMPerThread
,
GemmMPerWave
,
GemmNPerThread
,
GemmNPerWave
,
GemmKPerThread
,
GemmKPerWave
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadSliceLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
GemmABlockTransferThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
...
...
driver/src/conv_driver.cpp
View file @
dfbe7e20
...
@@ -25,11 +25,11 @@ int main(int argc, char* argv[])
...
@@ -25,11 +25,11 @@ int main(int argc, char* argv[])
using
namespace
ck
;
using
namespace
ck
;
#if 1
#if 1
constexpr
index_t
N
=
4
;
constexpr
index_t
N
=
256
;
constexpr
index_t
C
=
1
6
;
constexpr
index_t
C
=
25
6
;
constexpr
index_t
HI
=
4
;
constexpr
index_t
HI
=
16
;
constexpr
index_t
WI
=
4
;
constexpr
index_t
WI
=
16
;
constexpr
index_t
K
=
6
4
;
constexpr
index_t
K
=
25
6
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment