Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9359ebfd
Commit
9359ebfd
authored
Dec 04, 2019
by
Chao Liu
Browse files
updated fwd v4r4 to use gridwise gemm
parent
19c3c9a8
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
374 additions
and
626 deletions
+374
-626
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
+3
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+3
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...ridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+167
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+0
-408
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
.../tensor_operation/blockwise_generic_tensor_slice_copy.hpp
+6
-0
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+8
-5
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+28
-30
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+2
-4
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+152
-174
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+1
-1
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+4
-4
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp
View file @
9359ebfd
...
@@ -144,15 +144,18 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
...
@@ -144,15 +144,18 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
GemmABlockCopySrcDataPerRead_GemmN
,
GemmABlockCopySrcDataPerRead_GemmN
,
GemmABlockCopyDstDataPerWrite_GemmN
,
GemmABlockCopyDstDataPerWrite_GemmN
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
9359ebfd
...
@@ -198,15 +198,18 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -198,15 +198,18 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopySrcDataPerRead_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
...
...
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
9359ebfd
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccFloat
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmThreadGemmDataPerReadM
,
index_t
GemmThreadGemmDataPerReadN
,
typename
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
typename
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
index_t
GemmABlockCopySrcDataPerRead_GemmK
,
index_t
GemmABlockCopyDstDataPerWrite_GemmM
,
typename
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
typename
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
index_t
GemmBBlockCopySrcDataPerRead_GemmN
,
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
,
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
>
struct
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
GemmBBlockCopySrcDataPerRead_GemmN
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
GemmBBlockCopySrcDataPerRead_GemmN
==
0
)
&&
InLeftPads
{}[
1
]
%
GemmBBlockCopySrcDataPerRead_GemmN
==
0
&&
InRightPads
{}[
1
]
%
GemmBBlockCopySrcDataPerRead_GemmN
==
0
,
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// weight tensor
constexpr
auto
wei_e_k_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_e_b_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
constexpr
auto
out_k_b_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
constexpr
auto
gridwise_gemm
=
GridwiseGemmTransposedANormalBNormalC_v1
<
GridSize
,
BlockSize
,
Float
,
AccFloat
,
decltype
(
wei_e_k_global_desc
),
decltype
(
in_e_b_global_desc
),
decltype
(
out_k_b_global_desc
),
InMemoryDataOperation
::
none
,
GemmMPerBlock
,
GemmNPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmThreadGemmDataPerReadM
,
GemmThreadGemmDataPerReadN
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
Sequence
<
1
,
0
>
,
Sequence
<
1
,
0
>
,
0
,
GemmABlockCopySrcDataPerRead_GemmK
,
GemmABlockCopyDstDataPerWrite_GemmM
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
gridwise_gemm
.
Run
(
p_wei_global
,
p_in_global
,
p_out_global
);
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer.hpp
deleted
100644 → 0
View file @
19c3c9a8
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
// B = merge(N, Ho, Wo)
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
InBlockCopySubLengths_E_B
,
typename
InBlockCopyClusterLengths_E_B
,
typename
InBlockCopyThreadClusterArrangeOrder
,
typename
InBlockCopySrcAccessOrder
,
typename
InBlockCopyDstAccessOrder
,
index_t
InBlockCopyDataPerAccess_B
,
typename
WeiBlockCopySubLengths_E_K
,
typename
WeiBlockCopyClusterLengths_E_K
,
typename
WeiBlockCopyThreadClusterArrangeOrder
,
typename
WeiBlockCopySrcAccessOrder
,
typename
WeiBlockCopyDstAccessOrder
,
index_t
WeiBlockCopySrcDataPerRead_E
,
index_t
WeiBlockCopyDstDataPerWrite_K
,
index_t
OutThreadCopyDataPerAccess_B
>
struct
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
const
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
make_native_tensor_descriptor
(
InGlobalDesc
::
GetLengths
(),
InGlobalDesc
::
GetStrides
());
constexpr
auto
wei_k_c_y_x_global_desc
=
make_native_tensor_descriptor
(
WeiGlobalDesc
::
GetLengths
(),
WeiGlobalDesc
::
GetStrides
());
constexpr
auto
out_n_k_ho_wo_global_desc
=
make_native_tensor_descriptor
(
OutGlobalDesc
::
GetLengths
(),
OutGlobalDesc
::
GetStrides
());
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
);
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
);
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
);
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
);
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
);
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InBlockCopyDataPerAccess_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InBlockCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// divide block work by [K, B]
static_assert
(
K
%
KPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
E
%
EPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
KBlockWork
=
K
/
KPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
KBlockWork
,
BBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
block_work_id
[
0
]
*
KPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_id
[
1
]
*
BPerBlock
;
// input tensor
// global mem
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_e_b_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// LDS mem
// be careful of LDS alignment
constexpr
auto
in_e_b_block_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
EPerBlock
,
BPerBlock
>
{});
// input blockwise copy
auto
blockwise_in_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
in_e_b_global_desc
),
decltype
(
in_e_b_block_desc
),
decltype
(
in_e_b_block_desc
.
GetLengths
()),
InBlockCopySubLengths_E_B
,
InBlockCopyClusterLengths_E_B
,
InBlockCopyThreadClusterArrangeOrder
,
InBlockCopySrcAccessOrder
,
InBlockCopyDstAccessOrder
,
1
,
1
,
InBlockCopyDataPerAccess_B
,
InBlockCopyDataPerAccess_B
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
InMemoryDataOperation
::
none
>
(
{
0
,
b_block_data_on_global
},
{
0
,
0
});
// weight tensor
// global mem
constexpr
auto
wei_e_k_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
// LDS
// be careful of LDS alignment
constexpr
auto
wei_e_k_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
EPerBlock
,
KPerBlock
>
{},
Number
<
math
::
lcm
(
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
)
>
{});
// this check is ad-hoc
// TODO: need to properly implement tensor descriptor with multiple alignment
// requirements
static_assert
(
wei_e_k_block_desc
.
GetStride
(
I0
)
%
GemmDataPerReadA
==
0
,
"GemmDataPerReadA alignment requirement is not satisfied"
);
// weight blockwise copy
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
wei_e_k_global_desc
),
decltype
(
wei_e_k_block_desc
),
decltype
(
wei_e_k_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_E_K
,
WeiBlockCopyClusterLengths_E_K
,
WeiBlockCopyThreadClusterArrangeOrder
,
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
0
,
1
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
InMemoryDataOperation
::
none
>
(
{
0
,
k_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[EPerBlock, KPerBlock] is in LDS
// b_mtx[EPerBlocl, BPerBlock] is in LDS
// c_mtx[KPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_e_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_e_k_block_desc
);
constexpr
auto
b_e_b_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
in_e_b_block_desc
);
// sanity check
static_assert
(
KPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
&&
BPerBlock
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
KPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
BPerBlock
/
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_k0k1_b0b1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
GemmNRepeat
*
GemmNPerThreadSubC
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_e_k_block_mtx_desc
),
decltype
(
b_e_b_block_mtx_desc
),
decltype
(
c_k0k1_b0b1_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDataPerAccess_B
,
WeiBlockCopyDstDataPerWrite_K
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
math
::
integer_least_multiple
(
in_e_b_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_e_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
Float
p_out_thread
[
c_k0k1_b0b1_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_k0k1_b0b1_thread_mtx_desc
,
p_out_thread
);
// LDS double buffer: preload data into LDS
{
blockwise_in_copy
.
Run
(
p_in_global
,
p_in_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
e_block_data_begin
=
0
;
e_block_data_begin
+
2
*
EPerBlock
<
E
;
e_block_data_begin
+=
2
*
EPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_in_block_now
=
even_loop
?
p_in_block_double
:
p_in_block_double
+
in_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_in_block_next
=
even_loop
?
p_in_block_double
+
in_block_space
:
p_in_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_in_thread_buffer
[
blockwise_in_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreThreadBuffer
(
p_in_thread_buffer
,
p_in_block_next
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
E
%
(
2
*
EPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_in_thread_buffer
[
blockwise_in_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_in_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
EPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
blockwise_in_copy
.
RunLoadThreadBuffer
(
p_in_global
,
p_in_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store last data to LDS
blockwise_in_copy
.
RunStoreThreadBuffer
(
p_in_thread_buffer
,
p_in_block_double
+
in_block_space
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_double
+
wei_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
}
}
// copy output: register to global memory
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
// src descriptor
constexpr
auto
out_k0_k1_b0_b1_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
GemmMPerThreadSubC
,
GemmNRepeat
,
GemmNPerThreadSubC
>
{});
// dst descriptor
constexpr
index_t
K1
=
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
B1
=
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
K0
=
K
/
K1
;
constexpr
index_t
B0
=
B
/
B1
;
constexpr
auto
out_k_b_global_desc
=
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
out_k0_k1_b0_b1_global_desc
=
transform_tensor_descriptor
(
out_k_b_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
K0
,
K1
>>
{},
UnMerge
<
Sequence
<
B0
,
B1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// output threadwise copy
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_k0_k1_b0_b1_thread_desc
),
decltype
(
out_k0_k1_b0_b1_global_desc
),
decltype
(
out_k0_k1_b0_b1_thread_desc
.
GetLengths
()),
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
3
,
OutThreadCopyDataPerAccess_B
,
OutThreadCopyDataPerAccess_B
,
AddressSpace
::
vgpr
,
AddressSpace
::
global
>
({
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
K1
,
k_thread_data_on_global
%
K1
,
b_thread_data_on_global
/
B1
,
b_thread_data_on_global
%
B1
})
.
Run
(
p_out_thread
,
p_out_global
);
}
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp
View file @
9359ebfd
...
@@ -9,6 +9,12 @@
...
@@ -9,6 +9,12 @@
namespace
ck
{
namespace
ck
{
// This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// The dimension of vector access can be different for src and dst.
// The dimension access order can be different for src and dst.
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
typename
BlockSrcDesc
,
typename
BlockSrcDesc
,
typename
BlockDstDesc
,
typename
BlockDstDesc
,
...
...
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
9359ebfd
...
@@ -34,16 +34,19 @@ template <index_t GridSize,
...
@@ -34,16 +34,19 @@ template <index_t GridSize,
typename
ABlockCopyThreadSliceLengths_K_M
,
typename
ABlockCopyThreadSliceLengths_K_M
,
typename
ABlockCopyThreadClusterLengths_K_M
,
typename
ABlockCopyThreadClusterLengths_K_M
,
typename
ABlockCopyThreadClusterArrangeOrder
,
typename
ABlockCopyThreadClusterArrangeOrder
,
typename
ABlockCopySrcAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_M
,
index_t
ABlockCopyDstDataPerWrite_M
,
typename
BBlockCopyThreadSliceLengths_K_N
,
typename
BBlockCopyThreadSliceLengths_K_N
,
typename
BBlockCopyThreadClusterLengths_K_N
,
typename
BBlockCopyThreadClusterLengths_K_N
,
typename
BBlockCopyThreadClusterArrangeOrder
,
typename
BBlockCopyThreadClusterArrangeOrder
,
typename
BBlockCopySrcAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_N
,
index_t
BBlockCopyDstDataPerWrite_N
,
index_t
CThreadCopyVectorReadWriteDim
,
typename
CThreadCopySrcDstAccessOrder
,
index_t
CThreadCopySrcDstVectorReadWriteDim
,
index_t
CThreadCopyDstDataPerWrite
>
index_t
CThreadCopyDstDataPerWrite
>
struct
GridwiseGemmTransposedANormalBNormalC_v1
struct
GridwiseGemmTransposedANormalBNormalC_v1
{
{
...
@@ -96,7 +99,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -96,7 +99,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
ABlockCopyThreadSliceLengths_K_M
,
ABlockCopyThreadSliceLengths_K_M
,
ABlockCopyThreadClusterLengths_K_M
,
ABlockCopyThreadClusterLengths_K_M
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopyThreadClusterArrangeOrder
,
Sequence
<
0
,
1
>
,
ABlockCopySrcAccessOrder
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
ABlockCopySrcVectorReadDim
,
ABlockCopySrcVectorReadDim
,
1
,
1
,
...
@@ -122,7 +125,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -122,7 +125,7 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
BBlockCopyThreadSliceLengths_K_N
,
BBlockCopyThreadSliceLengths_K_N
,
BBlockCopyThreadClusterLengths_K_N
,
BBlockCopyThreadClusterLengths_K_N
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopyThreadClusterArrangeOrder
,
Sequence
<
0
,
1
>
,
BBlockCopySrcAccessOrder
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
BBlockCopySrcVectorReadDim
,
BBlockCopySrcVectorReadDim
,
1
,
1
,
...
@@ -311,8 +314,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
...
@@ -311,8 +314,8 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_m0_m1_n0_n1_thread_desc
),
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_m0_m1_n0_n1_thread_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
decltype
(
c_m0_m1_n0_n1_global_desc
),
decltype
(
c_m0_m1_n0_n1_thread_desc
.
GetLengths
()),
decltype
(
c_m0_m1_n0_n1_thread_desc
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
CThreadCopySrcDstAccessOrder
,
CThreadCopyVectorReadWriteDim
,
CThreadCopy
SrcDst
VectorReadWriteDim
,
1
,
1
,
CThreadCopyDstDataPerWrite
,
CThreadCopyDstDataPerWrite
,
AddressSpace
::
vgpr
,
AddressSpace
::
vgpr
,
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
9359ebfd
...
@@ -8,20 +8,17 @@
...
@@ -8,20 +8,17 @@
namespace
ck
{
namespace
ck
{
// This version use multi-index transformation
// This threadwise copy allow vector access of src and dst.
// This threadwise copy allow vector access of src and dst.
// It allows the vector size to be different on src and dst.
// It allows the vector size to be different on src and dst.
// The dimensions of vector access should be the same on src and dst.
// The dimensions of vector access should be the same on src and dst.
// The dimension access order should be the same on src and dst.
// The dimension access order should be the same on src and dst.
// It is designed for cases, where one of src and dst is register, and
// the other is device memory or LDS
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on src data: Read 0 if src data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
// Will do valid mapping check on dst data: No write if dst data has a invalid mapping
template
<
typename
SrcDesc
,
template
<
typename
SrcDesc
,
typename
DstDesc
,
typename
DstDesc
,
typename
SliceLengths
,
typename
SliceLengths
,
typename
DimAccessOrder
,
typename
SrcDst
DimAccessOrder
,
index_t
VectorReadWriteDim
,
index_t
SrcDst
VectorReadWriteDim
,
index_t
SrcDataPerRead
,
index_t
SrcDataPerRead
,
index_t
DstDataPerWrite
,
index_t
DstDataPerWrite
,
AddressSpace
SrcAddressSpace
=
AddressSpace
::
generic
,
AddressSpace
SrcAddressSpace
=
AddressSpace
::
generic
,
...
@@ -41,14 +38,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -41,14 +38,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
{
{
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
static_assert
(
nDim
==
SrcDesc
::
GetNumOfDimension
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
Size
()
&&
nDim
==
DstDesc
::
GetNumOfDimension
()
&&
nDim
==
SliceLengths
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
nDim
==
SrcDst
DimAccessOrder
::
Size
(),
"wrong! # of dimensions not the same"
);
"wrong! # of dimensions not the same"
);
static_assert
(
is_valid_sequence_map
<
DimAccessOrder
>
{},
"wrong! map is not valid"
);
static_assert
(
is_valid_sequence_map
<
SrcDst
DimAccessOrder
>
{},
"wrong! map is not valid"
);
static_assert
(
static_assert
(
SliceLengths
{}[
SrcDstVectorReadWriteDim
]
%
SliceLengths
{}[
VectorReadWriteDim
]
%
math
::
lcm
(
SrcDataPerRead
,
DstDataPerWrite
)
==
0
,
math
::
lcm
(
SrcDataPerRead
,
DstDataPerWrite
)
==
"wrong! cannot evenly divide"
);
0
,
"wrong! cannot evenly divide"
);
// TODO:: sanity-check if vectorized memory read/write is allowed on src and dst
// TODO:: sanity-check if vectorized memory read/write is allowed on src and dst
}
}
...
@@ -72,7 +70,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -72,7 +70,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
template
<
typename
SrcData
,
typename
DstData
>
template
<
typename
SrcData
,
typename
DstData
>
__device__
void
Run
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
__device__
void
Run
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
{
constexpr
auto
vector_access_dim
=
Number
<
VectorReadWriteDim
>
{};
constexpr
auto
vector_access_dim
=
Number
<
SrcDst
VectorReadWriteDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerRead
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerRead
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerWrite
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerWrite
>
{};
...
@@ -82,7 +80,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -82,7 +80,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr
auto
long_vector_access_lengths
=
SliceLengths
::
Modify
(
constexpr
auto
long_vector_access_lengths
=
SliceLengths
::
Modify
(
vector_access_dim
,
SliceLengths
::
Get
(
vector_access_dim
)
/
long_vector_size
);
vector_access_dim
,
SliceLengths
::
Get
(
vector_access_dim
)
/
long_vector_size
);
ford
<
decltype
(
long_vector_access_lengths
),
DimAccessOrder
>
{}([
&
](
ford
<
decltype
(
long_vector_access_lengths
),
SrcDst
DimAccessOrder
>
{}([
&
](
auto
long_vector_access_id
)
{
auto
long_vector_access_id
)
{
// data id w.r.t slicing-window
// data id w.r.t slicing-window
...
@@ -173,7 +171,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -173,7 +171,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
__device__
void
Run_optimized_src_address_calculation
(
const
SrcData
*
p_src
,
__device__
void
Run_optimized_src_address_calculation
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
DstData
*
p_dst
)
const
{
{
constexpr
auto
vector_access_dim
=
Number
<
VectorReadWriteDim
>
{};
constexpr
auto
vector_access_dim
=
Number
<
SrcDst
VectorReadWriteDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerRead
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerRead
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerWrite
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerWrite
>
{};
...
@@ -187,10 +185,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -187,10 +185,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr
auto
src_linear_dim_mask
=
SrcDesc
::
GetLinearDimensionMask
();
constexpr
auto
src_linear_dim_mask
=
SrcDesc
::
GetLinearDimensionMask
();
constexpr
auto
src_nonlinear_dim_mask
=
SrcDesc
::
GetNonLinearDimensionMask
();
constexpr
auto
src_nonlinear_dim_mask
=
SrcDesc
::
GetNonLinearDimensionMask
();
static_assert
(
src_linear_dim_mask
.
At
(
VectorReadWriteDim
)
||
static_assert
(
long_vector_size
==
SrcDataPerRead
,
src_linear_dim_mask
.
At
(
SrcDstVectorReadWriteDim
)
||
long_vector_size
==
SrcDataPerRead
,
"Warning! VectorReadWriteDim is not SrcDesc's linear dimension, performance "
"Warning!
SrcDst
VectorReadWriteDim is not SrcDesc's linear dimension, performance "
"would drop"
);
"would drop"
);
// separate steps into linear and non-linear components, accoording to src tensor
// separate steps into linear and non-linear components, accoording to src tensor
constexpr
auto
linear_long_vector_access_lengths
=
constexpr
auto
linear_long_vector_access_lengths
=
...
@@ -230,13 +228,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -230,13 +228,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector
[
i
]
=
0
;
p_src_long_vector
[
i
]
=
0
;
}
}
// Loop over VectorReadWriteDim, and load data from src to the
// Loop over
SrcDst
VectorReadWriteDim, and load data from src to the
// long-vector buffer.
// long-vector buffer.
// If VectorReadWriteDim is src's linear dimension, then src's
// If
SrcDst
VectorReadWriteDim is src's linear dimension, then src's
// offset-diff due to this looping is known at compile-time. If
// offset-diff due to this looping is known at compile-time. If
// VectorReadWriteDim is src's nonlinear dimension, then src's
//
SrcDst
VectorReadWriteDim is src's nonlinear dimension, then src's
// offset-diff due to this looping is only known at run-time. For best
// offset-diff due to this looping is only known at run-time. For best
// performance, VectorReadWriteDim, should be src's linear dimension
// performance,
SrcDst
VectorReadWriteDim, should be src's linear dimension
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
src_data_per_access
;
++
i
)
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
src_data_per_access
;
++
i
)
{
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
...
@@ -321,7 +319,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -321,7 +319,7 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
__device__
void
Run_optimized_dst_address_calculation
(
const
SrcData
*
p_src
,
__device__
void
Run_optimized_dst_address_calculation
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
DstData
*
p_dst
)
const
{
{
constexpr
auto
vector_access_dim
=
Number
<
VectorReadWriteDim
>
{};
constexpr
auto
vector_access_dim
=
Number
<
SrcDst
VectorReadWriteDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerRead
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerRead
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerWrite
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerWrite
>
{};
...
@@ -335,10 +333,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -335,10 +333,10 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
constexpr
auto
dst_linear_dim_mask
=
DstDesc
::
GetLinearDimensionMask
();
constexpr
auto
dst_linear_dim_mask
=
DstDesc
::
GetLinearDimensionMask
();
constexpr
auto
dst_nonlinear_dim_mask
=
DstDesc
::
GetNonLinearDimensionMask
();
constexpr
auto
dst_nonlinear_dim_mask
=
DstDesc
::
GetNonLinearDimensionMask
();
static_assert
(
dst_linear_dim_mask
.
At
(
VectorReadWriteDim
)
||
static_assert
(
long_vector_size
==
DstDataPerWrite
,
dst_linear_dim_mask
.
At
(
SrcDstVectorReadWriteDim
)
||
long_vector_size
==
DstDataPerWrite
,
"Warning! VectorReadWriteDim is not DstDesc's linear dimension, performance "
"Warning!
SrcDst
VectorReadWriteDim is not DstDesc's linear dimension, performance "
"would drop"
);
"would drop"
);
// separate steps into linear and non-linear components, accoording to dst tensor
// separate steps into linear and non-linear components, accoording to dst tensor
constexpr
auto
linear_long_vector_access_lengths
=
constexpr
auto
linear_long_vector_access_lengths
=
...
@@ -378,13 +376,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
...
@@ -378,13 +376,13 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
p_src_long_vector
[
i
]
=
0
;
p_src_long_vector
[
i
]
=
0
;
}
}
// Loop over VectorReadWriteDim, and load data from src to the
// Loop over
SrcDst
VectorReadWriteDim, and load data from src to the
// long-vector buffer.
// long-vector buffer.
// If VectorReadWriteDim is dst's linear dimension, then dst's
// If
SrcDst
VectorReadWriteDim is dst's linear dimension, then dst's
// offset-diff due to this looping is known at compile-time. If
// offset-diff due to this looping is known at compile-time. If
// VectorReadWriteDim is dst's nonlinear dimension, then dst's
//
SrcDst
VectorReadWriteDim is dst's nonlinear dimension, then dst's
// offset-diff due to this looping is only known at run-time. For best
// offset-diff due to this looping is only known at run-time. For best
// performance, VectorReadWriteDim, should be dst's linear dimension
// performance,
SrcDst
VectorReadWriteDim, should be dst's linear dimension
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
src_data_per_access
;
++
i
)
for
(
index_t
i
=
0
;
i
<
long_vector_size
/
src_data_per_access
;
++
i
)
{
{
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
auto
scalar_id
=
make_zero_array
<
index_t
,
nDim
>
();
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
9359ebfd
...
@@ -83,13 +83,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
...
@@ -83,13 +83,11 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#endif
#endif
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
constexpr
index_t
hcf_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
hcf_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
hcf_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf_stride_dilation_h
;
// may be wrong
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf_stride_dilation_h
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf_stride_dilation_w
;
// may be wrong
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf_stride_dilation_w
;
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
...
...
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
9359ebfd
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "device.hpp"
#include "device.hpp"
#include "tensor.hpp"
#include "tensor.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_kernel_wrapper.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
_lds_double_buffer
.hpp"
#include "gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
template
<
class
T
,
template
<
class
T
,
class
InDesc
,
class
InDesc
,
...
@@ -32,9 +32,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -32,9 +32,12 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
InDesc
{};
constexpr
auto
in_nchw_desc
=
constexpr
auto
wei_kcyx_desc
=
WeiDesc
{};
make_native_tensor_descriptor
(
InDesc
::
GetLengths
(),
InDesc
::
GetStrides
());
constexpr
auto
out_nkhw_desc
=
OutDesc
{};
constexpr
auto
wei_kcyx_desc
=
make_native_tensor_descriptor
(
WeiDesc
::
GetLengths
(),
WeiDesc
::
GetStrides
());
constexpr
auto
out_nkhw_desc
=
make_native_tensor_descriptor
(
OutDesc
::
GetLengths
(),
OutDesc
::
GetStrides
());
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_nkhw_desc
.
GetLength
(
I1
);
constexpr
index_t
K
=
out_nkhw_desc
.
GetLength
(
I1
);
...
@@ -51,198 +54,173 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -51,198 +54,173 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 1
#if 1
// BlockSize = 256,
E
PerBlock = 8
// BlockSize = 256,
GemmK
PerBlock = 8
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
InBlockCopySubLengths_E_B
=
Sequence
<
4
,
1
>
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
InBlockCopyClusterLengths_E_B
=
Sequence
<
2
,
128
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmK
=
4
;
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
constexpr
index_t
InBlockCopyDataPerAccess_B
=
1
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
4
,
1
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
2
,
128
>
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
1
;
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_B
=
1
;
#elif 0
#elif 0
// BlockSize = 256,
E
PerBlock = 8
// BlockSize = 256,
GemmK
PerBlock = 8
// 1x1 filter, 8x8 image
// 1x1 filter, 8x8 image
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
InBlockCopySubLengths_E_B
=
Sequence
<
1
,
4
>
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
InBlockCopyClusterLengths_E_B
=
Sequence
<
8
,
32
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmK
=
4
;
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
constexpr
index_t
InBlockCopyDataPerAccess_B
=
4
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
1
,
4
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
4
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_B
=
4
;
#elif 0
#elif 0
// BlockSize = 256,
E
PerBlock = 16
// BlockSize = 256,
GemmK
PerBlock = 16
// 1x1 filter, 8x8 image
// 1x1 filter, 8x8 image
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
16
;
constexpr
index_t
GemmKPerBlock
=
16
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
InBlockCopySubLengths_E_B
=
Sequence
<
2
,
4
>
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
2
>
;
using
InBlockCopyClusterLengths_E_B
=
Sequence
<
8
,
32
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
4
,
64
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmK
=
4
;
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
constexpr
index_t
InBlockCopyDataPerAccess_B
=
4
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
4
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
8
,
32
>
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
2
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
4
,
64
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
4
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_B
=
4
;
#elif 1
#elif 1
// 1x1 filter, 14x14 image
// 1x1 filter, 14x14 image
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
128
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
EPerBlock
=
8
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
ThreadGemmDataPerReadM
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
ThreadGemmDataPerReadN
=
4
;
using
InBlockCopySubLengths_E_B
=
Sequence
<
2
,
2
>
;
using
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
=
Sequence
<
4
,
1
>
;
using
InBlockCopyClusterLengths_E_B
=
Sequence
<
4
,
64
>
;
using
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
=
Sequence
<
2
,
128
>
;
using
InBlockCopyThreadClusterArrangeOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
using
InBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmK
=
4
;
using
InBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, B]
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmM
=
1
;
constexpr
index_t
InBlockCopyDataPerAccess_B
=
2
;
using
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
=
Sequence
<
2
,
2
>
;
using
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
=
Sequence
<
4
,
64
>
;
using
WeiBlockCopySubLengths_E_K
=
Sequence
<
4
,
1
>
;
using
WeiBlockCopyClusterLengths_E_K
=
Sequence
<
2
,
128
>
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
2
;
using
WeiBlockCopyThreadClusterArrangeOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
2
;
using
WeiBlockCopySrcAccessOrder
=
Sequence
<
1
,
0
>
;
// [K, E]
using
WeiBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
>
;
// [E, K]
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
2
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
OutThreadCopyDataPerAccess_B
=
2
;
#endif
#endif
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
constexpr
index_t
GridSize
=
constexpr
index_t
GridSize
=
((
B
+
B
PerBlock
-
1
)
/
B
PerBlock
)
*
((
K
+
K
PerBlock
-
1
)
/
K
PerBlock
);
((
B
+
GemmN
PerBlock
-
1
)
/
GemmN
PerBlock
)
*
((
K
+
GemmM
PerBlock
-
1
)
/
GemmM
PerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
constexpr
auto
gridwise_conv
=
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw
<
GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
<
GridSize
,
GridSize
,
BlockSize
,
BlockSize
,
T
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvStrides
,
ConvDilations
,
ConvDilations
,
LeftPads
,
LeftPads
,
RightPads
,
RightPads
,
BPerBlock
,
GemmMPerBlock
,
KPerBlock
,
GemmNPerBlock
,
EPerBlock
,
GemmKPerBlock
,
GemmMPerThreadSubC
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
ThreadGemmDataPerReadM
,
GemmDataPerReadB
,
ThreadGemmDataPerReadN
,
InBlockCopySubLengths_E_B
,
GemmABlockCopyThreadSliceLengths_GemmK_GemmM
,
InBlockCopyClusterLengths_E_B
,
GemmABlockCopyThreadClusterLengths_GemmK_GemmM
,
InBlockCopyThreadClusterArrangeOrder
,
GemmABlockCopySrcDataPerRead_GemmK
,
InBlockCopySrcAccessOrder
,
GemmABlockCopyDstDataPerWrite_GemmM
,
InBlockCopyDstAccessOrder
,
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN
,
InBlockCopyDataPerAccess_B
,
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN
,
WeiBlockCopySubLengths_E_K
,
GemmBBlockCopySrcDataPerRead_GemmN
,
WeiBlockCopyClusterLengths_E_K
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
WeiBlockCopyThreadClusterArrangeOrder
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
WeiBlockCopySrcAccessOrder
,
WeiBlockCopyDstAccessOrder
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
,
OutThreadCopyDataPerAccess_B
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
9359ebfd
...
@@ -21,7 +21,7 @@ int main(int argc, char* argv[])
...
@@ -21,7 +21,7 @@ int main(int argc, char* argv[])
{
{
using
namespace
ck
;
using
namespace
ck
;
#if
0
#if
1
constexpr
index_t
N
=
8
;
constexpr
index_t
N
=
8
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
16
;
constexpr
index_t
HI
=
16
;
...
...
driver/src/conv_driver.cpp
View file @
9359ebfd
...
@@ -43,7 +43,7 @@ int main(int argc, char* argv[])
...
@@ -43,7 +43,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
0
#elif
1
// 3x3, 34x34
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
...
@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
1
#elif
0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
...
@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
...
@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
0
#elif
1
// 1x7 filter, 0x3 pad, 17x17 input
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
...
@@ -403,7 +403,7 @@ int main(int argc, char* argv[])
...
@@ -403,7 +403,7 @@ int main(int argc, char* argv[])
ConvStrides
{},
ConvStrides
{},
ConvDilations
{},
ConvDilations
{},
nrepeat
);
nrepeat
);
#elif
1
#elif
0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment