Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d78fe365
Commit
d78fe365
authored
Nov 19, 2019
by
Chao Liu
Browse files
initial impl of bwd data
parent
3b3b9623
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1037 additions
and
53 deletions
+1037
-53
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...ata_implicit_gemm_v1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+401
-0
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+9
-19
composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp
...l/include/tensor_description/ConstantMatrixDescriptor.hpp
+1
-1
composable_kernel/include/tensor_description/tensor_coordinate.hpp
...e_kernel/include/tensor_description/tensor_coordinate.hpp
+2
-2
composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp
...clude/tensor_description/tensor_coordinate_deprecated.hpp
+2
-2
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
...l/include/tensor_description/tensor_descriptor_helper.hpp
+6
-6
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
...tensor_operation/threadwise_generic_tensor_slice_copy.hpp
+0
-9
composable_kernel/include/utility/in_memory_operation.amd.hpp.in
...ble_kernel/include/utility/in_memory_operation.amd.hpp.in
+8
-7
composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in
..._kernel/include/utility/in_memory_operation.nvidia.hpp.in
+8
-7
driver/CMakeLists.txt
driver/CMakeLists.txt
+4
-0
driver/include/device_convolution_bwd_data_implicit_gemm_v1_nchw_kcyx_nkhw.hpp
..._convolution_bwd_data_implicit_gemm_v1_nchw_kcyx_nkhw.hpp
+145
-0
driver/include/host_conv_bwd_data.hpp
driver/include/host_conv_bwd_data.hpp
+71
-0
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+379
-0
driver/src/conv_bwd_data_driver.cu
driver/src/conv_bwd_data_driver.cu
+1
-0
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v1_nchw_kcyx_nkhw_lds_double_buffer.hpp
0 → 100644
View file @
d78fe365
#ifndef CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#define CK_GRIDWISE_CONVOLUTION_BACKWARD_DATA_IMPLICIT_GEMM_V1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "ConstantMatrixDescriptor.hpp"
#include "blockwise_generic_tensor_slice_copy.hpp"
#include "threadwise_generic_tensor_slice_copy.hpp"
#include "blockwise_gemm.hpp"
namespace
ck
{
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
AccDataType
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
,
index_t
EPerBlock
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
GemmMPerThreadSubC
,
index_t
GemmNPerThreadSubC
,
index_t
GemmMLevel0Cluster
,
index_t
GemmNLevel0Cluster
,
index_t
GemmMLevel1Cluster
,
index_t
GemmNLevel1Cluster
,
index_t
GemmKPerThreadLoop
,
index_t
GemmDataPerReadA
,
index_t
GemmDataPerReadB
,
typename
OutBlockCopySubLengths_K_B
,
typename
OutBlockCopyClusterLengths_K_B
,
index_t
OutBlockCopyDataPerAccess_B
,
typename
WeiBlockCopySubLengths_K_E
,
typename
WeiBlockCopyClusterLengths_K_E
,
index_t
WeiBlockCopyDataPerAccess_E
,
index_t
InThreadCopyDataPerAccess_B
>
struct
GridwiseConvolutionBackwardDataImplicitGemm_v1_nchw_kcyx_nkhw_lds_double_buffer
{
__device__
void
Run
(
Float
*
const
__restrict__
p_in_global
,
const
Float
*
const
__restrict__
p_wei_global
,
const
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
N
*
Ho
*
Wo
;
// sanity-check for vectorized memory load
static_assert
((
Wo
==
1
||
(
ConvStrideW
==
1
||
InThreadCopyDataPerAccess_B
==
1
))
&&
(
X
==
1
||
ConvDilationW
%
InThreadCopyDataPerAccess_B
==
0
),
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
// lds max alignment
constexpr
index_t
max_lds_align
=
math
::
lcm
(
WeiBlockCopyDataPerAccess_E
,
OutBlockCopyDataPerAccess_B
,
GemmDataPerReadA
,
GemmDataPerReadB
);
// divide block work by [K, B]
static_assert
(
E
%
EPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
EBlockWork
=
E
/
EPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
Sequence
<
EBlockWork
,
BBlockWork
>
{});
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
e_block_data_on_global
=
block_work_id
[
0
]
*
EPerBlock
;
const
index_t
b_block_data_on_global
=
block_work_id
[
1
]
*
BPerBlock
;
// output tensor
// global tensor in global memory
constexpr
auto
out_n_k_howo_global_desc
=
unfold_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
I2
,
I3
);
// global tensor in global memory, src of blockwise copy
constexpr
auto
out_k_b_global_desc
=
transform_tensor_descriptor
(
out_n_k_howo_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
*
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
out_k_b_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
BPerBlock
>
{},
Number
<
max_lds_align
>
{});
// input tensor blockwise copy
auto
blockwise_out_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
out_k_b_global_desc
),
decltype
(
out_k_b_block_desc
),
decltype
(
out_k_b_block_desc
.
GetLengths
()),
OutBlockCopySubLengths_K_B
,
OutBlockCopyClusterLengths_K_B
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
OutBlockCopyDataPerAccess_B
,
OutBlockCopyDataPerAccess_B
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
InMemoryDataOperation
::
none
>
(
{
0
,
b_block_data_on_global
},
{
0
,
0
});
// weight tensor
// global tensor in global memory, src of blockwise copy
// It is constructed differently, depending on whether forward or backward weight
// convolution
constexpr
auto
wei_k_e_global_desc
=
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
);
// block tensor in LDS memory, dst of blockwise copy
// be careful of LDS alignment
constexpr
auto
wei_k_e_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
EPerBlock
>
{},
Number
<
max_lds_align
>
{});
// weight tensor blockwise copy
auto
blockwise_wei_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
wei_k_e_global_desc
),
decltype
(
wei_k_e_block_desc
),
decltype
(
wei_k_e_block_desc
.
GetLengths
()),
WeiBlockCopySubLengths_K_E
,
WeiBlockCopyClusterLengths_K_E
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
1
,
1
,
WeiBlockCopyDataPerAccess_E
,
WeiBlockCopyDataPerAccess_E
,
AddressSpace
::
global
,
AddressSpace
::
vgpr
,
AddressSpace
::
lds
,
InMemoryDataOperation
::
none
>
(
{
0
,
e_block_data_on_global
},
{
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, EPerBlock] is in LDS
// b_mtx[KPerBlocl, BPerBlock] is in LDS
// c_mtx[EPerBlock, BPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_e_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
wei_k_e_block_desc
);
constexpr
auto
b_k_b_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
out_k_b_block_desc
);
// sanity check
static_assert
(
EPerBlock
%
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
)
==
0
&&
BPerBlock
%
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
==
0
,
"wrong!"
);
constexpr
index_t
GemmMRepeat
=
EPerBlock
/
(
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
);
constexpr
index_t
GemmNRepeat
=
BPerBlock
/
(
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
);
// c_thread_mtx definition: this is a mess
// TODO:: more elegent way of defining c_thread_mtx
constexpr
auto
c_e0e1_b0b1_thread_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
GemmMRepeat
*
GemmMPerThreadSubC
>
{},
Number
<
GemmNRepeat
*
GemmNPerThreadSubC
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
<
BlockSize
,
decltype
(
a_k_e_block_mtx_desc
),
decltype
(
b_k_b_block_mtx_desc
),
decltype
(
c_e0e1_b0b1_thread_mtx_desc
),
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
out_block_space
=
math
::
integer_least_multiple
(
out_k_b_block_desc
.
GetElementSpace
(),
max_lds_align
);
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_k_e_block_desc
.
GetElementSpace
(),
max_lds_align
);
__shared__
Float
p_out_block_double
[
2
*
out_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
// register allocation for output
AccDataType
p_in_thread
[
c_e0e1_b0b1_thread_mtx_desc
.
GetElementSpace
()];
// zero out threadwise output
threadwise_matrix_set_zero
(
c_e0e1_b0b1_thread_mtx_desc
,
p_in_thread
);
// LDS double buffer: preload data into LDS
{
blockwise_out_copy
.
Run
(
p_out_global
,
p_out_block_double
);
blockwise_wei_copy
.
Run
(
p_wei_global
,
p_wei_block_double
);
}
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
Float
*
p_out_block_now
=
even_loop
?
p_out_block_double
:
p_out_block_double
+
out_block_space
;
Float
*
p_wei_block_now
=
even_loop
?
p_wei_block_double
:
p_wei_block_double
+
wei_block_space
;
Float
*
p_out_block_next
=
even_loop
?
p_out_block_double
+
out_block_space
:
p_out_block_double
;
Float
*
p_wei_block_next
=
even_loop
?
p_wei_block_double
+
wei_block_space
:
p_wei_block_double
;
Float
p_out_thread_buffer
[
blockwise_out_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_out_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
blockwise_out_copy
.
RunLoadThreadBuffer
(
p_out_global
,
p_out_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_out_block_now
,
p_in_thread
);
// LDS double buffer: store next data to LDS
blockwise_out_copy
.
RunStoreThreadBuffer
(
p_out_thread_buffer
,
p_out_block_next
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
Float
p_out_thread_buffer
[
blockwise_out_copy
.
GetThreadBufferSize
()];
Float
p_wei_thread_buffer
[
blockwise_wei_copy
.
GetThreadBufferSize
()];
blockwise_out_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
blockwise_wei_copy
.
MoveSrcSliceWindow
(
Sequence
<
KPerBlock
,
0
>
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
blockwise_out_copy
.
RunLoadThreadBuffer
(
p_out_global
,
p_out_thread_buffer
);
blockwise_wei_copy
.
RunLoadThreadBuffer
(
p_wei_global
,
p_wei_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_out_block_double
,
p_in_thread
);
// LDS double buffer: store last data to LDS
blockwise_out_copy
.
RunStoreThreadBuffer
(
p_out_thread_buffer
,
p_out_block_double
+
out_block_space
);
blockwise_wei_copy
.
RunStoreThreadBuffer
(
p_wei_thread_buffer
,
p_wei_block_double
+
wei_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_out_block_double
+
out_block_space
,
p_in_thread
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_out_block_double
,
p_in_thread
);
}
}
// input: register to global memory, atomic add
{
constexpr
index_t
E1
=
GemmMPerThreadSubC
*
GemmMLevel0Cluster
*
GemmMLevel1Cluster
;
constexpr
index_t
E0
=
E
/
E1
;
constexpr
index_t
B1
=
GemmNPerThreadSubC
*
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
B0
=
B
/
B1
;
// define input tensor descriptor for threadwise copy
// thread input tensor, src of threadwise copy
constexpr
auto
in_e0_e1_b0_b1_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
GemmMRepeat
,
GemmMPerThreadSubC
,
GemmNRepeat
,
GemmNPerThreadSubC
>
{});
// global input tensor, dst of threadwise copy
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
LeftPads
,
RightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_e_b_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
constexpr
auto
in_e0_e1_b0_b1_global_desc
=
transform_tensor_descriptor
(
in_e_b_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
E0
,
E1
>>
{},
UnMerge
<
Sequence
<
B0
,
B1
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
// calculate origin of thread input tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
const
index_t
e_thread_data_on_global
=
e_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
in_e0_e1_b0_b1_thread_desc
),
decltype
(
in_e0_e1_b0_b1_global_desc
),
decltype
(
in_e0_e1_b0_b1_thread_desc
.
GetLengths
()),
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
InThreadCopyDataPerAccess_B
,
InThreadCopyDataPerAccess_B
,
AddressSpace
::
vgpr
,
AddressSpace
::
global
,
InMemoryDataOperation
::
atomic_add
>
(
{
0
,
0
,
0
,
0
},
{
e_thread_data_on_global
/
E1
,
e_thread_data_on_global
%
E1
,
b_thread_data_on_global
/
B1
,
b_thread_data_on_global
%
B1
})
.
Run
(
p_in_thread
,
p_in_global
);
}
}
};
}
// namespace ck
#endif
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
d78fe365
...
...
@@ -107,16 +107,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
global_address_space
=
integral_constant
<
AddressSpace
,
AddressSpace
::
global
>
{};
constexpr
auto
lds_address_space
=
integral_constant
<
AddressSpace
,
AddressSpace
::
lds
>
{};
constexpr
auto
vgpr_address_space
=
integral_constant
<
AddressSpace
,
AddressSpace
::
vgpr
>
{};
constexpr
auto
no_inmem_op
=
integral_constant
<
InMemoryDataOperation
,
InMemoryDataOperation
::
none
>
{};
static_assert
(
ConvDirection
==
ConvolutionDirection
::
Forward
||
ConvDirection
==
ConvolutionDirection
::
BackwardWeight
,
"wrong! this kernel only support convolution forward and backward-weight"
);
...
...
@@ -135,17 +125,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I0
)
;
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I1
)
;
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I2
)
;
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
(
I3
)
;
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLength
s
()[
0
]
;
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLength
s
()[
1
]
;
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLength
s
()[
2
]
;
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLength
s
()[
3
]
;
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I1
)
;
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I2
)
;
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
(
I3
)
;
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLength
s
()[
1
]
;
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLength
s
()[
2
]
;
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLength
s
()[
3
]
;
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
(
I2
)
;
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
(
I3
)
;
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLength
s
()[
2
]
;
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLength
s
()[
3
]
;
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
...
...
composable_kernel/include/tensor_description/ConstantMatrixDescriptor.hpp
View file @
d78fe365
...
...
@@ -60,7 +60,7 @@ __host__ __device__ constexpr auto
template
<
typename
...
Ts
>
__host__
__device__
constexpr
auto
make_ConstantMatrixDescriptor
(
ConstantTensorDescriptor_deprecated
<
Ts
...
>
)
make_ConstantMatrixDescriptor
(
ConstantTensorDescriptor_deprecated
<
Ts
...
>
)
{
using
TDesc
=
ConstantTensorDescriptor_deprecated
<
Ts
...
>
;
static_assert
(
TDesc
::
GetNumOfDimension
()
==
2
,
"wrong"
);
...
...
composable_kernel/include/tensor_description/tensor_coordinate.hpp
View file @
d78fe365
...
...
@@ -228,7 +228,7 @@ struct TensorCoordinate
private:
template
<
typename
...
Ts
>
__host__
__device__
static
constexpr
auto
MakeDummyTensorCoordinate
(
NativeTensorDescriptor
<
Ts
...
>
)
MakeDummyTensorCoordinate
(
NativeTensorDescriptor
<
Ts
...
>
)
{
return
NativeTensorCoordinate
<
NativeTensorDescriptor
<
Ts
...
>>
(
make_zero_array
<
index_t
,
TensorDesc
::
GetNumOfDimension
()
>
());
...
...
@@ -236,7 +236,7 @@ struct TensorCoordinate
template
<
typename
...
Ts
>
__host__
__device__
static
constexpr
auto
MakeDummyTensorCoordinate
(
TransformedTensorDescriptor
<
Ts
...
>
)
MakeDummyTensorCoordinate
(
TransformedTensorDescriptor
<
Ts
...
>
)
{
return
TransformedTensorCoordinate
<
TransformedTensorDescriptor
<
Ts
...
>>
(
make_zero_array
<
index_t
,
TensorDesc
::
GetNumOfDimension
()
>
());
...
...
composable_kernel/include/tensor_description/tensor_coordinate_deprecated.hpp
View file @
d78fe365
...
...
@@ -327,14 +327,14 @@ struct TensorCoordinate_deprecated
private:
template
<
class
...
Ts
>
__host__
__device__
static
constexpr
auto
MakeDummyTensorCoordinate
(
ConstantTensorDescriptor_deprecated
<
Ts
...
>
)
MakeDummyTensorCoordinate
(
ConstantTensorDescriptor_deprecated
<
Ts
...
>
)
{
return
NormalTensorCoordinate_deprecated
<
ConstantTensorDescriptor_deprecated
<
Ts
...
>>
();
}
template
<
class
...
Ts
>
__host__
__device__
static
constexpr
auto
MakeDummyTensorCoordinate
(
ConstantMergedTensorDescriptor_deprecated
<
Ts
...
>
)
MakeDummyTensorCoordinate
(
ConstantMergedTensorDescriptor_deprecated
<
Ts
...
>
)
{
return
MergedTensorCoordinate_deprecated
<
ConstantMergedTensorDescriptor_deprecated
<
Ts
...
>>
();
...
...
composable_kernel/include/tensor_description/tensor_descriptor_helper.hpp
View file @
d78fe365
...
...
@@ -64,10 +64,10 @@ template <typename LowerTensorDescriptor,
index_t
...
LowerDimensionIds
,
index_t
...
UpperDimensionIds
>
__host__
__device__
constexpr
auto
reorder_transformed_tensor_descriptor_impl
(
LowerTensorDescriptor
,
Sequence
<
LowerLengths
...
>
,
Sequence
<
LowerDimensionIds
...
>
,
Sequence
<
UpperDimensionIds
...
>
)
reorder_transformed_tensor_descriptor_impl
(
LowerTensorDescriptor
,
Sequence
<
LowerLengths
...
>
,
Sequence
<
LowerDimensionIds
...
>
,
Sequence
<
UpperDimensionIds
...
>
)
{
return
TransformedTensorDescriptor
<
LowerTensorDescriptor
,
Tuple
<
PassThrough
<
LowerLengths
>
...
>
,
...
...
@@ -78,7 +78,7 @@ __host__ __device__ constexpr auto
// reorder a NativeTensorDescriptor
template
<
typename
...
Ts
,
typename
MapLower2Upper
>
__host__
__device__
constexpr
auto
reorder_tensor_descriptor_given_lower2upper
(
NativeTensorDescriptor
<
Ts
...
>
,
MapLower2Upper
)
reorder_tensor_descriptor_given_lower2upper
(
NativeTensorDescriptor
<
Ts
...
>
,
MapLower2Upper
)
{
static_assert
(
is_valid_sequence_map
<
MapLower2Upper
>
{},
"wrong! MapLower2Upper is not a valid map"
);
...
...
@@ -96,7 +96,7 @@ __host__ __device__ constexpr auto
// reorder a TransformedTensorDescriptor
template
<
typename
...
Ts
,
typename
MapLower2Upper
>
__host__
__device__
constexpr
auto
reorder_tensor_descriptor_given_lower2upper
(
TransformedTensorDescriptor
<
Ts
...
>
,
MapLower2Upper
)
reorder_tensor_descriptor_given_lower2upper
(
TransformedTensorDescriptor
<
Ts
...
>
,
MapLower2Upper
)
{
static_assert
(
is_valid_sequence_map
<
MapLower2Upper
>
{},
"wrong! MapLower2Upper is not a valid map"
);
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp
View file @
d78fe365
...
...
@@ -72,9 +72,6 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
template
<
typename
SrcData
,
typename
DstData
>
__device__
void
Run
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
using
src_vector_t
=
typename
vector_type
<
SrcData
,
SrcDataPerAccess
>::
MemoryType
;
using
dst_vector_t
=
typename
vector_type
<
DstData
,
DstDataPerAccess
>::
MemoryType
;
constexpr
auto
vector_access_dim
=
Number
<
VectorAccessDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerAccess
>
{};
...
...
@@ -176,9 +173,6 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
__device__
void
Run_optimized_src_address_calculation
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
using
src_vector_t
=
typename
vector_type
<
SrcData
,
SrcDataPerAccess
>::
MemoryType
;
using
dst_vector_t
=
typename
vector_type
<
DstData
,
DstDataPerAccess
>::
MemoryType
;
constexpr
auto
vector_access_dim
=
Number
<
VectorAccessDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerAccess
>
{};
...
...
@@ -327,9 +321,6 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
__device__
void
Run_optimized_dst_address_calculation
(
const
SrcData
*
p_src
,
DstData
*
p_dst
)
const
{
using
src_vector_t
=
typename
vector_type
<
SrcData
,
SrcDataPerAccess
>::
MemoryType
;
using
dst_vector_t
=
typename
vector_type
<
DstData
,
DstDataPerAccess
>::
MemoryType
;
constexpr
auto
vector_access_dim
=
Number
<
VectorAccessDim
>
{};
constexpr
auto
src_data_per_access
=
Number
<
SrcDataPerAccess
>
{};
...
...
composable_kernel/include/utility/in_memory_operation.amd.hpp.in
View file @
d78fe365
...
...
@@ -50,13 +50,14 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
{
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
static_if<SrcAddressSpace == AddressSpace::vgpr &&
DstAddressSpace == AddressSpace::global>{}([&](auto) {
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
}).Else([&](auto fwd) {
static_assert(fwd(false), "atomic_add doesn't support this memory space");
});
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}(
[&](auto) {
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
})
.Else([&](auto fwd) {
static_assert(fwd(false), "atomic_add doesn't support this memory space");
});
}
template <typename T,
...
...
composable_kernel/include/utility/in_memory_operation.nvidia.hpp.in
View file @
d78fe365
...
...
@@ -23,13 +23,14 @@ __device__ void atomic_add_data(const T* p_src, index_t src_offset, T* p_dst, in
{
using vector_t = typename vector_type<T, DataPerAccess>::MemoryType;
static_if<SrcAddressSpace == AddressSpace::vgpr &&
DstAddressSpace == AddressSpace::global>{}([&](auto) {
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
}).Else([&](auto fwd) {
static_assert(fwd(false), "atomic_add doesn't support this memory space");
});
static_if<SrcAddressSpace == AddressSpace::vgpr && DstAddressSpace == AddressSpace::global>{}(
[&](auto) {
atomicAdd(reinterpret_cast<vector_t*>(&p_dst[dst_offset]),
*reinterpret_cast<const vector_t*>(&p_src[src_offset]));
})
.Else([&](auto fwd) {
static_assert(fwd(false), "atomic_add doesn't support this memory space");
});
}
template <typename T,
...
...
driver/CMakeLists.txt
View file @
d78fe365
...
...
@@ -17,12 +17,16 @@ install(TARGETS host LIBRARY DESTINATION lib)
if
(
DEVICE_BACKEND STREQUAL
"AMD"
)
set
(
CONV_SOURCE src/conv_driver.cpp
)
set
(
COL2IM_SOURCE src/col2im_driver.cpp
)
set
(
CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cpp
)
elseif
(
DEVICE_BACKEND STREQUAL
"NVIDIA"
)
set
(
CONV_SOURCE src/conv_driver.cu
)
set
(
COL2IM_SOURCE src/col2im_driver.cu
)
set
(
CONV_BWD_DATA_SOURCE src/conv_bwd_data_driver.cu
)
endif
()
add_executable
(
conv
${
CONV_SOURCE
}
)
add_executable
(
col2im
${
COL2IM_SOURCE
}
)
add_executable
(
conv_bwd_data
${
CONV_BWD_DATA_SOURCE
}
)
target_link_libraries
(
conv PRIVATE host
)
target_link_libraries
(
col2im PRIVATE host
)
target_link_libraries
(
conv_bwd_data PRIVATE host
)
driver/include/device_convolution_bwd_data_implicit_gemm_v1_nchw_kcyx_nkhw.hpp
0 → 100644
View file @
d78fe365
#pragma once
#include <unistd.h>
#include "device.hpp"
#include "tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_backward_data_implicit_gemm_v1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
template
<
typename
T
,
typename
InDesc
,
typename
WeiDesc
,
typename
OutDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
>
void
device_convolution_bwd_data_implicit_gemm_v1_nchw_kcyx_nkhw
(
InDesc
in_nchw_desc
,
Tensor
<
T
>&
in_nchw
,
WeiDesc
wei_kcyx_desc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
out_nkhw_desc
,
const
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
LeftPads
,
RightPads
,
std
::
size_t
nrepeat
)
{
using
namespace
ck
;
constexpr
index_t
N
=
out_nkhw_desc
.
GetLengths
()[
0
];
constexpr
index_t
K
=
out_nkhw_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLengths
()[
3
];
constexpr
index_t
C
=
wei_kcyx_desc
.
GetLengths
()[
1
];
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLengths
()[
3
];
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if 1
// BlockSize = 256, each thread hold 64 data
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
EPerBlock
=
128
;
constexpr
index_t
BPerBlock
=
128
;
constexpr
index_t
KPerBlock
=
8
;
constexpr
index_t
GemmMPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
using
OutBlockCopySubLengths_K_B
=
Sequence
<
4
,
1
>
;
using
OutBlockCopyClusterLengths_K_B
=
Sequence
<
2
,
128
>
;
constexpr
index_t
OutBlockCopyDataPerAccess_B
=
1
;
using
WeiBlockCopySubLengths_K_E
=
Sequence
<
1
,
4
>
;
using
WeiBlockCopyClusterLengths_K_E
=
Sequence
<
8
,
32
>
;
constexpr
index_t
WeiBlockCopyDataPerAccess_E
=
4
;
constexpr
index_t
InThreadCopyDataPerAccess_B
=
1
;
#endif
constexpr
index_t
E
=
C
*
Y
*
X
;
constexpr
index_t
B
=
(
N
*
Ho
*
Wo
);
constexpr
index_t
GridSize
=
((
E
+
EPerBlock
-
1
)
/
EPerBlock
)
*
((
B
+
BPerBlock
-
1
)
/
BPerBlock
);
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
constexpr
auto
gridwise_conv
=
GridwiseConvolutionBackwardDataImplicitGemm_v1_nchw_kcyx_nkhw_lds_double_buffer
<
GridSize
,
BlockSize
,
T
,
T
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
LeftPads
,
RightPads
,
EPerBlock
,
BPerBlock
,
KPerBlock
,
GemmMPerThreadSubC
,
GemmNPerThreadSubC
,
GemmMLevel0Cluster
,
GemmNLevel0Cluster
,
GemmMLevel1Cluster
,
GemmNLevel1Cluster
,
GemmKPerThreadLoop
,
GemmDataPerReadA
,
GemmDataPerReadB
,
OutBlockCopySubLengths_K_B
,
OutBlockCopyClusterLengths_K_B
,
OutBlockCopyDataPerAccess_B
,
WeiBlockCopySubLengths_K_E
,
WeiBlockCopyClusterLengths_K_E
,
WeiBlockCopyDataPerAccess_E
,
InThreadCopyDataPerAccess_B
>
{};
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
float
time
=
launch_kernel
(
run_gridwise_operation
<
decltype
(
gridwise_conv
),
T
*
const
__restrict__
,
const
T
*
const
__restrict__
,
const
T
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
gridwise_conv
,
const_cast
<
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
())),
const_cast
<
const
T
*
const
__restrict__
>
(
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
())));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
}
in_nchw_device_buf
.
FromDevice
(
in_nchw
.
mData
.
data
());
}
driver/include/host_conv_bwd_data.hpp
0 → 100644
View file @
d78fe365
#pragma once
#include "tensor.hpp"
template
<
typename
TIn
,
typename
TWei
,
typename
TOut
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
LeftPads
,
typename
RightPads
>
void
host_direct_convolution_bwd_data
(
Tensor
<
TIn
>&
in_nchw
,
const
Tensor
<
TWei
>&
wei_kcyx
,
const
Tensor
<
TOut
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
LeftPads
,
RightPads
)
{
using
namespace
ck
;
int
N
=
in_nchw
.
mDesc
.
GetLengths
()[
0
];
int
C
=
in_nchw
.
mDesc
.
GetLengths
()[
1
];
int
HI
=
in_nchw
.
mDesc
.
GetLengths
()[
2
];
int
WI
=
in_nchw
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
K
=
wei_kcyx
.
mDesc
.
GetLengths
()[
0
];
std
::
size_t
Y
=
wei_kcyx
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
X
=
wei_kcyx
.
mDesc
.
GetLengths
()[
3
];
std
::
size_t
HO
=
out_nkhw
.
mDesc
.
GetLengths
()[
2
];
std
::
size_t
WO
=
out_nkhw
.
mDesc
.
GetLengths
()[
3
];
auto
f
=
[
&
](
auto
n
,
auto
c
,
auto
hi
,
auto
wi
)
{
double
v
=
0
;
for
(
int
y
=
0
;
y
<
Y
;
++
y
)
{
int
h_tmp
=
hi
+
LeftPads
{}[
0
]
-
y
*
ConvDilations
{}[
0
];
if
(
h_tmp
>=
0
&&
h_tmp
<
HI
&&
h_tmp
%
ConvStrides
{}[
0
]
==
0
)
{
int
ho
=
h_tmp
/
ConvStrides
{}[
0
];
for
(
int
x
=
0
;
x
<
X
;
++
x
)
{
int
w_tmp
=
wi
+
LeftPads
{}[
1
]
-
x
*
ConvDilations
{}[
1
];
if
(
w_tmp
>=
0
&&
w_tmp
<
WI
&&
w_tmp
%
ConvStrides
{}[
1
]
==
0
)
{
int
wo
=
w_tmp
/
ConvStrides
{}[
1
];
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
out_nkhw
(
n
,
k
,
ho
,
wo
)
*
wei_kcyx
(
k
,
c
,
y
,
x
);
}
}
}
}
}
in_nchw
(
n
,
c
,
hi
,
wi
)
=
v
;
};
auto
f_par
=
make_ParallelTensorFunctor
(
f
,
in_nchw
.
mDesc
.
GetLengths
()[
0
],
in_nchw
.
mDesc
.
GetLengths
()[
1
],
in_nchw
.
mDesc
.
GetLengths
()[
2
],
in_nchw
.
mDesc
.
GetLengths
()[
3
]);
f_par
(
std
::
thread
::
hardware_concurrency
());
}
driver/src/conv_bwd_data_driver.cpp
0 → 100644
View file @
d78fe365
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include "config.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "print_array.hpp"
#include "print_sequence.hpp"
#include "device.hpp"
#include "tensor_generator.hpp"
#include "device_tensor.hpp"
#include "conv_common.hpp"
#include "host_conv_bwd_data.hpp"
#include "device_convolution_bwd_data_implicit_gemm_v1_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
;
#if 0
constexpr index_t N = 2;
constexpr index_t C = 8;
constexpr index_t HI = 8;
constexpr index_t WI = 8;
constexpr index_t K = 128;
constexpr index_t Y = 4;
constexpr index_t X = 4;
using ConvStrides = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<2, 2>;
#elif
0
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
34
;
constexpr
index_t
WI
=
34
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
2048
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
832
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1280
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 14x14 image
// cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 8x8 image
// cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61%
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 28x28 image
// cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
832
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 17x17 input
// cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
768
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
528
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
528
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
WI
=
14
;
constexpr
index_t
K
=
256
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
832
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
288
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
384
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
// 5x5 filter, 2x2 pad, 7x7 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
48
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
WI
=
7
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
5
;
constexpr
index_t
X
=
5
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
2
,
2
>
;
using
RightPads
=
Sequence
<
2
,
2
>
;
#elif 0
// 7x1 filter, 3x0 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
7
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif 1
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
WI
=
17
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
1
;
constexpr
index_t
X
=
7
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#endif
constexpr
auto
in_nchw_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
constexpr
auto
wei_kcyx_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
constexpr
auto
out_nkhw_desc
=
get_convolution_output_default_4d_tensor_descriptor
(
in_nchw_desc
,
wei_kcyx_desc
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{});
ostream_ConstantTensorDescriptor
(
in_nchw_desc
,
std
::
cout
<<
"in_nchw_desc: "
);
ostream_ConstantTensorDescriptor
(
wei_kcyx_desc
,
std
::
cout
<<
"wei_kcyx_desc: "
);
ostream_ConstantTensorDescriptor
(
out_nkhw_desc
,
std
::
cout
<<
"out_nkhw_desc: "
);
print_sequence
(
"LeftPads"
,
LeftPads
{});
print_sequence
(
"LeftPads"
,
LeftPads
{});
print_sequence
(
"RightPads"
,
RightPads
{});
print_sequence
(
"ConvStrides"
,
ConvStrides
{});
print_sequence
(
"ConvDilations"
,
ConvDilations
{});
Tensor
<
float
>
in_nchw_device
(
make_TensorDescriptor
(
in_nchw_desc
));
Tensor
<
float
>
in_nchw_host
(
make_TensorDescriptor
(
in_nchw_desc
));
Tensor
<
float
>
wei_kcyx
(
make_TensorDescriptor
(
wei_kcyx_desc
));
Tensor
<
float
>
out_nkhw
(
make_TensorDescriptor
(
out_nkhw_desc
));
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
if
(
argc
!=
3
)
{
printf
(
"arg1: do_verification, arg2: nrepeat
\n
"
);
exit
(
1
);
}
bool
do_verification
=
atoi
(
argv
[
1
]);
std
::
size_t
nrepeat
=
atoi
(
argv
[
2
]);
if
(
do_verification
)
{
#if 0
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#else
out_nkhw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
#endif
}
#if 1
device_convolution_bwd_data_implicit_gemm_v1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw_device
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#endif
if
(
do_verification
)
{
host_direct_convolution_bwd_data
(
in_nchw_host
,
wei_kcyx
,
out_nkhw
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{});
check_error
(
in_nchw_host
,
in_nchw_device
);
#if 0
LogRange(std::cout << "col_eb : ", col_eb.mData, ",") << std::endl;
LogRange(std::cout << "img_nchw_host : ", img_nchw_host.mData, ",") << std::endl;
LogRange(std::cout << "img_nchw_device : ", img_nchw_device.mData, ",") << std::endl;
#endif
}
}
driver/src/conv_bwd_data_driver.cu
0 → 120000
View file @
d78fe365
conv_bwd_data_driver
.
cpp
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment