Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
df6dd915
Commit
df6dd915
authored
Apr 16, 2020
by
Jing Zhang
Browse files
formating
parent
e9f05865
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
991 additions
and
472 deletions
+991
-472
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
..._convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+0
-2
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+18
-8
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+19
-18
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+40
-39
composable_kernel/include/utility/amd_xdlops_emulate.hpp
composable_kernel/include/utility/amd_xdlops_emulate.hpp
+64
-33
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+0
-1
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+70
-71
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+217
-218
driver/include/device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
..._convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+74
-74
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+489
-8
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
df6dd915
...
@@ -158,7 +158,5 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
...
@@ -158,7 +158,5 @@ struct GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
}
}
};
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
View file @
df6dd915
...
@@ -26,17 +26,22 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
...
@@ -26,17 +26,22 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
index_t
col
;
index_t
col
;
};
};
//static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{};
// static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave,
// GemmDataPerReadA, GemmDataPerReadB>{};
index_t
mMyWaveOffsetA
;
index_t
mMyWaveOffsetA
;
index_t
mMyWaveOffsetB
;
index_t
mMyWaveOffsetB
;
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
WaveSize
=
64
;
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
GetOutputLayout
();
}
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
GetOutputLayout
();
}
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
()
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
()
{
{
static_assert
(
BlockMatrixA
::
NRow
()
==
BlockMatrixB
::
NRow
(),
static_assert
(
BlockMatrixA
::
NRow
()
==
BlockMatrixB
::
NRow
(),
"wrong! K dimension not consistent
\n
"
);
"wrong! K dimension not consistent
\n
"
);
...
@@ -67,8 +72,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
...
@@ -67,8 +72,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
constexpr
index_t
K
=
BlockMatrixA
::
NRow
();
constexpr
index_t
K
=
BlockMatrixA
::
NRow
();
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
template
Run
<
M
,
N
,
K
>(
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
&
p_a_block
[
mMyWaveOffsetA
],
&
p_b_block
[
mMyWaveOffsetB
],
p_c_thread
);
.
template
Run
<
M
,
N
,
K
>(
&
p_a_block
[
mMyWaveOffsetA
],
&
p_b_block
[
mMyWaveOffsetB
],
p_c_thread
);
}
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
i
)
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
i
)
...
@@ -76,7 +82,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
...
@@ -76,7 +82,9 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
GetBeginOfThreadBlk
(
i
);
const
auto
thread_mtx_on_blk
=
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
GetBeginOfThreadBlk
(
i
);
const
index_t
col
=
waveId
%
GemmNWaves
*
GemmNPerWave
+
thread_mtx_on_blk
.
col
;
const
index_t
col
=
waveId
%
GemmNWaves
*
GemmNPerWave
+
thread_mtx_on_blk
.
col
;
...
@@ -94,14 +102,16 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
...
@@ -94,14 +102,16 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
__device__
void
XdlopsMatrixCSetZero
()
const
__device__
void
XdlopsMatrixCSetZero
()
const
{
{
constexpr
auto
thread_mtx_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
constexpr
auto
thread_mtx_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
SetZeroXdlopsRegs
(
Number
<
thread_mtx_size
>
{});
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
SetZeroXdlopsRegs
(
Number
<
thread_mtx_size
>
{});
}
}
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
void
XdlopsMatrixCRead
(
FloatC
*
__restrict__
p_c_thread
)
const
__device__
void
XdlopsMatrixCRead
(
FloatC
*
__restrict__
p_c_thread
)
const
{
{
constexpr
auto
thread_mtx_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
constexpr
auto
thread_mtx_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
ReadXdlopsRegs
(
Number
<
thread_mtx_size
>
{},
p_c_thread
);
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}
.
ReadXdlopsRegs
(
Number
<
thread_mtx_size
>
{},
p_c_thread
);
}
}
};
};
...
...
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
df6dd915
...
@@ -90,15 +90,15 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -90,15 +90,15 @@ struct BlockwiseGenericTensorSliceCopy_v4
if
(
get_thread_local_1d_id
()
<
thread_cluster_desc
.
GetElementSize
())
if
(
get_thread_local_1d_id
()
<
thread_cluster_desc
.
GetElementSize
())
{
{
// TODO: threadwise copy is still being tweaked
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
if
(
has_optimized_address_calculation
)
{
{
mThreadwiseLoad
.
Run_optimized_src_address_calculation
(
p_block_src
,
p_thread_buffer
);
mThreadwiseLoad
.
Run_optimized_src_address_calculation
(
p_block_src
,
p_thread_buffer
);
}
}
else
else
{
{
mThreadwiseLoad
.
Run
(
p_block_src
,
p_thread_buffer
);
mThreadwiseLoad
.
Run
(
p_block_src
,
p_thread_buffer
);
}
}
}
}
}
}
...
@@ -114,15 +114,16 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -114,15 +114,16 @@ struct BlockwiseGenericTensorSliceCopy_v4
if
(
get_thread_local_1d_id
()
<
thread_cluster_desc
.
GetElementSize
())
if
(
get_thread_local_1d_id
()
<
thread_cluster_desc
.
GetElementSize
())
{
{
// TODO: threadwise copy is still being tweaked
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
if
(
has_optimized_address_calculation
)
{
{
mThreadwiseStore
.
Run_optimized_dst_address_calculation
(
p_thread_buffer
,
p_block_dst
);
mThreadwiseStore
.
Run_optimized_dst_address_calculation
(
p_thread_buffer
,
}
p_block_dst
);
else
}
{
else
mThreadwiseStore
.
Run
(
p_thread_buffer
,
p_block_dst
);
{
}
mThreadwiseStore
.
Run
(
p_thread_buffer
,
p_block_dst
);
}
}
}
}
}
...
...
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
View file @
df6dd915
...
@@ -497,174 +497,171 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
...
@@ -497,174 +497,171 @@ struct mfma_info<mfma_instr::mfma_f32_4x4x2bf16>
}
}
};
};
template
<
class
data_type
,
template
<
class
data_type
,
index_t
MPerWave
,
index_t
NPerWave
>
index_t
MPerWave
,
__device__
constexpr
auto
GetMFMAInfo
();
index_t
NPerWave
>
__device__
constexpr
auto
GetMFMAInfo
();
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
32
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
float
,
32
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
64
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
float
,
64
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
64
,
32
>
()
__device__
constexpr
auto
GetMFMAInfo
<
float
,
64
,
32
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x1xf32
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
32
,
32
>
()
__device__
constexpr
auto
GetMFMAInfo
<
float
,
32
,
32
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2xf32
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
16
,
16
>
()
__device__
constexpr
auto
GetMFMAInfo
<
float
,
16
,
16
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4xf32
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
16
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
float
,
16
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
64
,
16
>
()
__device__
constexpr
auto
GetMFMAInfo
<
float
,
64
,
16
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x1xf32
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
8
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
float
,
8
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
float
,
4
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
float
,
4
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x1xf32
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
64
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
,
64
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
64
,
32
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
,
64
,
32
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
32
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
,
32
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
32
,
32
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
,
32
,
32
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x8f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x8f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
16
,
16
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
,
16
,
16
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x16f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x16f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
16
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
,
16
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
64
,
16
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
,
64
,
16
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
4
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
,
4
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
half
,
8
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
half
,
8
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x4f16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
64
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
64
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
64
,
32
>
()
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
64
,
32
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
32
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
32
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x2bf16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
32
,
32
>
()
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
32
,
32
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4bf16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_32x32x4bf16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
16
,
16
>
()
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
16
,
16
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x8bf16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x8bf16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
16
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
16
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
64
,
16
>
()
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
64
,
16
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_16x16x2bf16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
4
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
4
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
>
{};
}
}
template
<
>
template
<
>
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
8
,
64
>
()
__device__
constexpr
auto
GetMFMAInfo
<
ushort
,
8
,
64
>
()
{
{
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
>
{};
return
mfma_info
<
mfma_instr
::
mfma_f32_4x4x2bf16
>
{};
}
}
template
<
class
data_type
,
template
<
class
data_type
,
index_t
MPerWave
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
NPerWave
,
...
@@ -685,7 +682,10 @@ struct XdlopsGemm_t
...
@@ -685,7 +682,10 @@ struct XdlopsGemm_t
__device__
static
constexpr
index_t
M0
()
{
return
M0_
;
}
__device__
static
constexpr
index_t
M0
()
{
return
M0_
;
}
__device__
static
constexpr
index_t
N1
()
{
return
N1_
;
}
__device__
static
constexpr
index_t
N1
()
{
return
N1_
;
}
__device__
static
constexpr
index_t
N0
()
{
return
N0_
;
}
__device__
static
constexpr
index_t
N0
()
{
return
N0_
;
}
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
().
num_regs_blk
;
}
__device__
static
constexpr
index_t
GetBlkSize
()
{
return
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
().
num_regs_blk
;
}
__device__
static
constexpr
index_t
GetNumBlks
()
__device__
static
constexpr
index_t
GetNumBlks
()
{
{
...
@@ -726,7 +726,6 @@ struct XdlopsGemm_t
...
@@ -726,7 +726,6 @@ struct XdlopsGemm_t
return
mfma_type
.
num_output_blks
==
1
&&
mfma_type
.
num_input_blks
!=
1
;
return
mfma_type
.
num_output_blks
==
1
&&
mfma_type
.
num_input_blks
!=
1
;
}
}
#if CK_USE_AMD_XDLOPS_EMULATE
#if CK_USE_AMD_XDLOPS_EMULATE
// emulate xdlops
// emulate xdlops
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
M
,
index_t
N
,
index_t
K
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
...
@@ -843,7 +842,8 @@ struct XdlopsGemm_t
...
@@ -843,7 +842,8 @@ struct XdlopsGemm_t
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
for
(
index_t
k
=
0
;
k
<
K
;
++
k
)
{
{
mfma_type
.
run
(
Number
<
MPerWave
>
{},
Number
<
NPerWave
>
{},
p_a_wave
,
p_b_wave
,
p_c_thread
);
mfma_type
.
run
(
Number
<
MPerWave
>
{},
Number
<
NPerWave
>
{},
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
}
}).
Else
([
&
](
auto
)
{
}).
Else
([
&
](
auto
)
{
...
@@ -852,7 +852,8 @@ struct XdlopsGemm_t
...
@@ -852,7 +852,8 @@ struct XdlopsGemm_t
for
(
index_t
k
=
0
;
k
<
K
;
k
+=
mfma_type
.
num_input_blks
)
for
(
index_t
k
=
0
;
k
<
K
;
k
+=
mfma_type
.
num_input_blks
)
{
{
mfma_type
.
run
(
Number
<
MPerWave
>
{},
Number
<
NPerWave
>
{},
p_a_wave
,
p_b_wave
,
p_c_thread
);
mfma_type
.
run
(
Number
<
MPerWave
>
{},
Number
<
NPerWave
>
{},
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
}
});
});
...
@@ -898,7 +899,7 @@ struct XdlopsGemm_t
...
@@ -898,7 +899,7 @@ struct XdlopsGemm_t
__device__
void
SetZeroXdlopsRegs
(
Number
<
Size
>
)
const
__device__
void
SetZeroXdlopsRegs
(
Number
<
Size
>
)
const
{
{
#if !CK_USE_AMD_XDLOPS_EMULATE
#if !CK_USE_AMD_XDLOPS_EMULATE
//gcnasm_accvgpr_zero<Size>();
//
gcnasm_accvgpr_zero<Size>();
#endif
#endif
}
}
...
@@ -907,8 +908,8 @@ struct XdlopsGemm_t
...
@@ -907,8 +908,8 @@ struct XdlopsGemm_t
{
{
#if !CK_USE_AMD_XDLOPS_EMULATE
#if !CK_USE_AMD_XDLOPS_EMULATE
constexpr
auto
mfma_type
=
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
();
constexpr
auto
mfma_type
=
GetMFMAInfo
<
data_type
,
MPerWave
,
NPerWave
>
();
//gcnasm_nop<mfma_type.cycles>();
//
gcnasm_nop<mfma_type.cycles>();
//gcnasm_accvgpr_read<Size>(p_c_thread);
//
gcnasm_accvgpr_read<Size>(p_c_thread);
#else
#else
(
void
)
p_c_thread
;
(
void
)
p_c_thread
;
#endif
#endif
...
...
composable_kernel/include/utility/amd_xdlops_emulate.hpp
View file @
df6dd915
...
@@ -7,7 +7,8 @@ template <index_t MPerWave, index_t NPerWave>
...
@@ -7,7 +7,8 @@ template <index_t MPerWave, index_t NPerWave>
__device__
void
gcnasm_mfma_f32_32x32x1f32
(
const
float
&
,
const
float
&
,
float32_t
*
);
__device__
void
gcnasm_mfma_f32_32x32x1f32
(
const
float
&
,
const
float
&
,
float32_t
*
);
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
32
;
i
++
)
for
(
index_t
i
=
0
;
i
<
32
;
i
++
)
...
@@ -17,7 +18,8 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const flo
...
@@ -17,7 +18,8 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<64, 64>(const float& reg_a, const flo
}
}
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
32
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
32
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
...
@@ -27,7 +29,8 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const flo
...
@@ -27,7 +29,8 @@ __device__ void gcnasm_mfma_f32_32x32x1f32<32, 64>(const float& reg_a, const flo
}
}
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
32
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
32
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
...
@@ -53,12 +56,14 @@ template <index_t MPerWave, index_t NPerWave>
...
@@ -53,12 +56,14 @@ template <index_t MPerWave, index_t NPerWave>
__device__
void
gcnasm_mfma_f32_16x16x1f32
(
const
float
&
,
const
float
&
,
float16_t
*
);
__device__
void
gcnasm_mfma_f32_16x16x1f32
(
const
float
&
,
const
float
&
,
float16_t
*
);
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
16
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
16
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
{
}
}
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
64
,
16
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
64
,
16
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
{
}
}
...
@@ -66,66 +71,77 @@ template <index_t MPerWave, index_t NPerWave>
...
@@ -66,66 +71,77 @@ template <index_t MPerWave, index_t NPerWave>
__device__
void
gcnasm_mfma_f32_4x4x1f32
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
);
__device__
void
gcnasm_mfma_f32_4x4x1f32
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
4
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
4
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
{
}
}
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
8
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
8
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
{
}
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
(
const
half4_t
&
,
__device__
void
gcnasm_mfma_f32_32x32x4f16
(
const
half4_t
&
,
const
half4_t
&
,
float32_t
*
);
const
half4_t
&
,
float32_t
*
);
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
{
}
}
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
32
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
32
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
{
}
}
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
32
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
32
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
{
}
}
__device__
void
gcnasm_mfma_f32_32x32x8f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x8f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
{
}
}
__device__
void
gcnasm_mfma_f32_16x16x16f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x16f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
{
}
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
);
__device__
void
gcnasm_mfma_f32_16x16x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
);
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
16
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
16
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
{
}
}
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
64
,
16
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
64
,
16
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
{
}
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
);
__device__
void
gcnasm_mfma_f32_4x4x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
4
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
4
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
{
}
}
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
8
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
8
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
{
}
}
...
@@ -133,54 +149,69 @@ template <index_t MPerWave, index_t NPerWave>
...
@@ -133,54 +149,69 @@ template <index_t MPerWave, index_t NPerWave>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
(
const
ushort2_t
&
,
const
ushort2_t
&
,
float32_t
*
);
__device__
void
gcnasm_mfma_f32_32x32x2bf16
(
const
ushort2_t
&
,
const
ushort2_t
&
,
float32_t
*
);
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
{
}
}
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
32
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
32
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
{
}
}
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
32
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
32
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
{
}
}
__device__
void
gcnasm_mfma_f32_32x32x4bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_32x32x4bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
{
}
}
__device__
void
gcnasm_mfma_f32_16x16x8bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x8bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
{
}
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
);
__device__
void
gcnasm_mfma_f32_16x16x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
);
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
16
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
16
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
{
}
}
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
64
,
16
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
64
,
16
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
{
}
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
);
__device__
void
gcnasm_mfma_f32_4x4x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
4
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
4
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
{
}
}
template
<
>
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
8
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
8
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
{
}
}
// clang-format on
// clang-format on
}
}
#endif
#endif
composable_kernel/include/utility/common_header.hpp
View file @
df6dd915
...
@@ -31,5 +31,4 @@
...
@@ -31,5 +31,4 @@
#include "amd_xdlops_emulate.hpp"
#include "amd_xdlops_emulate.hpp"
#endif
#endif
#endif
#endif
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
df6dd915
...
@@ -111,8 +111,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -111,8 +111,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
2
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
2
,
16
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
2
,
16
,
2
>
;
...
@@ -150,8 +150,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -150,8 +150,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
...
@@ -189,8 +189,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -189,8 +189,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
16
,
1
,
16
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
16
,
1
,
16
,
1
>
;
...
@@ -229,8 +229,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -229,8 +229,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
4
,
1
,
1
,
2
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
4
,
1
,
1
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
2
,
16
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
2
,
16
,
2
>
;
...
@@ -268,8 +268,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -268,8 +268,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
2
,
16
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
2
,
16
,
1
>
;
...
@@ -307,8 +307,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -307,8 +307,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
1
,
16
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
1
,
16
,
1
>
;
...
@@ -346,8 +346,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -346,8 +346,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
2
,
2
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
2
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
1
,
16
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
1
,
16
,
1
>
;
...
@@ -385,8 +385,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -385,8 +385,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
2
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
2
,
8
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
2
,
8
,
2
>
;
...
@@ -424,8 +424,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -424,8 +424,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
8
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
8
,
1
>
;
...
@@ -463,10 +463,10 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -463,10 +463,10 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
16
,
1
,
8
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
16
,
1
,
8
,
1
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [E, B, N1, N2]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [E, B, N1, N2]
...
@@ -502,8 +502,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -502,8 +502,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
1
,
8
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
1
,
8
,
1
>
;
...
@@ -541,8 +541,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -541,8 +541,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
2
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
3
,
1
,
1
,
2
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
3
,
1
,
1
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
8
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
8
,
2
>
;
...
@@ -580,8 +580,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -580,8 +580,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
3
,
1
,
1
,
2
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
3
,
1
,
1
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
16
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
16
,
2
>
;
...
@@ -619,8 +619,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -619,8 +619,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
3
,
1
,
1
,
1
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
3
,
1
,
1
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
8
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
8
,
4
>
;
...
@@ -658,8 +658,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -658,8 +658,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
1
,
16
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
1
,
16
,
1
>
;
...
@@ -697,8 +697,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -697,8 +697,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
2
,
2
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
2
,
2
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
1
,
16
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
4
,
1
,
16
,
1
>
;
...
@@ -736,8 +736,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -736,8 +736,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmDataPerReadA
=
2
;
constexpr
index_t
GemmDataPerReadA
=
2
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
1
,
1
,
1
,
4
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
8
,
2
,
16
,
1
>
;
...
@@ -813,19 +813,19 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -813,19 +813,19 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
{
{
float
time
=
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
}
}
// warm up
// warm up
...
@@ -833,14 +833,14 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -833,14 +833,14 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
}
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
...
@@ -850,26 +850,25 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -850,26 +850,25 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
float
ave_time
=
std
::
chrono
::
duration
<
float
,
std
::
milli
>
(
end
-
start
).
count
()
/
nrepeat
;
float
ave_time
=
std
::
chrono
::
duration
<
float
,
std
::
milli
>
(
end
-
start
).
count
()
/
nrepeat
;
printf
(
"Average elapsed time : %f ms, %f TFlop/s
\n
"
,
printf
(
"Average elapsed time : %f ms, %f TFlop/s
\n
"
,
ave_time
,
ave_time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
);
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
);
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
}
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
df6dd915
This diff is collapsed.
Click to expand it.
driver/include/device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
df6dd915
...
@@ -13,16 +13,16 @@ template <class T,
...
@@ -13,16 +13,16 @@ template <class T,
class
InLeftPads
,
class
InLeftPads
,
class
InRightPads
>
class
InRightPads
>
void
device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
InDesc
,
void
device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvStrides
,
ConvDilations
,
ConvDilations
,
InLeftPads
,
InLeftPads
,
InRightPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
ck
::
index_t
nrepeat
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -58,7 +58,7 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
...
@@ -58,7 +58,7 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
...
@@ -85,50 +85,51 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
...
@@ -85,50 +85,51 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
<
constexpr
auto
gridwise_conv
=
GridSize
,
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
<
BlockSize
,
GridSize
,
T
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
T
,
decltype
(
wei_kcyx_desc
),
decltype
(
in_nchw_desc
),
decltype
(
out_nkhw_desc
),
decltype
(
wei_kcyx_desc
),
ConvStrides
,
decltype
(
out_nkhw_desc
),
ConvDilations
,
ConvStrides
,
InLeftPads
,
ConvDilations
,
InRightPads
,
InLeftPads
,
GemmMPerBlock
,
InRightPads
,
GemmNPerBlock
,
GemmMPerBlock
,
GemmKPerBlock
,
GemmNPerBlock
,
GemmMPerWave
,
GemmKPerBlock
,
GemmNPerWave
,
GemmMPerWave
,
ThreadGemmDataPerReadM
,
GemmNPerWave
,
ThreadGemmDataPerReadN
,
ThreadGemmDataPerReadM
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
ThreadGemmDataPerReadN
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmK
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmABlockCopySrcDataPerRead_GemmK
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
>
{};
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
>
{};
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
{
{
float
time
=
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
}
}
// warm up
// warm up
...
@@ -136,14 +137,14 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
...
@@ -136,14 +137,14 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
}
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
...
@@ -153,26 +154,25 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
...
@@ -153,26 +154,25 @@ void device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(InDesc,
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
}
cudaDeviceSynchronize
();
cudaDeviceSynchronize
();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
float
ave_time
=
std
::
chrono
::
duration
<
float
,
std
::
milli
>
(
end
-
start
).
count
()
/
nrepeat
;
float
ave_time
=
std
::
chrono
::
duration
<
float
,
std
::
milli
>
(
end
-
start
).
count
()
/
nrepeat
;
printf
(
"Average elapsed time : %f ms, %f TFlop/s
\n
"
,
printf
(
"Average elapsed time : %f ms, %f TFlop/s
\n
"
,
ave_time
,
ave_time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
);
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
);
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
}
driver/src/conv_driver.cpp
View file @
df6dd915
...
@@ -20,26 +20,495 @@
...
@@ -20,26 +20,495 @@
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
using
namespace
ck
;
using
namespace
ck
;
#if 0
// 1x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
0
// 1x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
160
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
96
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 71x71
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
71
;
constexpr
index_t
WI
=
71
;
constexpr
index_t
K
=
192
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 7x1, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
320
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 0
// 1x7, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
224
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
224
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif 1
// 3x3, 299x299 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
3
;
constexpr
index_t
HI
=
299
;
constexpr
index_t
WI
=
299
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 147x147
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
WI
=
147
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 149x149
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
HI
=
149
;
constexpr
index_t
WI
=
149
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 17x17, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
192
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 35x35
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x3, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
448
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
1
>
;
using
RightPads
=
Sequence
<
0
,
1
>
;
#elif 0
// 3x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
448
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
#elif 0
// 3x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
448
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
#elif 1
// 3x3, 147x147
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
WI
=
147
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 7x1, 73x73
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 1
// 3x3, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14
// 1x1, 14x14
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 28x28
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
// 3x3, 14x14
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
// 1x1, 56x56, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 7x7, 230x230 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
3
;
constexpr
index_t
HI
=
230
;
constexpr
index_t
WI
=
230
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 28x28, stride = 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 28x28, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
// 1x1, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 1
// 3x3, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#endif
auto
in_nchw_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
in_nchw_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
wei_kcyx_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
auto
wei_kcyx_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
...
@@ -133,8 +602,8 @@ int main(int argc, char* argv[])
...
@@ -133,8 +602,8 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif
1
#elif
0
device_convolution_implicit_gemm_v4r4_
xdlops_
nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
wei_kcyx
,
wei_kcyx
,
...
@@ -145,6 +614,18 @@ int main(int argc, char* argv[])
...
@@ -145,6 +614,18 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif 1
device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#endif
#endif
if
(
do_verification
)
if
(
do_verification
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment