Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3276a5e9
Commit
3276a5e9
authored
Jul 05, 2019
by
Chao Liu
Browse files
update build
parent
c15ff3c8
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
90 additions
and
88 deletions
+90
-88
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
+6
-14
composable_kernel/include/utility/config_amd.hpp.in
composable_kernel/include/utility/config_amd.hpp.in
+2
-1
composable_kernel/include/utility/config_nvidia.hpp.in
composable_kernel/include/utility/config_nvidia.hpp.in
+2
-1
driver/include/conv_common.hpp
driver/include/conv_common.hpp
+10
-4
driver/include/device.hpp
driver/include/device.hpp
+2
-4
driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
...de/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
+2
-4
driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
...de/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+2
-2
driver/include/tensor.hpp
driver/include/tensor.hpp
+45
-2
driver/src/driver.cpp
driver/src/driver.cpp
+19
-56
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp
View file @
3276a5e9
...
@@ -209,15 +209,6 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -209,15 +209,6 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
GemmDataPerReadA
,
GemmDataPerReadA
,
GemmDataPerReadB
>
{};
GemmDataPerReadB
>
{};
// choose GEMM implementation here
const
auto
run_blockwise_gemm
=
[
&
](
auto
...
Xs
)
{
#if 1
return
blockwise_gemm
.
Run
(
Xs
...);
#else
return
blockwise_gemm
.
Run_amd_asm
(
Xs
...);
#endif
};
// LDS allocation for input and weight: be careful of alignment
// LDS allocation for input and weight: be careful of alignment
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
constexpr
index_t
max_align
=
math
::
lcm
(
InBlockCopyDstDataPerWrite_N2
,
WeiBlockCopyDataPerAccess_K
,
WeiBlockCopyDataPerAccess_K
,
...
@@ -225,9 +216,10 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -225,9 +216,10 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
GemmDataPerReadB
);
GemmDataPerReadB
);
constexpr
index_t
in_block_space
=
constexpr
index_t
in_block_space
=
in_c_n1_b_n2_block_mem_desc
.
GetElementSpace
(
Number
<
max_align
>
{}
);
math
::
integer_least_multiple
(
in_c_n1_b_n2_block_mem_desc
.
GetElementSpace
(
),
max_align
);
constexpr
index_t
wei_block_space
=
wei_c_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
math
::
integer_least_multiple
(
wei_c_k_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_in_block_double
[
2
*
in_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
...
@@ -291,7 +283,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -291,7 +283,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
p_wei_register_clipboard
);
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
run_
blockwise_gemm
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_now
,
p_in_block_now
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
...
@@ -319,7 +311,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -319,7 +311,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
p_wei_register_clipboard
);
p_wei_register_clipboard
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
run_
blockwise_gemm
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
blockwise_gemm
.
Run
(
p_wei_block_double
,
p_in_block_double
,
p_out_thread
);
// LDS double buffer: store next data to LDS
// LDS double buffer: store next data to LDS
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
blockwise_in_copy
.
RunStoreRegisterClipboard
(
p_in_register_clipboard
,
...
@@ -331,7 +323,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
...
@@ -331,7 +323,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
__syncthreads
();
__syncthreads
();
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
run_
blockwise_gemm
(
p_wei_block_double
+
wei_block_space
,
blockwise_gemm
.
Run
(
p_wei_block_double
+
wei_block_space
,
p_in_block_double
+
in_block_space
,
p_in_block_double
+
in_block_space
,
p_out_thread
);
p_out_thread
);
}
}
...
...
composable_kernel/include/utility/config_amd.hpp.in
View file @
3276a5e9
...
@@ -18,7 +18,8 @@ typedef float float4_t __attribute__((ext_vector_type(4)));
...
@@ -18,7 +18,8 @@ typedef float float4_t __attribute__((ext_vector_type(4)));
using index_t = uint32_t;
using index_t = uint32_t;
__device__ void fused_multiply_accumulate(float& d, const float& s0, const float& s1)
template <class T>
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
{
{
d += s0 * s1;
d += s0 * s1;
}
}
...
...
composable_kernel/include/utility/config_nvidia.hpp.in
View file @
3276a5e9
...
@@ -22,7 +22,8 @@ using float4_t = float4;
...
@@ -22,7 +22,8 @@ using float4_t = float4;
using index_t = uint32_t;
using index_t = uint32_t;
__device__ void fused_multiply_accumulate(float& d, const float& s0, const float& s1)
template <class T>
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
{
{
d += s0 * s1;
d += s0 * s1;
}
}
...
...
driver/include/conv_common.hpp
View file @
3276a5e9
#ifndef
CK_
CONV_COMMON_HPP
#ifndef CONV_COMMON_HPP
#define
CK_
CONV_COMMON_HPP
#define CONV_COMMON_HPP
#include "ConstantTensorDescriptor.hpp"
#include "ConstantTensorDescriptor.hpp"
using
namespace
ck
;
// this is ugly, only for 4d
// this is ugly, only for 4d
template
<
class
InDesc
,
class
WeiDesc
>
template
<
class
InDesc
,
class
WeiDesc
>
constexpr
auto
get_convolution_output_default_4d_tensor_descriptor
(
InDesc
,
WeiDesc
)
constexpr
auto
get_convolution_output_default_4d_tensor_descriptor
(
InDesc
,
WeiDesc
)
{
{
using
namespace
ck
;
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
...
@@ -45,6 +45,8 @@ template <class InDesc,
...
@@ -45,6 +45,8 @@ template <class InDesc,
constexpr
auto
get_convolution_with_padding_output_default_4d_tensor_descriptor
(
constexpr
auto
get_convolution_with_padding_output_default_4d_tensor_descriptor
(
InDesc
,
WeiDesc
,
ConvStrides
,
ConvDilations
,
LowerPads
,
UpperPads
)
InDesc
,
WeiDesc
,
ConvStrides
,
ConvDilations
,
LowerPads
,
UpperPads
)
{
{
using
namespace
ck
;
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
...
@@ -84,6 +86,8 @@ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
...
@@ -84,6 +86,8 @@ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
template
<
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
constexpr
std
::
size_t
calculate_convolution_flops
(
InDesc
,
WeiDesc
,
OutDesc
)
constexpr
std
::
size_t
calculate_convolution_flops
(
InDesc
,
WeiDesc
,
OutDesc
)
{
{
using
namespace
ck
;
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
...
@@ -107,6 +111,8 @@ constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc)
...
@@ -107,6 +111,8 @@ constexpr std::size_t calculate_convolution_flops(InDesc, WeiDesc, OutDesc)
template
<
class
Float
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
Float
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
constexpr
std
::
size_t
calculate_convolution_memory_size
(
Float
,
InDesc
,
WeiDesc
,
OutDesc
)
constexpr
std
::
size_t
calculate_convolution_memory_size
(
Float
,
InDesc
,
WeiDesc
,
OutDesc
)
{
{
using
namespace
ck
;
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
constexpr
auto
out_desc
=
OutDesc
{};
...
...
driver/include/device.hpp
View file @
3276a5e9
#ifndef
CK_
DEVICE_HPP
#ifndef DEVICE_HPP
#define
CK_
DEVICE_HPP
#define DEVICE_HPP
#include <memory>
#include <memory>
#include "config.hpp"
#include "config.hpp"
using
namespace
ck
;
struct
DeviceMem
struct
DeviceMem
{
{
DeviceMem
()
=
delete
;
DeviceMem
()
=
delete
;
...
...
driver/include/device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp
View file @
3276a5e9
...
@@ -6,8 +6,6 @@
...
@@ -6,8 +6,6 @@
#include "gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp"
#include "gridwise_convolution_implicit_gemm_v3_nchw_cyxk_nkhw_lds_double_buffer.hpp"
using
namespace
ck
;
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
InDesc
,
void
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
in_nchw
,
...
@@ -17,6 +15,8 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
...
@@ -17,6 +15,8 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
Tensor
<
T
>&
out_nkhw
,
Tensor
<
T
>&
out_nkhw
,
index_t
nrepeat
)
index_t
nrepeat
)
{
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -135,7 +135,6 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
...
@@ -135,7 +135,6 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
WeiBlockCopyClusterLengths_C_K
,
WeiBlockCopyClusterLengths_C_K
,
WeiBlockCopyDataPerAccess_K
>
{};
WeiBlockCopyDataPerAccess_K
>
{};
#if 1
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
float
time
=
launch_kernel
(
run_gridwise_convolution_kernel
<
decltype
(
gridwise_conv
),
T
>
,
dim3
(
GridSize
),
dim3
(
GridSize
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
...
@@ -149,7 +148,6 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
...
@@ -149,7 +148,6 @@ void device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(InDesc,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
usleep
(
std
::
min
(
time
*
1000
,
float
(
10000
)));
#endif
}
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
...
...
driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
View file @
3276a5e9
...
@@ -6,8 +6,6 @@
...
@@ -6,8 +6,6 @@
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
#include "gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp"
using
namespace
ck
;
template
<
class
T
,
template
<
class
T
,
class
InDesc
,
class
InDesc
,
class
WeiDesc
,
class
WeiDesc
,
...
@@ -24,6 +22,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
...
@@ -24,6 +22,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
ConvDilations
,
ConvDilations
,
index_t
nrepeat
)
index_t
nrepeat
)
{
{
using
namespace
ck
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
driver/include/tensor.hpp
View file @
3276a5e9
#ifndef
CK_
TENSOR_HPP
#ifndef TENSOR_HPP
#define
CK_
TENSOR_HPP
#define TENSOR_HPP
#include <thread>
#include <thread>
#include <vector>
#include <vector>
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <utility>
#include <utility>
#include <cassert>
#include <cassert>
#include <iostream>
#include <iostream>
#include "common_header.hpp"
template
<
class
Range
>
template
<
class
Range
>
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
)
...
@@ -269,4 +270,46 @@ struct Tensor
...
@@ -269,4 +270,46 @@ struct Tensor
std
::
vector
<
T
>
mData
;
std
::
vector
<
T
>
mData
;
};
};
// this is ugly, only for 4d
template
<
class
TConstTensorDesc
>
void
ostream_ConstantTensorDescriptor
(
TConstTensorDesc
,
std
::
ostream
&
os
=
std
::
cout
)
{
using
namespace
ck
;
static_assert
(
TConstTensorDesc
::
nDim
==
4
,
"nDim is not 4"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
desc
=
TConstTensorDesc
{};
os
<<
"Lengths: {"
<<
desc
.
GetLength
(
I0
)
<<
", "
<<
desc
.
GetLength
(
I1
)
<<
", "
<<
desc
.
GetLength
(
I2
)
<<
", "
<<
desc
.
GetLength
(
I3
)
<<
"}, "
<<
"Strides: {"
<<
desc
.
GetStride
(
I0
)
<<
", "
<<
desc
.
GetStride
(
I1
)
<<
", "
<<
desc
.
GetStride
(
I2
)
<<
", "
<<
desc
.
GetStride
(
I3
)
<<
"}"
<<
std
::
endl
;
}
// this is ugly, only for 4d
template
<
class
TConstTensorDesc
>
auto
make_TensorDescriptor
(
TConstTensorDesc
)
{
using
namespace
ck
;
static_assert
(
TConstTensorDesc
::
nDim
==
4
,
"nDim is not 4"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
desc
=
TConstTensorDesc
{};
std
::
initializer_list
<
index_t
>
lengths
=
{
desc
.
GetLength
(
I0
),
desc
.
GetLength
(
I1
),
desc
.
GetLength
(
I2
),
desc
.
GetLength
(
I3
)};
std
::
initializer_list
<
index_t
>
strides
=
{
desc
.
GetStride
(
I0
),
desc
.
GetStride
(
I1
),
desc
.
GetStride
(
I2
),
desc
.
GetStride
(
I3
)};
return
TensorDescriptor
(
lengths
,
strides
);
}
#endif
#endif
driver/src/driver.cpp
View file @
3276a5e9
...
@@ -14,8 +14,6 @@
...
@@ -14,8 +14,6 @@
#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp"
using
namespace
ck
;
struct
GeneratorTensor_1
struct
GeneratorTensor_1
{
{
template
<
class
...
Is
>
template
<
class
...
Is
>
...
@@ -65,44 +63,6 @@ struct GeneratorTensor_Checkboard
...
@@ -65,44 +63,6 @@ struct GeneratorTensor_Checkboard
}
}
};
};
// this is ugly, only for 4d
template
<
class
TConstTensorDesc
>
void
ostream_ConstantTensorDescriptor
(
TConstTensorDesc
,
std
::
ostream
&
os
=
std
::
cout
)
{
static_assert
(
TConstTensorDesc
::
nDim
==
4
,
"nDim is not 4"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
desc
=
TConstTensorDesc
{};
os
<<
"Lengths: {"
<<
desc
.
GetLength
(
I0
)
<<
", "
<<
desc
.
GetLength
(
I1
)
<<
", "
<<
desc
.
GetLength
(
I2
)
<<
", "
<<
desc
.
GetLength
(
I3
)
<<
"}, "
<<
"Strides: {"
<<
desc
.
GetStride
(
I0
)
<<
", "
<<
desc
.
GetStride
(
I1
)
<<
", "
<<
desc
.
GetStride
(
I2
)
<<
", "
<<
desc
.
GetStride
(
I3
)
<<
"}"
<<
std
::
endl
;
}
// this is ugly, only for 4d
template
<
class
TConstTensorDesc
>
auto
make_TensorDescriptor
(
TConstTensorDesc
)
{
static_assert
(
TConstTensorDesc
::
nDim
==
4
,
"nDim is not 4"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
desc
=
TConstTensorDesc
{};
std
::
initializer_list
<
index_t
>
lengths
=
{
desc
.
GetLength
(
I0
),
desc
.
GetLength
(
I1
),
desc
.
GetLength
(
I2
),
desc
.
GetLength
(
I3
)};
std
::
initializer_list
<
index_t
>
strides
=
{
desc
.
GetStride
(
I0
),
desc
.
GetStride
(
I1
),
desc
.
GetStride
(
I2
),
desc
.
GetStride
(
I3
)};
return
TensorDescriptor
(
lengths
,
strides
);
}
template
<
class
TIn
,
template
<
class
TIn
,
class
TWei
,
class
TWei
,
class
TOut
,
class
TOut
,
...
@@ -416,6 +376,8 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -416,6 +376,8 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
using
namespace
ck
;
#if 0
#if 0
constexpr index_t N = 8;
constexpr index_t N = 8;
constexpr index_t C = 16;
constexpr index_t C = 16;
...
@@ -611,7 +573,7 @@ int main(int argc, char* argv[])
...
@@ -611,7 +573,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
0
#elif
1
// 1x1 filter, 8x8 image
// 1x1 filter, 8x8 image
// cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51%
// cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
...
@@ -787,7 +749,7 @@ int main(int argc, char* argv[])
...
@@ -787,7 +749,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
1
#elif
0
// 1x1 filter, 7x7 image
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
...
@@ -859,21 +821,23 @@ int main(int argc, char* argv[])
...
@@ -859,21 +821,23 @@ int main(int argc, char* argv[])
#endif
#endif
}
}
#if 1
#if 0
#if 0
device_convolution_direct_v2_nchw_kcyx_nkhw
device_convolution_direct_v2_nchw_kcyx_nkhw
in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
#elif
0
#elif
0
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif 0
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif 0
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif 0
#elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif 1
#elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
(
in_nchw_desc
,
#endif
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
wei_kcyx
,
wei_kcyx
,
...
@@ -882,7 +846,6 @@ int main(int argc, char* argv[])
...
@@ -882,7 +846,6 @@ int main(int argc, char* argv[])
ConvStrides
{},
ConvStrides
{},
ConvDilations
{},
ConvDilations
{},
nrepeat
);
nrepeat
);
#elif 0
#elif 0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
(
in_nchw_desc
,
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment