Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
82a15a27
Commit
82a15a27
authored
Apr 16, 2020
by
Jing Zhang
Browse files
add xdlops emulation on v100
parent
e69b1970
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
2359 additions
and
475 deletions
+2359
-475
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
..._convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+164
-0
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
...kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
+109
-0
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+14
-0
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops.hpp
..._kernel/include/tensor_operation/gridwise_gemm_xdlops.hpp
+656
-0
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
+955
-0
composable_kernel/include/utility/amd_xdlops_emulate.hpp
composable_kernel/include/utility/amd_xdlops_emulate.hpp
+189
-0
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+3
-0
composable_kernel/include/utility/float_type.nvidia.hpp.in
composable_kernel/include/utility/float_type.nvidia.hpp.in
+13
-0
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+40
-1
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+34
-1
driver/include/device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
..._convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+178
-0
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+4
-473
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
82a15a27
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_XDLOPS_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
Float
,
class
AccDataType
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
LeftPads
,
class
RightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
class
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
class
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmK
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
class
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
class
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
>
struct
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
GemmM
=
K
;
constexpr
index_t
GemmK
=
C
*
Y
*
X
;
constexpr
index_t
GemmN
=
N
*
Ho
*
Wo
;
static_assert
(
GemmM
%
GemmMPerBlock
==
0
&&
GemmN
%
GemmNPerBlock
==
0
&&
GemmK
%
GemmKPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
GemmBBlockCopySrcDataPerRead_GemmN
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
GemmBBlockCopySrcDataPerRead_GemmN
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// input tensor
// global mem
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
wei_gemmk_gemmm_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
constexpr
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalCXdlops_v1
<
GridSize
,
BlockSize
,
Float
,
AccDataType
,
decltype
(
wei_gemmk_gemmm_global_desc
),
decltype
(
in_gemmk_gemmn_global_desc
),
decltype
(
out_gemmm_gemmn_global_desc
),
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>
,
0
,
GemmABlockCopySrcDataPerRead_GemmK
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
InMemoryDataOperation
::
Set
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_in_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp
0 → 100644
View file @
82a15a27
#ifndef CK_BLOCKWISE_GEMM_XDLOPS_HPP
#define CK_BLOCKWISE_GEMM_XDLOPS_HPP
#include "common_header.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "xdlops_gemm.hpp"
#include "threadwise_gemm.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
Float
,
index_t
GemmMPerWave
,
index_t
GemmNPerWave
,
index_t
GemmMWaves
,
index_t
GemmNWaves
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
>
struct
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
{
struct
MatrixIndex
{
index_t
row
;
index_t
col
;
};
//static constexpr XdlopsGemm_t XdlopsGemm = XdlopsGemm_t<Float, GemmMPerWave, GemmNPerWave, GemmDataPerReadA, GemmDataPerReadB>{};
index_t
mMyWaveOffsetA
;
index_t
mMyWaveOffsetB
;
static
constexpr
index_t
WaveSize
=
64
;
__device__
constexpr
auto
GetOutputLayout
()
const
{
return
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
GetOutputLayout
();
}
__device__
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
()
{
static_assert
(
BlockMatrixA
::
NRow
()
==
BlockMatrixB
::
NRow
(),
"wrong! K dimension not consistent
\n
"
);
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
static_assert
(
GemmMPerWave
*
GemmMWaves
==
M
,
"GemmMWaves * GemmMPerWave != M"
);
static_assert
(
GemmNPerWave
*
GemmNWaves
==
N
,
"GemmNWaves * GemmNPerWave != N"
);
static_assert
(
BlockSize
==
GemmMWaves
*
GemmNWaves
*
WaveSize
,
"BlockSize != GemmMWaves * GemmNWaves * WaveSize
\n
"
);
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
index_t
waveId_m
=
waveId
/
GemmNWaves
;
const
index_t
waveId_n
=
waveId
%
GemmNWaves
;
mMyWaveOffsetA
=
waveId_m
*
GemmMPerWave
;
mMyWaveOffsetB
=
waveId_n
*
GemmNPerWave
;
}
template
<
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
Run
(
const
FloatA
*
__restrict__
p_a_block
,
const
FloatB
*
__restrict__
p_b_block
,
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
index_t
M
=
BlockMatrixA
::
NCol
();
// A is transposed
constexpr
index_t
N
=
BlockMatrixB
::
NCol
();
constexpr
index_t
K
=
BlockMatrixA
::
NRow
();
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
template
Run
<
M
,
N
,
K
>(
&
p_a_block
[
mMyWaveOffsetA
],
&
p_b_block
[
mMyWaveOffsetB
],
p_c_thread
);
}
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
i
)
{
const
index_t
waveId
=
get_thread_local_1d_id
()
/
WaveSize
;
const
auto
thread_mtx_on_blk
=
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
GetBeginOfThreadBlk
(
i
);
const
index_t
col
=
waveId
%
GemmNWaves
*
GemmNPerWave
+
thread_mtx_on_blk
.
col
;
const
index_t
row
=
waveId
/
GemmNWaves
*
GemmMPerWave
+
thread_mtx_on_blk
.
row
;
return
MatrixIndex
{
row
,
col
};
}
__device__
constexpr
auto
GetThreadMatrixCDescriptor
()
const
{
const
index_t
reg_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
return
make_ConstantMatrixDescriptor_packed
(
Number
<
reg_size
>
{},
Number
<
1
>
{});
}
__device__
void
XdlopsMatrixCSetZero
()
const
{
constexpr
auto
thread_mtx_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
SetZeroXdlopsRegs
(
Number
<
thread_mtx_size
>
{});
}
template
<
class
FloatC
>
__device__
void
XdlopsMatrixCRead
(
FloatC
*
__restrict__
p_c_thread
)
const
{
constexpr
auto
thread_mtx_size
=
GemmMPerWave
*
GemmNPerWave
/
WaveSize
;
XdlopsGemm_t
<
Float
,
GemmMPerWave
,
GemmNPerWave
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{}.
ReadXdlopsRegs
(
Number
<
thread_mtx_size
>
{},
p_c_thread
);
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
82a15a27
...
@@ -56,8 +56,10 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -56,8 +56,10 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr
auto
thread_cluster_desc
=
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
#if 0
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
static_assert(BlockSize == thread_cluster_desc.GetElementSize(),
"wrong! BlockSize not consistent with ThreadClusterLengths");
"wrong! BlockSize not consistent with ThreadClusterLengths");
#endif
const
auto
thread_cluster_id
=
const
auto
thread_cluster_id
=
thread_cluster_desc
.
CalculateClusterIndex
(
get_thread_local_1d_id
());
thread_cluster_desc
.
CalculateClusterIndex
(
get_thread_local_1d_id
());
...
@@ -83,6 +85,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -83,6 +85,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr
bool
has_optimized_address_calculation
=
constexpr
bool
has_optimized_address_calculation
=
decltype
(
mThreadwiseStore
)
::
HasWorkingOptimizedAddressCalculation
();
decltype
(
mThreadwiseStore
)
::
HasWorkingOptimizedAddressCalculation
();
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
if
(
get_thread_local_1d_id
()
<
thread_cluster_desc
.
GetElementSize
())
{
// TODO: threadwise copy is still being tweaked
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
if
(
has_optimized_address_calculation
)
{
{
...
@@ -92,6 +99,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -92,6 +99,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
{
{
mThreadwiseLoad
.
Run
(
p_block_src
,
p_thread_buffer
);
mThreadwiseLoad
.
Run
(
p_block_src
,
p_thread_buffer
);
}
}
}
}
}
template
<
typename
ThreadBufferData
,
typename
BlockDstData
>
template
<
typename
ThreadBufferData
,
typename
BlockDstData
>
...
@@ -101,6 +109,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -101,6 +109,11 @@ struct BlockwiseGenericTensorSliceCopy_v4
constexpr
bool
has_optimized_address_calculation
=
constexpr
bool
has_optimized_address_calculation
=
decltype
(
mThreadwiseStore
)
::
HasWorkingOptimizedAddressCalculation
();
decltype
(
mThreadwiseStore
)
::
HasWorkingOptimizedAddressCalculation
();
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
if
(
get_thread_local_1d_id
()
<
thread_cluster_desc
.
GetElementSize
())
{
// TODO: threadwise copy is still being tweaked
// TODO: threadwise copy is still being tweaked
if
(
has_optimized_address_calculation
)
if
(
has_optimized_address_calculation
)
{
{
...
@@ -110,6 +123,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
...
@@ -110,6 +123,7 @@ struct BlockwiseGenericTensorSliceCopy_v4
{
{
mThreadwiseStore
.
Run
(
p_thread_buffer
,
p_block_dst
);
mThreadwiseStore
.
Run
(
p_thread_buffer
,
p_block_dst
);
}
}
}
}
}
template
<
typename
BlockSrcData
,
typename
BlockDstData
>
template
<
typename
BlockSrcData
,
typename
BlockDstData
>
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops.hpp
0 → 100644
View file @
82a15a27
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/xdlops_gemm.hpp
0 → 100644
View file @
82a15a27
This diff is collapsed.
Click to expand it.
composable_kernel/include/utility/amd_xdlops_emulate.hpp
0 → 100644
View file @
82a15a27
#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
namespace
ck
{
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
(
const
float
&
,
const
float
&
,
float32_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
32
;
i
++
)
{
reg_c_
[
i
]
+=
reg_a
*
reg_b
;
reg_c_
[
i
+
32
]
=
reg_c
[
i
];
}
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
32
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
{
reg_c_
[
i
]
+=
reg_a
*
reg_b
;
reg_c_
[
i
+
16
]
=
reg_c
[
i
];
}
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x1f32
<
64
,
32
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float32_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
{
reg_c_
[
i
]
+=
reg_a
*
reg_b
;
reg_c_
[
i
+
16
]
=
reg_c
[
i
];
}
}
__device__
void
gcnasm_mfma_f32_32x32x2f32
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
auto
reg_c_
=
reinterpret_cast
<
float_t
*>
(
reg_c
);
for
(
index_t
i
=
0
;
i
<
16
;
i
++
)
{
reg_c_
[
i
]
+=
reg_a
*
reg_b
;
}
}
__device__
void
gcnasm_mfma_f32_16x16x4f32
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
(
const
float
&
,
const
float
&
,
float16_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
16
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x1f32
<
64
,
16
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
4
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x1f32
<
8
,
64
>
(
const
float
&
reg_a
,
const
float
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
(
const
half4_t
&
,
const
half4_t
&
,
float32_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
32
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x4f16
<
64
,
32
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_32x32x8f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_16x16x16f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
16
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x4f16
<
64
,
16
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
4
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x4f16
<
8
,
64
>
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
(
const
ushort2_t
&
,
const
ushort2_t
&
,
float32_t
*
);
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
32
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_32x32x2bf16
<
64
,
32
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float32_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_32x32x4bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
__device__
void
gcnasm_mfma_f32_16x16x8bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
16
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_16x16x2bf16
<
64
,
16
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float16_t
*
reg_c
)
{
}
template
<
index_t
MPerWave
,
index_t
NPerWave
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
);
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
4
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
template
<
>
__device__
void
gcnasm_mfma_f32_4x4x2bf16
<
8
,
64
>
(
const
ushort2_t
&
reg_a
,
const
ushort2_t
&
reg_b
,
float4_t
*
reg_c
)
{
}
// clang-format on
}
#endif
composable_kernel/include/utility/common_header.hpp
View file @
82a15a27
...
@@ -27,6 +27,9 @@
...
@@ -27,6 +27,9 @@
#if CK_USE_AMD_XDLOPS
#if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp"
#include "amd_xdlops.hpp"
#else
#include "amd_xdlops_emulate.hpp"
#endif
#endif
#endif
#endif
composable_kernel/include/utility/float_type.nvidia.hpp.in
View file @
82a15a27
...
@@ -13,6 +13,19 @@ namespace ck {
...
@@ -13,6 +13,19 @@ namespace ck {
using float2_t = float2;
using float2_t = float2;
using float4_t = float4;
using float4_t = float4;
// float
typedef float float16_t __attribute__((ext_vector_type(16)));
typedef float float32_t __attribute__((ext_vector_type(32)));
// float16
typedef float half4_t __attribute__((ext_vector_type(2)));
typedef float half8_t __attribute__((ext_vector_type(4)));
// bfloat16
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
typedef ushort ushort8_t __attribute__((ext_vector_type(8)));
// float16
// float16
using half2_t = half2;
using half2_t = half2;
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
82a15a27
...
@@ -522,7 +522,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -522,7 +522,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif
1
#elif
0
// cdata = 64, BlockSize = 32, 32x64x3
// cdata = 64, BlockSize = 32, 32x64x3
constexpr
index_t
BlockSize
=
32
;
constexpr
index_t
BlockSize
=
32
;
...
@@ -559,6 +559,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -559,6 +559,45 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 1
// cdata = 64, BlockSize = 64, 32x128x3
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
BPerBlock
=
16
;
constexpr
index_t
EPerBlock
=
3
;
constexpr
index_t
GemmNRepeat
=
2
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
1
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_N1_B_N2
=
Sequence
<
3
,
1
,
1
,
2
>
;
using
InBlockCopyClusterLengths_E_N1_B_N2
=
Sequence
<
1
,
2
,
16
,
2
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [E, N1, N2, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [E, B, N1, N2]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [E, N1, B, N2]
constexpr
index_t
InBlockCopySrcDataPerRead_B
=
1
;
constexpr
index_t
InBlockCopyDstDataPerWrite_N2
=
2
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
3
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
1
,
32
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif 0
#elif 0
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
82a15a27
...
@@ -758,7 +758,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -758,7 +758,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif
1
#elif
0
// cdata = 64, BlockSize = 32, 32x64x3
// cdata = 64, BlockSize = 32, 32x64x3
constexpr
index_t
BlockSize
=
32
;
constexpr
index_t
BlockSize
=
32
;
...
@@ -790,6 +790,39 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -790,6 +790,39 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
2
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
2
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 1
// cdata = 64, BlockSize = 64, 32x128x3
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
GemmMPerBlock
=
32
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
3
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel1Cluster
=
2
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
3
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
1
,
32
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmK
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
3
,
2
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
1
,
64
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
2
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 1
#elif 1
// cdata = 64, BlockSize = 64, 64x64x3
// cdata = 64, BlockSize = 64, 64x64x3
...
...
driver/include/device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
82a15a27
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
>
void
device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
make_native_tensor_descriptor
(
InDesc
::
GetLengths
(),
InDesc
::
GetStrides
());
constexpr
auto
wei_kcyx_desc
=
make_native_tensor_descriptor
(
WeiDesc
::
GetLengths
(),
WeiDesc
::
GetStrides
());
constexpr
auto
out_nkhw_desc
=
make_native_tensor_descriptor
(
OutDesc
::
GetLengths
(),
OutDesc
::
GetStrides
());
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_nkhw_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
// cdata = 64, BlockSize = 256, 128x128x16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
ThreadGemmDataPerReadM
=
1
;
constexpr
index_t
ThreadGemmDataPerReadN
=
1
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmK
=
4
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
2
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmM
=
K
;
constexpr
index_t
GemmN
=
N
*
Ho
*
Wo
;
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v4r4_xdlops_fwd_fp32_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerWave
,
GemmNPerWave
,
ThreadGemmDataPerReadM
,
ThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmK
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
>
{};
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
{
float
time
=
launch_and_time_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
}
// warm up
printf
(
"Warn up running %d times...
\n
"
,
nrepeat
);
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
cudaDeviceSynchronize
();
auto
start
=
std
::
chrono
::
steady_clock
::
now
();
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
cudaDeviceSynchronize
();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
float
ave_time
=
std
::
chrono
::
duration
<
float
,
std
::
milli
>
(
end
-
start
).
count
()
/
nrepeat
;
printf
(
"Average elapsed time : %f ms, %f TFlop/s
\n
"
,
ave_time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
);
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/src/conv_driver.cpp
View file @
82a15a27
...
@@ -20,321 +20,18 @@
...
@@ -20,321 +20,18 @@
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
using
namespace
ck
;
using
namespace
ck
;
#if 0
// 1x1, 17x17
constexpr index_t N = 128;
constexpr index_t C = 1024;
constexpr index_t HI = 17;
constexpr index_t WI = 17;
constexpr index_t K = 256;
constexpr index_t Y = 1;
constexpr index_t X = 1;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
0
// 1x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
160
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
96
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 71x71
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
71
;
constexpr
index_t
WI
=
71
;
constexpr
index_t
K
=
192
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 7x1, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
320
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 0
// 1x7, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
224
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
224
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif 1
// 3x3, 299x299 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
3
;
constexpr
index_t
HI
=
299
;
constexpr
index_t
WI
=
299
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 147x147
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
WI
=
147
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 149x149
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
HI
=
149
;
constexpr
index_t
WI
=
149
;
constexpr
index_t
K
=
32
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 17x17, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
192
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
192
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 35x35
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 35x35, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x3, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
448
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
1
>
;
using
RightPads
=
Sequence
<
0
,
1
>
;
#elif 0
// 3x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
448
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
#elif 0
// 3x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
448
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
#elif 1
// 3x3, 147x147
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
WI
=
147
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 7x1, 73x73
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 1
// 3x3, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
WI
=
73
;
constexpr
index_t
K
=
96
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14
// 1x1, 14x14
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
256
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
constexpr
index_t
X
=
1
;
...
@@ -343,172 +40,6 @@ int main(int argc, char* argv[])
...
@@ -343,172 +40,6 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 14x14, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 28x28
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
// 3x3, 14x14
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
// 1x1, 56x56, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 7x7, 230x230 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
3
;
constexpr
index_t
HI
=
230
;
constexpr
index_t
WI
=
230
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 28x28, stride = 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
1024
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 28x28, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
2048
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3, 7x7
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 1
// 1x1, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 1
// 3x3, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
64
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#endif
auto
in_nchw_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
in_nchw_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
wei_kcyx_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
auto
wei_kcyx_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
...
@@ -603,7 +134,7 @@ int main(int argc, char* argv[])
...
@@ -603,7 +134,7 @@ int main(int argc, char* argv[])
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif 1
#elif 1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_
xdlops_
nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
wei_kcyx
,
wei_kcyx
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment