Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8d15144c
"...resnet50_tensorflow.git" did not exist on "c5df7268120db7c7f9b3d40bdb059af663083efb"
Commit
8d15144c
authored
Apr 20, 2020
by
Chao Liu
Browse files
refactor
parent
7fde99f4
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
221 additions
and
749 deletions
+221
-749
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+15
-17
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...ridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+1
-1
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
...osable_kernel/include/tensor_operation/blockwise_gemm.hpp
+1
-1
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+1
-2
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+14
-14
composable_kernel/include/utility/float_type.amd.hpp.in
composable_kernel/include/utility/float_type.amd.hpp.in
+18
-18
composable_kernel/include/utility/float_type.nvidia.hpp.in
composable_kernel/include/utility/float_type.nvidia.hpp.in
+15
-17
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+58
-58
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp
...volution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp
+0
-305
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+84
-84
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp
...volution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp
+0
-225
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+14
-7
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
8d15144c
...
@@ -71,13 +71,13 @@ template <index_t GridSize,
...
@@ -71,13 +71,13 @@ template <index_t GridSize,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
EPerBlock
,
index_t
GemmNRepeat
,
index_t
GemmNRepeat
,
index_t
GemmMPerThreadSubC
,
index_t
GemmMPerThread
,
index_t
GemmNPerThreadSubC
,
index_t
GemmNPerThread
,
index_t
GemmKPerThread
,
index_t
GemmMLevel0Cluster
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
index_t
GemmDataPerReadB
,
typename
InBlockCopySubLengths_E_N1_B_N2
,
typename
InBlockCopySubLengths_E_N1_B_N2
,
...
@@ -114,11 +114,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -114,11 +114,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// this is a mess
// this is a mess
// TODO: find more elegent way of specifying (or calculating) performance parameters
// TODO: find more elegent way of specifying (or calculating) performance parameters
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThread
SubC
;
constexpr
index_t
N2
=
GemmNPerThread
;
static_assert
((
N1
*
N2
*
BPerBlock
)
%
static_assert
(
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
(
N1
*
N2
*
BPerBlock
)
%
(
GemmNPerThread
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
0
,
"wrong!"
);
"wrong!"
);
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
...
@@ -290,30 +289,29 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -290,30 +289,29 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
in_e_n1_b_n2_block_desc
.
GetStride
(
I0
));
in_e_n1_b_n2_block_desc
.
GetStride
(
I0
));
// sanity check
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
static_assert
(
KPerBlock
%
(
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
,
0
,
"wrong!"
);
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThread
SubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
KPerBlock
/
(
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k1_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
constexpr
auto
c_k0k1_n1n2_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThread
SubC
>
{},
Number
<
GemmNRepeat
*
GemmNPerThread
SubC
>
{});
Number
<
GemmMRepeat
*
GemmMPerThread
>
{},
Number
<
GemmNRepeat
*
GemmNPerThread
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
b_e_n1bn2_block_mtx_desc
),
decltype
(
c_k0k1_n1n2_thread_mtx_desc
),
decltype
(
c_k0k1_n1n2_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmMPerThread
,
GemmNPerThreadSubC
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
GemmDataPerReadB
>
{};
...
@@ -432,13 +430,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -432,13 +430,13 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
// copy output: register to global memory
// copy output: register to global memory
{
{
constexpr
index_t
K1
=
GemmMPerThread
SubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
K1
=
GemmMPerThread
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
K0
=
K
/
K1
;
constexpr
index_t
K0
=
K
/
K1
;
// define output tensor descriptor for threadwise copy
// define output tensor descriptor for threadwise copy
// thread output tensor, src of threadwise copy
// thread output tensor, src of threadwise copy
constexpr
auto
out_k0_k1_n1_b_n2_thread_desc
=
make_native_tensor_descriptor_packed
(
constexpr
auto
out_k0_k1_n1_b_n2_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
GemmMPerThread
SubC
,
N1
,
1
,
N2
>
{});
Sequence
<
GemmMRepeat
,
GemmMPerThread
,
N1
,
1
,
N2
>
{});
// global output tensor
// global output tensor
constexpr
auto
out_n0_n1_n2_k0_k1_ho_wo_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
out_n0_n1_n2_k0_k1_ho_wo_global_desc
=
transform_tensor_descriptor
(
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
8d15144c
...
@@ -159,7 +159,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
...
@@ -159,7 +159,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
1
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
Sequence
<
2
,
3
,
0
,
1
>
,
3
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
...
...
composable_kernel/include/tensor_operation/blockwise_gemm.hpp
View file @
8d15144c
...
@@ -18,11 +18,11 @@ template <index_t BlockSize,
...
@@ -18,11 +18,11 @@ template <index_t BlockSize,
typename
ThreadMatrixC
,
typename
ThreadMatrixC
,
index_t
MPerThreadSubC
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
KPerThreadLoop
,
index_t
MLevel0ThreadCluster
,
index_t
MLevel0ThreadCluster
,
index_t
NLevel0ThreadCluster
,
index_t
NLevel0ThreadCluster
,
index_t
MLevel1ThreadCluster
,
index_t
MLevel1ThreadCluster
,
index_t
NLevel1ThreadCluster
,
index_t
NLevel1ThreadCluster
,
index_t
KPerThreadLoop
,
index_t
ThreadGemmADataPerRead_M
,
index_t
ThreadGemmADataPerRead_M
,
index_t
ThreadGemmBDataPerRead_N
>
index_t
ThreadGemmBDataPerRead_N
>
struct
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
struct
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
...
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
8d15144c
...
@@ -186,7 +186,6 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -186,7 +186,6 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
"wrong!"
);
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
MPerBlock
/
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
GemmMRepeat
=
MPerBlock
/
(
MPerThread
*
MLevel0Cluster
*
MLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
NPerBlock
/
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
NPerBlock
/
(
NPerThread
*
NLevel0Cluster
*
NLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// c_thread_mtx definition: this is a mess
...
@@ -201,11 +200,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -201,11 +200,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
decltype
(
c_m0m1_n0n1_thread_mtx_desc
),
decltype
(
c_m0m1_n0n1_thread_mtx_desc
),
MPerThread
,
MPerThread
,
NPerThread
,
NPerThread
,
KPerThread
,
MLevel0Cluster
,
MLevel0Cluster
,
NLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
NLevel1Cluster
,
KPerThread
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmAThreadCopySrcDataPerRead_M
,
ThreadGemmBThreadCopySrcDataPerRead_N
>
{};
ThreadGemmBThreadCopySrcDataPerRead_N
>
{};
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
8d15144c
...
@@ -207,7 +207,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
...
@@ -207,7 +207,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x4f16>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half
*
a
,
const
half
*
b
,
float
*
reg_c
)
const
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half
_t
*
a
,
const
half
_t
*
b
,
float
*
reg_c
)
const
{
{
static_assert
((
MPerWave
==
64
&&
NPerWave
==
64
)
||
(
MPerWave
==
32
&&
NPerWave
==
64
)
||
static_assert
((
MPerWave
==
64
&&
NPerWave
==
64
)
||
(
MPerWave
==
32
&&
NPerWave
==
64
)
||
(
MPerWave
==
64
&&
NPerWave
==
32
),
(
MPerWave
==
64
&&
NPerWave
==
32
),
...
@@ -239,7 +239,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
...
@@ -239,7 +239,7 @@ struct mfma_info<mfma_instr::mfma_f32_32x32x8f16>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half
*
a
,
const
half
*
b
,
float
*
reg_c
)
const
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half
_t
*
a
,
const
half
_t
*
b
,
float
*
reg_c
)
const
{
{
static_assert
((
MPerWave
==
32
&&
NPerWave
==
32
),
"unsupported xdlops gemm"
);
static_assert
((
MPerWave
==
32
&&
NPerWave
==
32
),
"unsupported xdlops gemm"
);
...
@@ -269,7 +269,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
...
@@ -269,7 +269,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x16f16>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half
*
a
,
const
half
*
b
,
float
*
reg_c
)
const
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half
_t
*
a
,
const
half
_t
*
b
,
float
*
reg_c
)
const
{
{
static_assert
((
MPerWave
==
16
&&
NPerWave
==
16
),
"unsupported xdlops gemm"
);
static_assert
((
MPerWave
==
16
&&
NPerWave
==
16
),
"unsupported xdlops gemm"
);
...
@@ -299,7 +299,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
...
@@ -299,7 +299,7 @@ struct mfma_info<mfma_instr::mfma_f32_16x16x4f16>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half
*
a
,
const
half
*
b
,
float
*
reg_c
)
const
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half
_t
*
a
,
const
half
_t
*
b
,
float
*
reg_c
)
const
{
{
static_assert
((
MPerWave
==
16
&&
NPerWave
==
64
)
||
(
MPerWave
==
64
&&
NPerWave
==
16
),
static_assert
((
MPerWave
==
16
&&
NPerWave
==
64
)
||
(
MPerWave
==
64
&&
NPerWave
==
16
),
"unsupported xdlops gemm"
);
"unsupported xdlops gemm"
);
...
@@ -330,7 +330,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
...
@@ -330,7 +330,7 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x4f16>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
__device__
void
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half
*
a
,
const
half
*
b
,
float
*
reg_c
)
const
run
(
Number
<
MPerWave
>
,
Number
<
NPerWave
>
,
const
half
_t
*
a
,
const
half
_t
*
b
,
float
*
reg_c
)
const
{
{
static_assert
((
MPerWave
==
4
||
MPerWave
==
8
)
&&
NPerWave
==
64
,
static_assert
((
MPerWave
==
4
||
MPerWave
==
8
)
&&
NPerWave
==
64
,
"unsupported xdlops gemm"
);
"unsupported xdlops gemm"
);
...
@@ -555,55 +555,55 @@ __device__ constexpr auto GetMFMAInfo<float, 4, 64>()
...
@@ -555,55 +555,55 @@ __device__ constexpr auto GetMFMAInfo<float, 4, 64>()
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
64
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
_t
,
64
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
64
,
32
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
_t
,
64
,
32
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
32
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
_t
,
32
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
32
,
32
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
_t
,
32
,
32
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x8f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x8f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
16
,
16
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
_t
,
16
,
16
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x16f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x16f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
16
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
_t
,
16
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
64
,
16
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
_t
,
64
,
16
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
4
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
_t
,
4
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
8
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
_t
,
8
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{};
}
}
...
...
composable_kernel/include/utility/float_type.amd.hpp.in
View file @
8d15144c
...
@@ -84,37 +84,37 @@ struct vector_type<float, 4>
...
@@ -84,37 +84,37 @@ struct vector_type<float, 4>
};
};
template <>
template <>
struct vector_type<half, 1>
struct vector_type<half
_t
, 1>
{
{
using MemoryType = half;
using MemoryType = half
_t
;
template <index_t I>
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
__host__ __device__ static void SetScalar(MemoryType& v, half
_t
s, Number<I>)
{
{
static_assert(I < 1, "wrong");
static_assert(I < 1, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s;
*(reinterpret_cast<half
_t
*>(&v) + I) = s;
}
}
};
};
template <>
template <>
struct vector_type<half, 2>
struct vector_type<half
_t
, 2>
{
{
using MemoryType = half2_t;
using MemoryType = half2_t;
union DataType
union DataType
{
{
MemoryType vector;
MemoryType vector;
half scalar[2];
half
_t
scalar[2];
};
};
template <index_t I>
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
__host__ __device__ static void SetScalar(MemoryType& v, half
_t
s, Number<I>)
{
{
static_assert(I < 2, "wrong");
static_assert(I < 2, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s;
*(reinterpret_cast<half
_t
*>(&v) + I) = s;
}
}
__host__ __device__ static MemoryType Pack(half s0, half s1)
__host__ __device__ static MemoryType Pack(half
_t
s0, half
_t
s1)
{
{
DataType data;
DataType data;
data.scalar[0] = s0;
data.scalar[0] = s0;
...
@@ -124,24 +124,24 @@ struct vector_type<half, 2>
...
@@ -124,24 +124,24 @@ struct vector_type<half, 2>
};
};
template <>
template <>
struct vector_type<half, 4>
struct vector_type<half
_t
, 4>
{
{
using MemoryType = half4_t;
using MemoryType = half4_t;
union DataType
union DataType
{
{
MemoryType vector;
MemoryType vector;
half scalar[4];
half
_t
scalar[4];
};
};
template <index_t I>
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
__host__ __device__ static void SetScalar(MemoryType& v, half
_t
s, Number<I>)
{
{
static_assert(I < 4, "wrong");
static_assert(I < 4, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s;
*(reinterpret_cast<half
_t
*>(&v) + I) = s;
}
}
__host__ __device__ static MemoryType Pack(half s0, half s1, half s2, half s3)
__host__ __device__ static MemoryType Pack(half
_t
s0, half
_t
s1, half
_t
s2, half
_t
s3)
{
{
DataType data;
DataType data;
data.scalar[0] = s0;
data.scalar[0] = s0;
...
@@ -255,8 +255,8 @@ struct inner_product_with_conversion
...
@@ -255,8 +255,8 @@ struct inner_product_with_conversion
__device__ T operator()(half2_t a, half2_t b) const
__device__ T operator()(half2_t a, half2_t b) const
{
{
const half* p_a_half = reinterpret_cast<const half*>(&a);
const half
_t
* p_a_half = reinterpret_cast<const half
_t
*>(&a);
const half* p_b_half = reinterpret_cast<const half*>(&b);
const half
_t
* p_b_half = reinterpret_cast<const half
_t
*>(&b);
T acc = 0;
T acc = 0;
for(index_t v = 0; v < 2; ++v)
for(index_t v = 0; v < 2; ++v)
...
@@ -269,8 +269,8 @@ struct inner_product_with_conversion
...
@@ -269,8 +269,8 @@ struct inner_product_with_conversion
__device__ T operator()(half4_t a, half4_t b) const
__device__ T operator()(half4_t a, half4_t b) const
{
{
const half* p_a_half = reinterpret_cast<const half*>(&a);
const half
_t
* p_a_half = reinterpret_cast<const half
_t
*>(&a);
const half* p_b_half = reinterpret_cast<const half*>(&b);
const half
_t
* p_b_half = reinterpret_cast<const half
_t
*>(&b);
T acc = 0;
T acc = 0;
for(index_t v = 0; v < 4; ++v)
for(index_t v = 0; v < 4; ++v)
...
...
composable_kernel/include/utility/float_type.nvidia.hpp.in
View file @
8d15144c
...
@@ -14,17 +14,15 @@ using float2_t = float2;
...
@@ -14,17 +14,15 @@ using float2_t = float2;
using float4_t = float4;
using float4_t = float4;
// float
// float
typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32)));
typedef float float32_t __attribute__((ext_vector_type(32)));
// float16
// bfloat16
// bfloat16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
// float16
// fp16
using half_t = half;
using half2_t = half2;
using half2_t = half2;
using half4_t = float2;
using half4_t = float2;
...
@@ -93,37 +91,37 @@ struct vector_type<float, 4>
...
@@ -93,37 +91,37 @@ struct vector_type<float, 4>
};
};
template <>
template <>
struct vector_type<half, 1>
struct vector_type<half
_t
, 1>
{
{
using MemoryType = half;
using MemoryType = half
_t
;
template <index_t I>
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
__host__ __device__ static void SetScalar(MemoryType& v, half
_t
s, Number<I>)
{
{
static_assert(I < 1, "wrong");
static_assert(I < 1, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s;
*(reinterpret_cast<half
_t
*>(&v) + I) = s;
}
}
};
};
template <>
template <>
struct vector_type<half, 2>
struct vector_type<half
_t
, 2>
{
{
using MemoryType = half2_t;
using MemoryType = half2_t;
union DataType
union DataType
{
{
MemoryType vector;
MemoryType vector;
half scalar[2];
half
_t
scalar[2];
};
};
template <index_t I>
template <index_t I>
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
__host__ __device__ static void SetScalar(MemoryType& v, half
_t
s, Number<I>)
{
{
static_assert(I < 2, "wrong");
static_assert(I < 2, "wrong");
*(reinterpret_cast<half*>(&v) + I) = s;
*(reinterpret_cast<half
_t
*>(&v) + I) = s;
}
}
__host__ __device__ static MemoryType Pack(half s0, half s1)
__host__ __device__ static MemoryType Pack(half
_t
s0, half
_t
s1)
{
{
DataType data;
DataType data;
data.scalar[0] = s0;
data.scalar[0] = s0;
...
@@ -152,8 +150,8 @@ struct inner_product_with_conversion
...
@@ -152,8 +150,8 @@ struct inner_product_with_conversion
__device__ T operator()(half2_t a, half2_t b) const
__device__ T operator()(half2_t a, half2_t b) const
{
{
const half* p_a_half = reinterpret_cast<const half*>(&a);
const half
_t
* p_a_half = reinterpret_cast<const half
_t
*>(&a);
const half* p_b_half = reinterpret_cast<const half*>(&b);
const half
_t
* p_b_half = reinterpret_cast<const half
_t
*>(&b);
T acc = 0;
T acc = 0;
for(index_t v = 0; v < 2; ++v)
for(index_t v = 0; v < 2; ++v)
...
@@ -166,8 +164,8 @@ struct inner_product_with_conversion
...
@@ -166,8 +164,8 @@ struct inner_product_with_conversion
__device__ T operator()(half4_t a, half4_t b) const
__device__ T operator()(half4_t a, half4_t b) const
{
{
const half* p_a_half = reinterpret_cast<const half*>(&a);
const half
_t
* p_a_half = reinterpret_cast<const half
_t
*>(&a);
const half* p_b_half = reinterpret_cast<const half*>(&b);
const half
_t
* p_b_half = reinterpret_cast<const half
_t
*>(&b);
T acc = 0;
T acc = 0;
for(index_t v = 0; v < 4; ++v)
for(index_t v = 0; v < 4; ++v)
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
8d15144c
...
@@ -65,9 +65,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -65,9 +65,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThread
SubC
= 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread
SubC
= 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread
Loop
= 1;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
...
@@ -104,9 +104,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -104,9 +104,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -143,9 +143,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -143,9 +143,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -182,9 +182,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -182,9 +182,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -222,9 +222,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -222,9 +222,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -261,9 +261,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -261,9 +261,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -300,9 +300,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -300,9 +300,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -339,9 +339,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -339,9 +339,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -378,9 +378,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -378,9 +378,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -417,9 +417,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -417,9 +417,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -456,9 +456,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -456,9 +456,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -495,9 +495,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -495,9 +495,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -534,9 +534,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -534,9 +534,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -573,9 +573,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -573,9 +573,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -612,9 +612,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -612,9 +612,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -651,9 +651,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -651,9 +651,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -690,9 +690,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -690,9 +690,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -729,9 +729,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -729,9 +729,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
2
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -761,7 +761,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -761,7 +761,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
#endif
#endif
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThread
SubC
;
constexpr
index_t
N2
=
GemmNPerThread
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
...
@@ -788,13 +788,13 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -788,13 +788,13 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
KPerBlock
,
KPerBlock
,
EPerBlock
,
EPerBlock
,
GemmNRepeat
,
GemmNRepeat
,
GemmMPerThreadSubC
,
GemmMPerThread
,
GemmNPerThreadSubC
,
GemmNPerThread
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
,
GemmDataPerReadB
,
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopySubLengths_E_N1_B_N2
,
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated.hpp
deleted
100644 → 0
View file @
7fde99f4
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
>
void
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_deprecated
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_nkhw_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 0
// BlockSize = 256, blockwise-GEMM 128x128, each thread hold 64 data
constexpr index_t BlockSize = 256;
constexpr index_t BPerBlock = 16;
constexpr index_t KPerBlock = 128;
constexpr index_t EPerBlock = 8;
constexpr index_t GemmNRepeat = 2;
constexpr index_t GemmMPerThreadSubC = 4;
constexpr index_t GemmNPerThreadSubC = 4;
constexpr index_t GemmMLevel0Cluster = 4;
constexpr index_t GemmNLevel0Cluster = 4;
constexpr index_t GemmMLevel1Cluster = 4;
constexpr index_t GemmNLevel1Cluster = 4;
constexpr index_t GemmKPerThreadLoop = 1;
constexpr index_t GemmDataPerReadA = 4;
constexpr index_t GemmDataPerReadB = 4;
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
#elif
0
// BlockSize = 256, EPerBlock = 16, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
16
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
16
,
1
,
16
,
1
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [E, B, N1, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
2
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
4
,
64
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 0
// BlockSize = 64, blockwise-GEMM 64x64, each thread hold 64 data
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
BPerBlock
=
8
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
1
,
8
,
1
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [E, B, N1, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
2
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
32
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 0
// BlockSize = 256, blockwise-GEMM 64x128, each thread hold 32 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
64
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
2
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
2
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [E, B, N1, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
2
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
4
,
64
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
2
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 1
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
EPerBlock
=
4
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
1
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
1
,
16
,
1
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [E, B, N1, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
1
,
2
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
4
,
16
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
2
;
#endif
constexpr
index_t
N1
=
GemmNRepeat
;
constexpr
index_t
N2
=
GemmNPerThreadSubC
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
)
/
(
N1
*
N2
);
constexpr
index_t
GridSize
=
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
ConvolutionDirection
::
Forward
,
BPerBlock
,
KPerBlock
,
EPerBlock
,
GemmNRepeat
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopySubLengths_E_N1_B_N2
,
InBlockCopyClusterLengths_E_N1_B_N2
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopySrcDataPerRead_B
,
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
8d15144c
...
@@ -62,9 +62,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -62,9 +62,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmNPerBlock = 256;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmKPerBlock = 8;
constexpr index_t GemmMPerThread
SubC
= 4;
constexpr index_t GemmMPerThread = 4;
constexpr index_t GemmNPerThread
SubC
= 4;
constexpr index_t GemmNPerThread = 4;
constexpr index_t GemmKPerThread
Loop
= 1;
constexpr index_t GemmKPerThread = 1;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmMLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
constexpr index_t GemmNLevel0Cluster = 2;
...
@@ -95,9 +95,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -95,9 +95,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -128,9 +128,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -128,9 +128,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -162,9 +162,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -162,9 +162,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -195,9 +195,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -195,9 +195,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -230,9 +230,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -230,9 +230,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -265,9 +265,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -265,9 +265,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -298,9 +298,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -298,9 +298,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -333,9 +333,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -333,9 +333,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -366,9 +366,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -366,9 +366,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -401,9 +401,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -401,9 +401,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -434,9 +434,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -434,9 +434,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -467,9 +467,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -467,9 +467,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -502,9 +502,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -502,9 +502,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -535,9 +535,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -535,9 +535,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -570,9 +570,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -570,9 +570,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -603,9 +603,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -603,9 +603,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -636,9 +636,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -636,9 +636,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -669,9 +669,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -669,9 +669,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -702,9 +702,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -702,9 +702,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
2
;
constexpr
index_t
GemmKPerBlock
=
2
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -735,9 +735,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -735,9 +735,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmKPerBlock
=
4
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -768,9 +768,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -768,9 +768,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
3
;
constexpr
index_t
GemmKPerBlock
=
3
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -801,9 +801,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -801,9 +801,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
3
;
constexpr
index_t
GemmKPerBlock
=
3
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -834,9 +834,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -834,9 +834,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmNPerBlock
=
64
;
constexpr
index_t
GemmKPerBlock
=
3
;
constexpr
index_t
GemmKPerBlock
=
3
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
...
@@ -867,9 +867,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -867,9 +867,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
SubC
=
4
;
constexpr
index_t
GemmMPerThread
=
4
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -900,9 +900,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -900,9 +900,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThread
SubC
=
2
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -933,9 +933,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -933,9 +933,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThread
SubC
=
2
;
constexpr
index_t
GemmMPerThread
=
2
;
constexpr
index_t
GemmNPerThread
SubC
=
4
;
constexpr
index_t
GemmNPerThread
=
4
;
constexpr
index_t
GemmKPerThread
Loop
=
1
;
constexpr
index_t
GemmKPerThread
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
...
@@ -983,9 +983,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -983,9 +983,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
GemmMPerBlock
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmKPerBlock
,
GemmMPerThread
SubC
,
GemmMPerThread
,
GemmNPerThread
SubC
,
GemmNPerThread
,
GemmKPerThread
Loop
,
GemmKPerThread
,
GemmMLevel0Cluster
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated.hpp
deleted
100644 → 0
View file @
7fde99f4
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated.hpp"
using
namespace
ck
;
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
>
void
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_deprecated
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
ck
::
index_t
nrepeat
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_nkhw_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 1
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_B
=
Sequence
<
4
,
1
>
;
using
InBlockCopyClusterLengths_E_B
=
Sequence
<
2
,
128
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
InBlockCopyDataPerAccess_B
=
1
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_B
=
1
;
#elif 1
// 1x1 filter, 8x8 image
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_B
=
Sequence
<
1
,
4
>
;
using
InBlockCopyClusterLengths_E_B
=
Sequence
<
8
,
32
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
InBlockCopyDataPerAccess_B
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_B
=
4
;
#elif 0
// 1x1 filter, 14x14 image
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_B
=
Sequence
<
2
,
2
>
;
using
InBlockCopyClusterLengths_E_B
=
Sequence
<
4
,
64
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
InBlockCopyDataPerAccess_B
=
2
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_B
=
2
;
#endif
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
constexpr
index_t
GridSize
=
((
B
+
BPerBlock
-
1
)
/
BPerBlock
)
*
((
K
+
KPerBlock
-
1
)
/
KPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
#if 0
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
#else
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated
#endif
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
BPerBlock
,
KPerBlock
,
EPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopyDataPerAccess_B
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
,
OutThreadCopyDataPerAccess_B
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/src/conv_driver.cpp
View file @
8d15144c
...
@@ -13,14 +13,9 @@
...
@@ -13,14 +13,9 @@
#include "conv_common.hpp"
#include "conv_common.hpp"
#include "host_conv.hpp"
#include "host_conv.hpp"
#include "device_tensor.hpp"
#include "device_tensor.hpp"
//#include "device_convolution_direct_v2_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn_padded.hpp"
//#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_fp16.hpp"
//#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp"
...
@@ -611,7 +606,7 @@ int main(int argc, char* argv[])
...
@@ -611,7 +606,7 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif
1
#elif
0
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
@@ -623,6 +618,18 @@ int main(int argc, char* argv[])
...
@@ -623,6 +618,18 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif 1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_fp16
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#elif 1
#elif 1
device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment