Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9359ebfd
"...composable_kernel.git" did not exist on "ad09ebdb531285c35f7c45be68db7fd52b5dc082"
Commit
9359ebfd
authored
Dec 04, 2019
by
Chao Liu
Browse files
updated fwd v4r4 to use gridwise gemm
parent
19c3c9a8
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
374 additions
and
626 deletions
+374
-626
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+3
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+3
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...ridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+167
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+0
-408
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+6
-0
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+8
-5
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+28
-30
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+2
-4
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+152
-174
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+1
-1
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+4
-4
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
View file @
9359ebfd
...
...
@@ -144,15 +144,18 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmN
,
GemmABlockCopyDstDataPerWrite_GemmN
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
9359ebfd
...
...
@@ -198,15 +198,18 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
9359ebfd
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
typename
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmK
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
GemmBBlockCopySrcDataPerRead_GemmN
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
GemmBBlockCopySrcDataPerRead_GemmN
==
0
)
&&
InLeftPads
{}[
1
]
%
GemmBBlockCopySrcDataPerRead_GemmN
==
0
&&
InRightPads
{}[
1
]
%
GemmBBlockCopySrcDataPerRead_GemmN
==
0
,
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// weight tensor
constexpr
auto
wei_e_k_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_e_b_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
constexpr
auto
out_k_b_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_e_k_global_desc
),
decltype
(
in_e_b_global_desc
),
decltype
(
out_k_b_global_desc
),
InMemoryDataOperation
::
none
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockCopySrcDataPerRead_GemmK
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_in_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
deleted
100644 → 0
View file @
19c3c9a8
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// B = merge(N, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
InBlockCopySubLengths_E_B
,
typename
InBlockCopyClusterLengths_E_B
,
typename
InBlockCopyThreadClusterArrangeOrder
,
typename
InBlockCopySrcAccessOrder
,
typename
InBlockCopyDstAccessOrder
,
index_t
InBlockCopyDataPerAccess_B
,
typename
WeiBlockCopySubLengths_E_K
,
typename
WeiBlockCopyClusterLengths_E_K
,
typename
WeiBlockCopyThreadClusterArrangeOrder
,
typename
WeiBlockCopySrcAccessOrder
,
typename
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
,
index_t
OutThreadCopyDataPerAccess_B
>
struct
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
make_native_tensor_descriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
constexpr
auto
wei_k_c_y_x_global_desc
=
make_native_tensor_descriptor
(
WeiGlobalDesc
::
GetLengths
(),
WeiGlobalDesc
::
GetStrides
());
constexpr
auto
out_n_k_ho_wo_global_desc
=
make_native_tensor_descriptor
(
OutGlobalDesc
::
GetLengths
(),
OutGlobalDesc
::
GetStrides
());
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InBlockCopyDataPerAccess_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_id
[
1
]
*
BPerBlock
;
// input tensor
// global mem
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_e_b_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// LDS mem
// be careful of LDS alignment
constexpr
auto
in_e_b_block_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
EPerBlock
,
BPerBlock
>
{});
// input blockwise copy
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
1
,
1
,
InBlockCopyDataPerAccess_B
,
InBlockCopyDataPerAccess_B
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
InMemoryDataOperation
::
none
>
(
{
0
,
b_block_data_on_global
},
{
0
,
0
});
// weight tensor
// global mem
constexpr
auto
wei_e_k_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
// LDS
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
wei_e_k_block_desc
.
GetStride
(
I0
)
%
GemmDataPerReadA
==
0
,
"GemmDataPerReadA alignment requirement is not satisfied"
);
// weight blockwise copy
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
InMemoryDataOperation
::
none
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_b_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_b_block_desc
);
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
&&
BPerBlock
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
BPerBlock
/
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k1_b0b1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
GemmNRepeat
*
GemmNPerThreadSubC
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_b_block_mtx_desc
),
decltype
(
c_k0k1_b0b1_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerAccess_B
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_b_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k1_b0b1_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k1_b0b1_thread_mtx_desc
,
p_out_thread
);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_thread_buffer
[
blockwise_in_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreThreadBuffer
(
p_in_thread_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
E
%
(
2
*
EPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_in_thread_buffer
[
blockwise_in_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store last data to LDS
blockwise_in_copy
.
RunStoreThreadBuffer
(
p_in_thread_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_double
+
wei_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
}
}
// copy output: register to global memory
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
// src descriptor
constexpr
auto
out_k0_k1_b0_b1_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
GemmMPerThreadSubC
,
GemmNRepeat
,
GemmNPerThreadSubC
>
{});
// dst descriptor
constexpr
index_t
K1
=
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
B1
=
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
K0
=
K
/
K1
;
constexpr
index_t
B0
=
B
/
B1
;
constexpr
auto
out_k_b_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
out_k0_k1_b0_b1_global_desc
=
transform_tensor_descriptor
(
out_k_b_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
K0
,
K1
>>
{},
UnMerge
<
Sequence
<
B0
,
B1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// output threadwise copy
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_k0_k1_b0_b1_thread_desc
),
decltype
(
out_k0_k1_b0_b1_global_desc
),
decltype
(
out_k0_k1_b0_b1_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
3
,
OutThreadCopyDataPerAccess_B
,
OutThreadCopyDataPerAccess_B
,
AddressSpace
::
vgpr
,
AddressSpace
::
global
>
({
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
K1
,
k_thread_data_on_global
%
K1
,
b_thread_data_on_global
/
B1
,
b_thread_data_on_global
%
B1
})
.
Run
(
p_out_thread
,
p_out_global
);
}
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
9359ebfd
...
...
@@ -9,6 +9,12 @@
namespace
ck
{
// This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimension of vector access can be different for src and dst.
// The dimension access order can be different for src and dst.
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template
<
index_t
BlockSize
,
typename
BlockSrcDesc
,
typename
BlockDstDesc
,
...
...
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
9359ebfd
...
...
@@ -34,16 +34,19 @@ template <index_t GridSize,
typename
ABlockCopyThreadSliceLengths_K_M
,
typename
ABlockCopyThreadClusterLengths_K_M
,
typename
ABlockCopyThreadClusterArrangeOrder
,
typename
ABlockCopySrcAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_M
,
typename
BBlockCopyThreadSliceLengths_K_N
,
typename
BBlockCopyThreadClusterLengths_K_N
,
typename
BBlockCopyThreadClusterArrangeOrder
,
typename
BBlockCopySrcAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_N
,
index_t
CThreadCopyVectorReadWriteDim
,
typename
CThreadCopySrcDstAccessOrder
,
index_t
CThreadCopySrcDstVectorReadWriteDim
,
index_t
CThreadCopyDstDataPerWrite
>
struct
GridwiseGemmTransposedANormalBNormalC_v1
{
...
...
@@ -96,7 +99,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
ABlockCopyThreadSliceLengths_K_M
,
ABlockCopyThreadClusterLengths_K_M
,
ABlockCopyThreadClusterArrangeOrder
,
Sequence
<
0
,
1
>
,
ABlockCopySrcAccessOrder
,
Sequence
<
0
,
1
>
,
ABlockCopySrcVectorReadDim
,
1
,
...
...
@@ -122,7 +125,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
BBlockCopyThreadSliceLengths_K_N
,
BBlockCopyThreadClusterLengths_K_N
,
BBlockCopyThreadClusterArrangeOrder
,
Sequence
<
0
,
1
>
,
BBlockCopySrcAccessOrder
,
Sequence
<
0
,
1
>
,
BBlockCopySrcVectorReadDim
,
1
,
...
...
@@ -311,8 +314,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
CThreadCopyVectorReadWriteDim
,
CThreadCopySrcDstAccessOrder
,
CThreadCopy
SrcDst
VectorReadWriteDim
,
1
,
CThreadCopyDstDataPerWrite
,
AddressSpace
::
vgpr
,
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
9359ebfd
...
...
@@ -8,20 +8,17 @@
namespace
ck
{
// This version use multi-index transformation
// This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimensions of vector access should be the same on src and dst.
// The dimension access order should be the same on src and dst.
// It is designed for cases, where one of src and dst is register, and
// the other is device memory or LDS
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template
<
typename
SrcDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
VectorReadWriteDim
,
typename
SrcDst
DimAccessOrder
,
index_t
SrcDst
VectorReadWriteDim
,
index_t
SrcDataPerRead
,
index_t
DstDataPerWrite
,
AddressSpace
SrcAddressSpace
=
AddressSpace
::
generic
,
...
...
@@ -41,14 +38,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
nDim
==
SrcDst
DimAccessOrder
::
Size
(),
"wrong! # of dimensions not the same"
);
static_assert
(
is_valid_sequence_map
<
DimAccessOrder
>
{},
"wrong! map is not valid"
);
static_assert
(
is_valid_sequence_map
<
SrcDst
DimAccessOrder
>
{},
"wrong! map is not valid"
);
static_assert
(
SliceLengths
{}[
VectorReadWriteDim
]
%
math
::
lcm
(
SrcDataPerRead
,
DstDataPerWrite
)
==
0
,
"wrong! cannot evenly divide"
);
static_assert
(
SliceLengths
{}[
SrcDstVectorReadWriteDim
]
%
math
::
lcm
(
SrcDataPerRead
,
DstDataPerWrite
)
==
0
,
"wrong! cannot evenly divide"
);
// TODO:: sanity-check if vectorized memory read/write is allowed on src and dst
}
...
...
@@ -72,7 +70,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
template
<
typename
SrcData
,
typename
DstData
>
__device__
void
Run
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
constexpr
auto
vector_access_dim
=
Number
<
VectorReadWriteDim
>
{};
constexpr
auto
vector_access_dim
=
Number
<
SrcDst
VectorReadWriteDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerRead
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerWrite
>
{};
...
...
@@ -82,7 +80,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr
auto
long_vector_access_lengths
=
SliceLengths
::
Modify
(
vector_access_dim
,
SliceLengths
::
Get
(
vector_access_dim
)
/
long_vector_size
);
ford
<
decltype
(
long_vector_access_lengths
),
DimAccessOrder
>
{}([
&
](
ford
<
decltype
(
long_vector_access_lengths
),
SrcDst
DimAccessOrder
>
{}([
&
](
auto
long_vector_access_id
)
{
// data id w.r.t slicing-window
...
...
@@ -173,7 +171,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
__device__
void
Run_optimized_src_address_calculation
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
constexpr
auto
vector_access_dim
=
Number
<
VectorReadWriteDim
>
{};
constexpr
auto
vector_access_dim
=
Number
<
SrcDst
VectorReadWriteDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerRead
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerWrite
>
{};
...
...
@@ -187,10 +185,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr
auto
src_linear_dim_mask
=
SrcDesc
::
GetLinearDimensionMask
();
constexpr
auto
src_nonlinear_dim_mask
=
SrcDesc
::
GetNonLinearDimensionMask
();
static_assert
(
src_linear_dim_mask
.
At
(
VectorReadWriteDim
)
||
long_vector_size
==
SrcDataPerRead
,
"Warning! VectorReadWriteDim is not SrcDesc's linear dimension, performance "
"would drop"
);
static_assert
(
src_linear_dim_mask
.
At
(
SrcDstVectorReadWriteDim
)
||
long_vector_size
==
SrcDataPerRead
,
"Warning!
SrcDst
VectorReadWriteDim is not SrcDesc's linear dimension, performance "
"would drop"
);
// separate steps into linear and non-linear components, accoording to src tensor
constexpr
auto
linear_long_vector_access_lengths
=
...
...
@@ -230,13 +228,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector
[
i
]
=
0
;
}
// Loop over VectorReadWriteDim, and load data from src to the
// Loop over
SrcDst
VectorReadWriteDim, and load data from src to the
// long-vector buffer.
// If VectorReadWriteDim is src's linear dimension, then src's
// If
SrcDst
VectorReadWriteDim is src's linear dimension, then src's
// offset-diff due to this looping is known at compile-time. If
// VectorReadWriteDim is src's nonlinear dimension, then src's
//
SrcDst
VectorReadWriteDim is src's nonlinear dimension, then src's
// offset-diff due to this looping is only known at run-time. For best
// performance, VectorReadWriteDim, should be src's linear dimension
// performance,
SrcDst
VectorReadWriteDim, should be src's linear dimension
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
src_data_per_access
;
++
i
)
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
...
...
@@ -321,7 +319,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
__device__
void
Run_optimized_dst_address_calculation
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
constexpr
auto
vector_access_dim
=
Number
<
VectorReadWriteDim
>
{};
constexpr
auto
vector_access_dim
=
Number
<
SrcDst
VectorReadWriteDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerRead
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerWrite
>
{};
...
...
@@ -335,10 +333,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr
auto
dst_linear_dim_mask
=
DstDesc
::
GetLinearDimensionMask
();
constexpr
auto
dst_nonlinear_dim_mask
=
DstDesc
::
GetNonLinearDimensionMask
();
static_assert
(
dst_linear_dim_mask
.
At
(
VectorReadWriteDim
)
||
long_vector_size
==
DstDataPerWrite
,
"Warning! VectorReadWriteDim is not DstDesc's linear dimension, performance "
"would drop"
);
static_assert
(
dst_linear_dim_mask
.
At
(
SrcDstVectorReadWriteDim
)
||
long_vector_size
==
DstDataPerWrite
,
"Warning!
SrcDst
VectorReadWriteDim is not DstDesc's linear dimension, performance "
"would drop"
);
// separate steps into linear and non-linear components, accoording to dst tensor
constexpr
auto
linear_long_vector_access_lengths
=
...
...
@@ -378,13 +376,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector
[
i
]
=
0
;
}
// Loop over VectorReadWriteDim, and load data from src to the
// Loop over
SrcDst
VectorReadWriteDim, and load data from src to the
// long-vector buffer.
// If VectorReadWriteDim is dst's linear dimension, then dst's
// If
SrcDst
VectorReadWriteDim is dst's linear dimension, then dst's
// offset-diff due to this looping is known at compile-time. If
// VectorReadWriteDim is dst's nonlinear dimension, then dst's
//
SrcDst
VectorReadWriteDim is dst's nonlinear dimension, then dst's
// offset-diff due to this looping is only known at run-time. For best
// performance, VectorReadWriteDim, should be dst's linear dimension
// performance,
SrcDst
VectorReadWriteDim, should be dst's linear dimension
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
src_data_per_access
;
++
i
)
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
9359ebfd
...
...
@@ -83,13 +83,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
constexpr
index_t
hcf_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf_stride_dilation_h
;
// may be wrong
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf_stride_dilation_w
;
// may be wrong
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf_stride_dilation_h
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf_stride_dilation_w
;
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
9359ebfd
...
...
@@ -3,7 +3,7 @@
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
_lds_double_buffer
.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template
<
class
T
,
class
InDesc
,
...
...
@@ -32,9 +32,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
auto
in_nchw_desc
=
make_native_tensor_descriptor
(
InDesc
::
GetLengths
(),
InDesc
::
GetStrides
());
constexpr
auto
wei_kcyx_desc
=
make_native_tensor_descriptor
(
WeiDesc
::
GetLengths
(),
WeiDesc
::
GetStrides
());
constexpr
auto
out_nkhw_desc
=
make_native_tensor_descriptor
(
OutDesc
::
GetLengths
(),
OutDesc
::
GetStrides
());
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_nkhw_desc
.
GetLength
(
I1
);
...
...
@@ -51,198 +54,173 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 1
// BlockSize = 256,
E
PerBlock = 8
// BlockSize = 256,
GemmK
PerBlock = 8
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_B
=
Sequence
<
4
,
1
>
;
using
InBlockCopyClusterLengths_E_B
=
Sequence
<
2
,
128
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
InBlockCopyDataPerAccess_B
=
1
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_B
=
1
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmK
=
4
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif 0
// BlockSize = 256,
E
PerBlock = 8
// BlockSize = 256,
GemmK
PerBlock = 8
// 1x1 filter, 8x8 image
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_B
=
Sequence
<
1
,
4
>
;
using
InBlockCopyClusterLengths_E_B
=
Sequence
<
8
,
32
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
InBlockCopyDataPerAccess_B
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_B
=
4
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmK
=
4
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
4
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
#elif 0
// BlockSize = 256,
E
PerBlock = 16
// BlockSize = 256,
GemmK
PerBlock = 16
// 1x1 filter, 8x8 image
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
16
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_B
=
Sequence
<
2
,
4
>
;
using
InBlockCopyClusterLengths_E_B
=
Sequence
<
8
,
32
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
InBlockCopyDataPerAccess_B
=
4
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
2
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
4
,
64
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_B
=
4
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmK
=
4
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
#elif 1
// 1x1 filter, 14x14 image
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
InBlockCopySubLengths_E_B
=
Sequence
<
2
,
2
>
;
using
InBlockCopyClusterLengths_E_B
=
Sequence
<
4
,
64
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
InBlockCopyDataPerAccess_B
=
2
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_B
=
2
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmK
=
4
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
2
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
2
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
2
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
2
;
#endif
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
constexpr
index_t
GridSize
=
((
B
+
B
PerBlock
-
1
)
/
B
PerBlock
)
*
((
K
+
K
PerBlock
-
1
)
/
K
PerBlock
);
((
B
+
GemmN
PerBlock
-
1
)
/
GemmN
PerBlock
)
*
((
K
+
GemmM
PerBlock
-
1
)
/
GemmM
PerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
<
GridSize
,
BlockSize
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
LeftPads
,
RightPads
,
BPerBlock
,
KPerBlock
,
EPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
InBlockCopyDataPerAccess_B
,
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
,
OutThreadCopyDataPerAccess_B
>
{};
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
LeftPads
,
RightPads
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
ThreadGemmDataPerReadM
,
ThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopySrcDataPerRead_GemmK
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
9359ebfd
...
...
@@ -21,7 +21,7 @@ int main(int argc, char* argv[])
{
using
namespace
ck
;
#if
0
#if
1
constexpr
index_t
N
=
8
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
16
;
...
...
driver/src/conv_driver.cpp
View file @
9359ebfd
...
...
@@ -43,7 +43,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
0
#elif
1
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
...
...
@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
1
#elif
0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr
index_t
N
=
128
;
...
...
@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
0
#elif
1
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
...
...
@@ -403,7 +403,7 @@ int main(int argc, char* argv[])
ConvStrides
{},
ConvDilations
{},
nrepeat
);
#elif
1
#elif
0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment