Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f64fab12
Commit
f64fab12
authored
Apr 18, 2020
by
Chao Liu
Browse files
clean pu
parent
3317bfe2
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
51 additions
and
55 deletions
+51
-55
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
...n_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
+0
-1
composable_kernel/include/utility/float_type.nvidia.hpp.in
composable_kernel/include/utility/float_type.nvidia.hpp.in
+0
-1
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+1
-1
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
+21
-22
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
.../device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+17
-18
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+12
-12
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
f64fab12
...
@@ -182,7 +182,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -182,7 +182,6 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
const
index_t
k_block_data_on_global
=
block_work_id
[
1
]
*
KPerBlock
;
const
index_t
k_block_data_on_global
=
block_work_id
[
1
]
*
KPerBlock
;
#endif
#endif
// input tensor
// input tensor
// global tensor in global memory
// global tensor in global memory
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
...
...
composable_kernel/include/utility/float_type.nvidia.hpp.in
View file @
f64fab12
...
@@ -176,7 +176,6 @@ struct inner_product_with_conversion
...
@@ -176,7 +176,6 @@ struct inner_product_with_conversion
}
}
return acc;
return acc;
}
}
};
};
} // namespace ck
} // namespace ck
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
f64fab12
1111
gma
once
#pra
gma once
#include <unistd.h>
#include <unistd.h>
#include "device.hpp"
#include "device.hpp"
#include "tensor.hpp"
#include "tensor.hpp"
...
...
driver/include/device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp
View file @
f64fab12
...
@@ -170,7 +170,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -170,7 +170,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
4
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif
0
#elif
1
// cdata = 64, BlockSize = 256, 128x128x16
// cdata = 64, BlockSize = 256, 128x128x16
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
@@ -184,10 +184,10 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -184,10 +184,10 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmNPerThreadSubC
=
4
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmKPerThreadLoop
=
1
;
constexpr
index_t
GemmMLevel0Cluster
=
2
;
constexpr
index_t
GemmMLevel0Cluster
=
4
;
constexpr
index_t
GemmNLevel0Cluster
=
2
;
constexpr
index_t
GemmNLevel0Cluster
=
4
;
constexpr
index_t
GemmMLevel1Cluster
=
8
;
constexpr
index_t
GemmMLevel1Cluster
=
4
;
constexpr
index_t
GemmNLevel1Cluster
=
8
;
constexpr
index_t
GemmNLevel1Cluster
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadA
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
constexpr
index_t
GemmDataPerReadB
=
4
;
...
@@ -561,7 +561,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -561,7 +561,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopySrcDataPerRead_E
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
constexpr
index_t
WeiBlockCopyDstDataPerWrite_K
=
1
;
#elif
1
#elif
0
// cdata = 64, BlockSize = 64, 32x128x3
// cdata = 64, BlockSize = 64, 32x128x3
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
BlockSize
=
64
;
...
@@ -810,6 +810,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -810,6 +810,9 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopySrcDataPerRead_E
,
WeiBlockCopyDstDataPerWrite_K
>
{};
WeiBlockCopyDstDataPerWrite_K
>
{};
// warm up
std
::
cout
<<
"Warn up runs..."
<<
std
::
endl
;
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
{
{
float
time
=
float
time
=
...
@@ -822,14 +825,11 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -822,14 +825,11 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
time
,
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
;
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
}
// warm up
std
::
cout
<<
"Elapsed time : "
<<
time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
printf
(
"Warn up running %d times...
\n
"
,
nrepeat
);
}
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
...
@@ -845,8 +845,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -845,8 +845,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
cudaDeviceSynchronize
()
;
KernelTimer
timer
;
auto
start
=
std
::
chrono
::
steady_clock
::
now
();
timer
.
Start
();
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
...
@@ -860,15 +860,14 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
...
@@ -860,15 +860,14 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
}
cudaDeviceSynchronize
();
timer
.
End
();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
ave_time
=
std
::
chrono
::
duration
<
float
,
std
::
milli
>
(
end
-
start
).
count
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
printf
(
"Average elapsed time : %f ms, %f TFlop/s
\n
"
,
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
ave_time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
);
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
}
driver/include/device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
f64fab12
...
@@ -253,7 +253,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -253,7 +253,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
4
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
4
;
#elif
0
#elif
1
// cdata = 64, BlockSize = 256, 128x128x16
// cdata = 64, BlockSize = 256, 128x128x16
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmBBlockCopySrcDataPerRead_GemmN = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
// GemmCThreadCopyDstDataPerWrite_GemmN1 = 4
...
@@ -791,7 +791,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -791,7 +791,7 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
2
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmN
=
2
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
constexpr
index_t
GemmCThreadCopyDstDataPerWrite_GemmN1
=
1
;
#elif
1
#elif
0
// cdata = 64, BlockSize = 64, 32x128x3
// cdata = 64, BlockSize = 64, 32x128x3
constexpr
index_t
BlockSize
=
64
;
constexpr
index_t
BlockSize
=
64
;
...
@@ -1000,6 +1000,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -1000,6 +1000,9 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmN
,
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
GemmCThreadCopyDstDataPerWrite_GemmN1
>
{};
// warm up
std
::
cout
<<
"Warn up runs..."
<<
std
::
endl
;
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
for
(
index_t
i
=
0
;
i
<
10
;
++
i
)
{
{
float
time
=
float
time
=
...
@@ -1012,14 +1015,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -1012,14 +1015,11 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
printf
(
"Elapsed time : %f ms, %f TFlop/s
\n
"
,
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
time
,
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
;
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
time
);
}
// warm up
std
::
cout
<<
"Elapsed time : "
<<
time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
printf
(
"Warn up running %d times...
\n
"
,
nrepeat
);
}
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
...
@@ -1035,8 +1035,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -1035,8 +1035,8 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
cudaDeviceSynchronize
()
;
KernelTimer
timer
;
auto
start
=
std
::
chrono
::
steady_clock
::
now
();
timer
.
Start
();
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
for
(
index_t
i
=
0
;
i
<
nrepeat
;
++
i
)
{
{
...
@@ -1050,15 +1050,14 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
...
@@ -1050,15 +1050,14 @@ void device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc,
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
static_cast
<
T
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
}
cudaDeviceSynchronize
();
timer
.
End
();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
ave_time
=
std
::
chrono
::
duration
<
float
,
std
::
milli
>
(
end
-
start
).
count
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
printf
(
"Average elapsed time : %f ms, %f TFlop/s
\n
"
,
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
ave_time
,
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
);
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
}
driver/src/conv_driver.cpp
View file @
f64fab12
#include <iostream>
#include <numeric>
#include <numeric>
#include <initializer_list>
#include <initializer_list>
#include <cstdlib>
#include <cstdlib>
...
@@ -20,8 +20,8 @@
...
@@ -20,8 +20,8 @@
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
//#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_implicit_gemm_v4r4_xdlops_fp16_nchw_kcyx_nkhw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -132,7 +132,7 @@ int main(int argc, char* argv[])
...
@@ -132,7 +132,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
LeftPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
using
RightPads
=
Sequence
<
0
,
3
>
;
#elif
1
#elif
0
// 3x3, 299x299 stride=2
// 3x3, 299x299 stride=2
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1024
;
constexpr
index_t
C
=
1024
;
...
@@ -209,10 +209,10 @@ int main(int argc, char* argv[])
...
@@ -209,10 +209,10 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 3x3, 35x35, stride 2
// 3x3, 35x35, stride 2
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
384
;
constexpr
index_t
C
=
288
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
WI
=
35
;
constexpr
index_t
K
=
384
;
constexpr
index_t
K
=
384
;
...
@@ -269,7 +269,7 @@ int main(int argc, char* argv[])
...
@@ -269,7 +269,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
LeftPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
using
RightPads
=
Sequence
<
1
,
0
>
;
#elif
1
#elif
0
// 3x3, 147x147
// 3x3, 147x147
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
C
=
64
;
...
@@ -300,7 +300,7 @@ int main(int argc, char* argv[])
...
@@ -300,7 +300,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
1
#elif
0
// 3x3, 73x73
// 3x3, 73x73
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
64
;
constexpr
index_t
C
=
64
;
...
@@ -525,8 +525,8 @@ int main(int argc, char* argv[])
...
@@ -525,8 +525,8 @@ int main(int argc, char* argv[])
print_sequence
(
"ConvStrides"
,
ConvStrides
{});
print_sequence
(
"ConvStrides"
,
ConvStrides
{});
print_sequence
(
"ConvDilations"
,
ConvDilations
{});
print_sequence
(
"ConvDilations"
,
ConvDilations
{});
using
in_data_t
=
half
;
using
in_data_t
=
float
;
using
out_data_t
=
half
;
using
out_data_t
=
float
;
Tensor
<
in_data_t
>
in_nchw
(
make_TensorDescriptor
(
in_nchw_desc
));
Tensor
<
in_data_t
>
in_nchw
(
make_TensorDescriptor
(
in_nchw_desc
));
Tensor
<
in_data_t
>
wei_kcyx
(
make_TensorDescriptor
(
wei_kcyx_desc
));
Tensor
<
in_data_t
>
wei_kcyx
(
make_TensorDescriptor
(
wei_kcyx_desc
));
Tensor
<
out_data_t
>
out_nkhw_host
(
make_TensorDescriptor
(
out_nkhw_desc
));
Tensor
<
out_data_t
>
out_nkhw_host
(
make_TensorDescriptor
(
out_nkhw_desc
));
...
@@ -592,7 +592,7 @@ int main(int argc, char* argv[])
...
@@ -592,7 +592,7 @@ int main(int argc, char* argv[])
#elif 0
#elif 0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
(
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
#elif
0
#elif
1
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
@@ -604,7 +604,7 @@ int main(int argc, char* argv[])
...
@@ -604,7 +604,7 @@ int main(int argc, char* argv[])
LeftPads
{},
LeftPads
{},
RightPads
{},
RightPads
{},
nrepeat
);
nrepeat
);
#elif
0
#elif
1
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
device_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment