Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
82a15a27
"...resnet50_tensorflow.git" did not exist on "00fa8b123a1c2bb7a5fd2fd126793acdbfc31245"
Commit
82a15a27
authored
Apr 16, 2020
by
Jing Zhang
Browse files
add xdlops emulation on v100
parent
e69b1970
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
2359 additions
and
475 deletions
+2359
-475
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
..._convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+164
-0
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+109
-0
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+14
-0
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops.hpp
..._kernel/include/tensor_operation/gridwise_gemm_xdlops.hpp
+656
-0
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+955
-0
composable_kernel/include/utility/amd_xdlops_emulate.hpp
composable_kernel/include/utility/amd_xdlops_emulate.hpp
+189
-0
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+3
-0
composable_kernel/include/utility/float_type.nvidia.hpp.in
composable_kernel/include/utility/float_type.nvidia.hpp.in
+13
-0
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+40
-1
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+34
-1
driver/include/device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
..._convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+178
-0
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+4
-473
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
82a15a27
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
LeftPads
,
class
RightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
class
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
class
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmK
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
class
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
class
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
>
struct
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GemmM
=
K
;
constexpr
index_t
GemmK
=
C
*
Y
*
X
;
constexpr
index_t
GemmN
=
N
*
Ho
*
Wo
;
static_assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
GemmBBlockCopySrcDataPerRead_GemmN
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
GemmBBlockCopySrcDataPerRead_GemmN
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// input tensor
// global mem
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
constexpr
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalCXdlops_v1
<
GridSize
,
BlockSize
,
Float
,
AccDataType
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_gemmn_global_desc
),
decltype
(
out_gemmm_gemmn_global_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>
,
0
,
GemmABlockCopySrcDataPerRead_GemmK
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
InMemoryDataOperation
::
Set
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_in_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
0 → 100644
View file @
82a15a27
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "xdlops_gemm.hpp"
#include "threadwise_gemm.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
Float
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmMWaves
,
index_t
GemmNWaves
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
>
struct
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
//static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{};
index_t
mMyWaveOffsetA
;
index_t
mMyWaveOffsetB
;
static
constexpr
index_t
WaveSize
=
64
;
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
GetOutputLayout
();
}
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
()
{
static_assert
(
BlockMatrixA
::
NRow
()
==
BlockMatrixB
::
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
static_assert
(
GemmMPerWave
*
GemmMWaves
==
M
,
"GemmMWaves * GemmMPerWave != M"
);
static_assert
(
GemmNPerWave
*
GemmNWaves
==
N
,
"GemmNWaves * GemmNPerWave != N"
);
static_assert
(
BlockSize
==
GemmMWaves
*
GemmNWaves
*
WaveSize
,
"BlockSize != GemmMWaves * GemmNWaves * WaveSize
\n
"
);
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
index_t
waveId_m
=
waveId
/
GemmNWaves
;
const
index_t
waveId_n
=
waveId
%
GemmNWaves
;
mMyWaveOffsetA
=
waveId_m
*
GemmMPerWave
;
mMyWaveOffsetB
=
waveId_n
*
GemmNPerWave
;
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
constexpr
index_t
K
=
BlockMatrixA
::
NRow
();
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
template
Run
<
M
,
N
,
K
>(
&
p_a_block
[
mMyWaveOffsetA
],
&
p_b_block
[
mMyWaveOffsetB
],
p_c_thread
);
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
i
)
{
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
GetBeginOfThreadBlk
(
i
);
const
index_t
col
=
waveId
%
GemmNWaves
*
GemmNPerWave
+
thread_mtx_on_blk
.
col
;
const
index_t
row
=
waveId
/
GemmNWaves
*
GemmMPerWave
+
thread_mtx_on_blk
.
row
;
return
MatrixIndex
{
row
,
col
};
}
__device__
constexpr
auto
GetThreadMatrixCDescriptor
()
const
{
const
index_t
reg_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
return
make_ConstantMatrixDescriptor_packed
(
Number
<
reg_size
>
{},
Number
<
1
>
{});
}
__device__
void
XdlopsMatrixCSetZero
()
const
{
constexpr
auto
thread_mtx_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
SetZeroXdlopsRegs
(
Number
<
thread_mtx_size
>
{});
}
template
<
class
FloatC
>
__device__
void
XdlopsMatrixCRead
(
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
auto
thread_mtx_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
ReadXdlopsRegs
(
Number
<
thread_mtx_size
>
{},
p_c_thread
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
82a15a27
...
...
@@ -56,8 +56,10 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
#if 0
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
"wrong! BlockSize not consistent with ThreadClusterLengths");
#endif
const
auto
thread_cluster_id
=
thread_cluster_desc
.
CalculateClusterIndex
(
get_thread_local_1d_id
());
...
...
@@ -83,6 +85,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr
bool
has_optimized_address_calculation
=
decltype
(
mThreadwiseStore
)
::
HasWorkingOptimizedAddressCalculation
();
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
if
(
get_thread_local_1d_id
()
<
thread_cluster_desc
.
GetElementSize
())
{
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
{
...
...
@@ -92,6 +99,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
{
mThreadwiseLoad
.
Run
(
p_block_src
,
p_thread_buffer
);
}
}
}
template
<
typename
ThreadBufferData
,
typename
BlockDstData
>
...
...
@@ -101,6 +109,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr
bool
has_optimized_address_calculation
=
decltype
(
mThreadwiseStore
)
::
HasWorkingOptimizedAddressCalculation
();
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
if
(
get_thread_local_1d_id
()
<
thread_cluster_desc
.
GetElementSize
())
{
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
{
...
...
@@ -110,6 +123,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
{
mThreadwiseStore
.
Run
(
p_thread_buffer
,
p_block_dst
);
}
}
}
template
<
typename
BlockSrcData
,
typename
BlockDstData
>
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops.hpp
0 → 100644
View file @
82a15a27
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
0 → 100644
View file @
82a15a27
This diff is collapsed.
Click to expand it.
composable_kernel/include/utility/amd_xdlops_emulate.hpp
0 → 100644
View file @
82a15a27
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
namespace
ck
{
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
(
const
float
&
,
const
float
&
,
float32_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
32
;
i
++
)
{
reg_c_
[
i
]
+=
reg_a
*
reg_b
;
reg_c_
[
i
+
32
]
=
reg_c
[
i
];
}
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
32
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
{
reg_c_
[
i
]
+=
reg_a
*
reg_b
;
reg_c_
[
i
+
16
]
=
reg_c
[
i
];
}
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
32
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
{
reg_c_
[
i
]
+=
reg_a
*
reg_b
;
reg_c_
[
i
+
16
]
=
reg_c
[
i
];
}
}
__device__
void
gcnasm_mfma_f32_32x32x2f32
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
{
reg_c_
[
i
]
+=
reg_a
*
reg_b
;
}
}
__device__
void
gcnasm_mfma_f32_16x16x4f32
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
(
const
float
&
,
const
float
&
,
float16_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
16
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
64
,
16
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
4
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
8
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
(
const
half4_t
&
,
const
half4_t
&
,
float32_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
32
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
32
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_32x32x8f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_16x16x16f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
16
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
64
,
16
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
4
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
8
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
(
const
ushort2_t
&
,
const
ushort2_t
&
,
float32_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
32
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
32
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_32x32x4bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_16x16x8bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
16
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
64
,
16
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
4
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
8
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
// clang-format on
}
#endif
composable_kernel/include/utility/common_header.hpp
View file @
82a15a27
...
...
@@ -27,6 +27,9 @@
#if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp"
#else
#include "amd_xdlops_emulate.hpp"
#endif
#endif
composable_kernel/include/utility/float_type.nvidia.hpp.in
View file @
82a15a27
...
...
@@ -13,6 +13,19 @@ namespace ck {
using float2_t = float2;
using float4_t = float4;
// float
typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32)));
// float16
typedef float half4_t __attribute__((ext_vector_type(2)));
typedef float half8_t __attribute__((ext_vector_type(4)));
// bfloat16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
// float16
using half2_t = half2;
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
82a15a27
...
...
@@ -522,7 +522,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif
1
#elif
0
// cdata = 64, BlockSize = 32, 32x64x3
constexpr
index_t
BlockSize
=
32
;
...
...
@@ -559,6 +559,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 1
// cdata = 64, BlockSize = 64, 32x128x3
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
EPerBlock
=
3
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
3
,
1
,
1
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
16
,
2
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [E, B, N1, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
2
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
3
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
1
,
32
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 0
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
82a15a27
...
...
@@ -758,7 +758,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif
1
#elif
0
// cdata = 64, BlockSize = 32, 32x64x3
constexpr
index_t
BlockSize
=
32
;
...
...
@@ -790,6 +790,39 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
2
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 64, 32x128x3
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
32
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
3
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
3
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
1
,
32
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmK
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
3
,
2
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
2
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 64, 64x64x3
...
...
driver/include/device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
82a15a27
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
>
void
device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
make_native_tensor_descriptor
(
InDesc
::
GetLengths
(),
InDesc
::
GetStrides
());
constexpr
auto
wei_kcyx_desc
=
make_native_tensor_descriptor
(
WeiDesc
::
GetLengths
(),
WeiDesc
::
GetStrides
());
constexpr
auto
out_nkhw_desc
=
make_native_tensor_descriptor
(
OutDesc
::
GetLengths
(),
OutDesc
::
GetStrides
());
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_nkhw_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
// cdata = 64, BlockSize = 256, 128x128x16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
ThreadGemmDataPerReadM
=
1
;
constexpr
index_t
ThreadGemmDataPerReadN
=
1
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmK
=
4
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
2
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmM
=
K
;
constexpr
index_t
GemmN
=
N
*
Ho
*
Wo
;
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
ThreadGemmDataPerReadM
,
ThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmK
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
>
{};
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
}
// warm up
printf
(
"Warn up running %d times...
\n
"
,
nrepeat
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
cudaDeviceSynchronize
();
auto
start
=
std
::
chrono
::
steady_clock
::
now
();
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
cudaDeviceSynchronize
();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
float
ave_time
=
std
::
chrono
::
duration
<
float
,
std
::
milli
>
(
end
-
start
).
count
()
/
nrepeat
;
printf
(
"Average elapsed time : %f ms, %f TFlop/s
\n
"
,
ave_time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
);
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/src/conv_driver.cpp
View file @
82a15a27
...
...
@@ -20,321 +20,18 @@
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
;
#if 0
// 1x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
0
// 1x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
160
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
96
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 71x71
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
71
;
constexpr
index_t
WI
=
71
;
constexpr
index_t
K
=
192
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 7x1, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
320
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 0
// 1x7, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
224
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
224
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif 1
// 3x3, 299x299 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
3
;
constexpr
index_t
HI
=
299
;
constexpr
index_t
WI
=
299
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 147x147
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
WI
=
147
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 149x149
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
HI
=
149
;
constexpr
index_t
WI
=
149
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 17x17, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
192
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 35x35
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x3, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
448
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
1
>
;
using
RightPads
=
Sequence
<
0
,
1
>
;
#elif 0
// 3x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
448
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
#elif 0
// 3x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
448
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
#elif 1
// 3x3, 147x147
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
WI
=
147
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 7x1, 73x73
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 1
// 3x3, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
256
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
...
...
@@ -343,172 +40,6 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 28x28
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 14x14
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
// 1x1, 56x56, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 7x7, 230x230 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
3
;
constexpr
index_t
HI
=
230
;
constexpr
index_t
WI
=
230
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 28x28, stride = 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 28x28, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
// 1x1, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 1
// 3x3, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#endif
auto
in_nchw_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
wei_kcyx_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
...
...
@@ -603,7 +134,7 @@ int main(int argc, char* argv[])
RightPads
{},
nrepeat
);
#elif 1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_
xdlops_
nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment