Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bd4c668e
Commit
bd4c668e
authored
Jul 02, 2020
by
Chao Liu
Browse files
experiment dummy static transform
parent
435f5f91
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
289 additions
and
53 deletions
+289
-53
composable_kernel/include/kernel_algorithm/dummy_static_transform.hpp
...ernel/include/kernel_algorithm/dummy_static_transform.hpp
+129
-0
composable_kernel/include/utility/amd_buffer_addressing.hpp
composable_kernel/include/utility/amd_buffer_addressing.hpp
+15
-0
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+2
-2
driver/include/device_dummy_transform.hpp
driver/include/device_dummy_transform.hpp
+97
-0
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+13
-13
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+33
-38
No files found.
composable_kernel/include/kernel_algorithm/dummy_static_transform.hpp
0 → 100644
View file @
bd4c668e
#ifndef CK_DUMMY_STATIC_TRANSFORM_HPP
#define CK_DUMMY_STATIC_TRANSFORM_HPP
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
namespace
ck
{
// GemmM = K
// GemmN = N * Ho * Wo
// GemmK = C * Y * X
template
<
index_t
GridSize
,
index_t
BlockSize
,
typename
Float
,
typename
InGlobalDesc
,
typename
WeiGlobalDesc
,
typename
OutGlobalDesc
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InRightPads
>
struct
DummyStaticTransform
{
__device__
void
Run
(
Float
*
const
__restrict__
p_in_global
,
Float
*
const
__restrict__
p_wei_global
,
Float
*
const
__restrict__
p_out_global
)
const
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_n_c_hi_wi_global_desc
=
InGlobalDesc
{};
constexpr
auto
wei_k_c_y_x_global_desc
=
WeiGlobalDesc
{};
constexpr
auto
out_n_k_ho_wo_global_desc
=
OutGlobalDesc
{};
constexpr
index_t
N
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
0
];
constexpr
index_t
C
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Hi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wi
=
in_n_c_hi_wi_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
K
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
1
];
constexpr
index_t
Ho
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wo
=
out_n_k_ho_wo_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
Y
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_k_c_y_x_global_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
// weight tensor
constexpr
auto
wei_gemmk_gemmm_global_desc
=
reorder_tensor_descriptor_given_upper2lower
(
unfold_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
I1
,
I3
),
Sequence
<
1
,
0
>
{});
// input tensor
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
InLeftPads
,
InRightPads
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
index_t
Hip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
2
];
constexpr
index_t
Wip
=
in_n_c_hip_wip_global_desc
.
GetLengths
()[
3
];
constexpr
auto
in_n_c_y_ho_x_wo_global_desc
=
transform_tensor_descriptor
(
in_n_c_hip_wip_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
Embed
<
Hip
,
Sequence
<
Y
,
Ho
>
,
Sequence
<
ConvDilationH
,
ConvStrideH
,
0
>>
{},
Embed
<
Wip
,
Sequence
<
X
,
Wo
>
,
Sequence
<
ConvDilationW
,
ConvStrideW
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
constexpr
auto
in_gemmk_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_y_ho_x_wo_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Y
,
X
>>
{},
Merge
<
Sequence
<
N
,
Ho
,
Wo
>>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
constexpr
auto
out_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
unfold_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
I2
,
I3
),
make_tuple
(
PassThrough
<
K
>
{},
Merge
<
Sequence
<
N
,
Ho
*
Wo
>>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// input
const
index_t
k0
=
p_in_global
[
get_thread_local_1d_id
()];
const
index_t
n0
=
p_in_global
[
get_thread_local_1d_id
()];
auto
coord
=
typename
TensorCoordinate
<
decltype
(
in_gemmk_gemmn_global_desc
)
>::
type
(
k0
,
n0
);
if
(
get_block_1d_id
()
<
coord
.
GetOffset
())
{
for
(
index_t
k
=
0
;
k
<
1
;
++
k
)
{
for
(
index_t
n
=
0
;
n
<
4
;
++
n
)
{
auto
tmp
=
coord
+
Array
<
index_t
,
2
>
{
k
,
n
};
Float
value
=
1
;
transfer_data
<
Float
,
1
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
InMemoryDataOperation
::
Set
,
1
,
1
>
(
&
value
,
0
,
true
,
1
,
p_in_global
,
tmp
.
GetOffset
(),
tmp
.
IsOffsetValidAssumingUpperIndexIsValid
(),
in_gemmk_gemmn_global_desc
.
GetElementSpace
());
}
}
}
}
};
}
// namespace ck
#endif
composable_kernel/include/utility/amd_buffer_addressing.hpp
View file @
bd4c668e
...
@@ -196,6 +196,7 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave,
...
@@ -196,6 +196,7 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave,
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
float
);
index_t
src_thread_addr_offset
=
src_thread_data_offset
*
sizeof
(
float
);
#if 1 // debug
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
return
__llvm_amdgcn_buffer_load_f32
(
src_wave_buffer_resource
.
data
,
return
__llvm_amdgcn_buffer_load_f32
(
src_wave_buffer_resource
.
data
,
0
,
0
,
...
@@ -209,6 +210,12 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave,
...
@@ -209,6 +210,12 @@ __device__ float amd_buffer_load<float, 1>(const float* p_src_wave,
return
__llvm_amdgcn_buffer_load_f32
(
return
__llvm_amdgcn_buffer_load_f32
(
src_wave_buffer_resource
.
data
,
0
,
src_addr_base
+
src_thread_addr_offset
,
false
,
false
);
src_wave_buffer_resource
.
data
,
0
,
src_addr_base
+
src_thread_addr_offset
,
false
,
false
);
#endif
#endif
#else
return
src_thread_data_valid
?
__llvm_amdgcn_buffer_load_f32
(
src_wave_buffer_resource
.
data
,
0
,
src_thread_addr_offset
,
false
,
false
)
:
0
;
#endif
}
}
template
<
>
template
<
>
...
@@ -570,6 +577,7 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src_thread,
...
@@ -570,6 +577,7 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src_thread,
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
float
);
index_t
dst_thread_addr_offset
=
dst_thread_data_offset
*
sizeof
(
float
);
#if 1 // debug
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
#if !CK_EXPERIMENTAL_AMD_BUFFER_ADDRESSING_USE_OFFSET_TRICK
__llvm_amdgcn_buffer_store_f32
(
*
p_src_thread
,
__llvm_amdgcn_buffer_store_f32
(
*
p_src_thread
,
dst_wave_buffer_resource
.
data
,
dst_wave_buffer_resource
.
data
,
...
@@ -587,6 +595,13 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src_thread,
...
@@ -587,6 +595,13 @@ __device__ void amd_buffer_store<float, 1>(const float* p_src_thread,
false
,
false
,
false
);
false
);
#endif
#endif
#else
if
(
dst_thread_data_valid
)
{
__llvm_amdgcn_buffer_store_f32
(
*
p_src_thread
,
dst_wave_buffer_resource
.
data
,
0
,
dst_thread_addr_offset
,
false
,
false
);
}
#endif
}
}
template
<
>
template
<
>
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
bd4c668e
...
@@ -133,7 +133,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -133,7 +133,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
2
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
2
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif
1
#elif
0
// cdata = 64, BlockSize = 256, 128x128x8
// cdata = 64, BlockSize = 256, 128x128x8
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -290,7 +290,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -290,7 +290,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
2
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
2
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif
0
#elif
1
// cdata = 64, BlockSize = 128, 64x128x8
// cdata = 64, BlockSize = 128, 64x128x8
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
...
...
driver/include/device_dummy_transform.hpp
0 → 100644
View file @
bd4c668e
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "dummy_static_transform.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
>
void
device_dummy_transform
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
in_nchw_desc
=
make_native_tensor_descriptor
(
InDesc
::
GetLengths
(),
InDesc
::
GetStrides
());
constexpr
auto
wei_kcyx_desc
=
make_native_tensor_descriptor
(
WeiDesc
::
GetLengths
(),
WeiDesc
::
GetStrides
());
constexpr
auto
out_nkhw_desc
=
make_native_tensor_descriptor
(
OutDesc
::
GetLengths
(),
OutDesc
::
GetStrides
());
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_nkhw_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GridSize
=
1
;
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
using
dummy_transform
=
DummyStaticTransform
<
GridSize
,
BlockSize
,
float
,
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
>
;
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
launch_kernel
(
run_gridwise_operation
<
dummy_transform
,
float
*
const
__restrict__
,
float
*
const
__restrict__
,
float
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
float
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/src/conv_bwd_data_driver.cpp
View file @
bd4c668e
...
@@ -52,7 +52,7 @@ int main(int argc, char* argv[])
...
@@ -52,7 +52,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 3x3, 28x28
// 3x3, 28x28
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
C
=
1024
;
...
@@ -245,7 +245,7 @@ int main(int argc, char* argv[])
...
@@ -245,7 +245,7 @@ int main(int argc, char* argv[])
#endif
#endif
}
}
#if
0
#if
1
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw
#elif 1
#elif 1
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw
...
@@ -256,17 +256,17 @@ int main(int argc, char* argv[])
...
@@ -256,17 +256,17 @@ int main(int argc, char* argv[])
#elif 1
#elif 1
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw
#endif
#endif
(
in_nchw_desc
,
(
in_nchw_desc
,
in_nchw_device
,
in_nchw_device
,
wei_kcyx_desc
,
wei_kcyx_desc
,
wei_kcyx
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_desc
,
out_nkhw
,
out_nkhw
,
ConvStrides
{},
ConvStrides
{},
ConvDilations
{},
ConvDilations
{},
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
if
(
do_verification
)
if
(
do_verification
)
{
{
...
...
driver/src/conv_driver.cpp
View file @
bd4c668e
...
@@ -14,26 +14,27 @@
...
@@ -14,26 +14,27 @@
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dummy_transform.hpp"
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
using
namespace
ck
;
using
namespace
ck
;
#if 0
#if 0
//
1x1, 17x17
//
3x3, 71x71
constexpr index_t N = 128;
constexpr index_t N = 128;
constexpr index_t C = 1
024
;
constexpr index_t C = 1
92
;
constexpr index_t HI =
1
7;
constexpr index_t HI = 7
1
;
constexpr index_t WI =
1
7;
constexpr index_t WI = 7
1
;
constexpr index_t K =
256
;
constexpr index_t K =
128
;
constexpr index_t Y =
1
;
constexpr index_t Y =
3
;
constexpr index_t X =
1
;
constexpr index_t X =
3
;
using ConvStrides = Sequence<
1
,
1
>;
using ConvStrides = Sequence<
2
,
2
>;
using ConvDilations = Sequence<1, 1>;
using ConvDilations = Sequence<1, 1>;
using LeftPads = Sequence<
0
,
0
>;
using LeftPads = Sequence<
1
,
1
>;
using RightPads = Sequence<
0
,
0
>;
using RightPads = Sequence<
1
,
1
>;
#elif
0
#elif
0
// 1x1, 8x8
// 1x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
...
@@ -109,7 +110,7 @@ int main(int argc, char* argv[])
...
@@ -109,7 +110,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
1
#elif
0
// 1x7, 17x17
// 1x7, 17x17
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
...
@@ -141,7 +142,6 @@ int main(int argc, char* argv[])
...
@@ -141,7 +142,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 3x3, 147x147
// 3x3, 147x147
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
C
=
32
;
constexpr
index_t
HI
=
147
;
constexpr
index_t
HI
=
147
;
...
@@ -157,7 +157,6 @@ int main(int argc, char* argv[])
...
@@ -157,7 +157,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif 0
#elif 0
// 3x3, 149x149
// 3x3, 149x149
// v4r4@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
32
;
constexpr
index_t
C
=
32
;
constexpr
index_t
HI
=
149
;
constexpr
index_t
HI
=
149
;
...
@@ -201,7 +200,7 @@ int main(int argc, char* argv[])
...
@@ -201,7 +200,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
1
#elif
0
// 3x3, 35x35, stride 2
// 3x3, 35x35, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
288
;
constexpr
index_t
C
=
288
;
...
@@ -244,21 +243,6 @@ int main(int argc, char* argv[])
...
@@ -244,21 +243,6 @@ int main(int argc, char* argv[])
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
#elif 0
// 3x1, 8x8
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
448
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
WI
=
8
;
constexpr
index_t
K
=
512
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
1
;
using
ConvStrides
=
Sequence
<
1
,
1
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
#elif 0
#elif 0
...
@@ -278,7 +262,6 @@ int main(int argc, char* argv[])
...
@@ -278,7 +262,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 7x1, 73x73
// 7x1, 73x73
// v44@v100 xx.xx%, cudnn@v100 xx.xx%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
73
;
constexpr
index_t
HI
=
73
;
...
@@ -352,10 +335,10 @@ int main(int argc, char* argv[])
...
@@ -352,10 +335,10 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 3x3, 28x28
// 3x3, 28x28
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
12
8
;
constexpr
index_t
C
=
1
9
2
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
128
;
constexpr
index_t
K
=
128
;
...
@@ -367,7 +350,7 @@ int main(int argc, char* argv[])
...
@@ -367,7 +350,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
1
#elif
0
// 3x3, 14x14
// 3x3, 14x14
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -382,7 +365,7 @@ int main(int argc, char* argv[])
...
@@ -382,7 +365,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
1
#elif
0
// 1x1, 56x56, stride 2
// 1x1, 56x56, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -472,7 +455,7 @@ int main(int argc, char* argv[])
...
@@ -472,7 +455,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
LeftPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
using
RightPads
=
Sequence
<
1
,
1
>
;
#elif
1
#elif
0
// 1x1, 56x56
// 1x1, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
C
=
64
;
...
@@ -487,7 +470,7 @@ int main(int argc, char* argv[])
...
@@ -487,7 +470,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
1
#elif
0
// 3x3, 56x56
// 3x3, 56x56
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
C
=
64
;
...
@@ -565,7 +548,7 @@ int main(int argc, char* argv[])
...
@@ -565,7 +548,7 @@ int main(int argc, char* argv[])
#endif
#endif
}
}
#if
1
#if
0
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
in_nchw,
in_nchw,
wei_kcyx_desc,
wei_kcyx_desc,
...
@@ -577,7 +560,7 @@ int main(int argc, char* argv[])
...
@@ -577,7 +560,7 @@ int main(int argc, char* argv[])
LeftPads{},
LeftPads{},
RightPads{},
RightPads{},
nrepeat);
nrepeat);
#elif
1
#elif
0
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
@@ -589,6 +572,18 @@ int main(int argc, char* argv[])
...
@@ -589,6 +572,18 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif 1
device_dummy_transform
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#endif
#endif
if
(
do_verification
)
if
(
do_verification
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment