Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
74603261
Commit
74603261
authored
Jul 18, 2022
by
Chao Liu
Browse files
fix initialization issue
parent
360184cd
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
142 additions
and
239 deletions
+142
-239
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
+2
-2
example/20_convnd_bwd_weight/convnd_bwd_weight_xdl_bf16.cpp
example/20_convnd_bwd_weight/convnd_bwd_weight_xdl_bf16.cpp
+2
-2
profiler/include/profile_conv_bwd_data_impl.hpp
profiler/include/profile_conv_bwd_data_impl.hpp
+2
-2
profiler/include/profile_conv_bwd_weight_impl.hpp
profiler/include/profile_conv_bwd_weight_impl.hpp
+4
-4
profiler/src/profile_conv_bwd_weight.cpp
profiler/src/profile_conv_bwd_weight.cpp
+6
-3
test/convnd_bwd_data/convnd_bwd_data.cpp
test/convnd_bwd_data/convnd_bwd_data.cpp
+1
-1
test/convnd_bwd_weight/CMakeLists.txt
test/convnd_bwd_weight/CMakeLists.txt
+1
-1
test/convnd_bwd_weight/convnd_bwd_weight.cpp
test/convnd_bwd_weight/convnd_bwd_weight.cpp
+124
-224
No files found.
example/17_convnd_bwd_data/convnd_bwd_data_common.hpp
View file @
74603261
...
@@ -197,8 +197,8 @@ int run_conv_bwd_data(bool do_verification,
...
@@ -197,8 +197,8 @@ int run_conv_bwd_data(bool do_verification,
wei
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
wei
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
break
;
break
;
default:
default:
out
.
GenerateTensorValue
(
GeneratorTensor_
1
<
OutDataType
>
{
1
});
out
.
GenerateTensorValue
(
GeneratorTensor_
3
<
OutDataType
>
{
0.0
,
1.0
});
wei
.
GenerateTensorValue
(
GeneratorTensor_
1
<
WeiDataType
>
{
1
});
wei
.
GenerateTensorValue
(
GeneratorTensor_
3
<
WeiDataType
>
{
-
0.5
,
0.5
});
}
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_device
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_device
.
mDesc
.
GetElementSpace
());
...
...
example/20_convnd_bwd_weight/convnd_bwd_weight_xdl_bf16.cpp
View file @
74603261
...
@@ -6,8 +6,8 @@
...
@@ -6,8 +6,8 @@
#include "ck/tensor_operation/gpu/device/device_convnd_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/device_convnd_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp"
using
InDataType
=
ck
::
bhalf_t
;
using
InDataType
=
ck
::
bhalf_t
;
using
WeiDataType
=
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
float
;
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
using
WeiDataType
=
float
;
using
OutDataType
=
ck
::
bhalf_t
;
using
OutDataType
=
ck
::
bhalf_t
;
using
AccDataType
=
float
;
using
AccDataType
=
float
;
...
...
profiler/include/profile_conv_bwd_data_impl.hpp
View file @
74603261
...
@@ -154,8 +154,8 @@ bool profile_conv_bwd_data_impl(int do_verification,
...
@@ -154,8 +154,8 @@ bool profile_conv_bwd_data_impl(int do_verification,
weight
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
weight
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
break
;
break
;
default:
default:
output
.
GenerateTensorValue
(
GeneratorTensor_
1
<
OutDataType
>
{
1
});
output
.
GenerateTensorValue
(
GeneratorTensor_
3
<
OutDataType
>
{
0.0
,
1.0
});
weight
.
GenerateTensorValue
(
GeneratorTensor_
1
<
WeiDataType
>
{
1
});
weight
.
GenerateTensorValue
(
GeneratorTensor_
3
<
WeiDataType
>
{
-
0.5
,
0.5
});
}
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
input_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
input_device_result
.
mDesc
.
GetElementSpace
());
...
...
profiler/include/profile_conv_bwd_weight_impl.hpp
View file @
74603261
...
@@ -156,12 +156,12 @@ bool profile_conv_bwd_weight_impl(int do_verification,
...
@@ -156,12 +156,12 @@ bool profile_conv_bwd_weight_impl(int do_verification,
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
input
.
GenerateTensorValue
(
GeneratorTensor_2
<
Out
DataType
>
{
-
2
,
2
});
input
.
GenerateTensorValue
(
GeneratorTensor_2
<
In
DataType
>
{
-
5
,
5
});
output
.
GenerateTensorValue
(
GeneratorTensor_2
<
Wei
DataType
>
{
-
2
,
2
});
output
.
GenerateTensorValue
(
GeneratorTensor_2
<
Out
DataType
>
{
-
5
,
5
});
break
;
break
;
default:
default:
input
.
GenerateTensorValue
(
GeneratorTensor_
1
<
Out
DataType
>
{
1
});
input
.
GenerateTensorValue
(
GeneratorTensor_
3
<
In
DataType
>
{
0.0
,
1.0
});
output
.
GenerateTensorValue
(
GeneratorTensor_
1
<
Wei
DataType
>
{
1
});
output
.
GenerateTensorValue
(
GeneratorTensor_
3
<
Out
DataType
>
{
-
0.5
,
0.5
});
}
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
input
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
input
.
mDesc
.
GetElementSpace
());
...
...
profiler/src/profile_conv_bwd_weight.cpp
View file @
74603261
...
@@ -197,7 +197,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
...
@@ -197,7 +197,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
}
}
else
if
(
data_type
==
ConvDataType
::
BF16_F32_BF16
)
else
if
(
data_type
==
ConvDataType
::
BF16_F32_BF16
)
{
{
return
profile
(
I1
,
NWC
{},
KXC
{},
NWK
{},
BF16
{},
BF16
{},
BF16
{});
// fp32 atomic add is used for weight tensor in bf16 kernel
return
profile
(
I1
,
NWC
{},
KXC
{},
NWK
{},
BF16
{},
F32
{},
BF16
{});
}
}
}
}
else
if
(
num_dim_spatial
==
2
&&
layout
==
ConvLayout
::
NHWC_KYXC_NHWK
)
else
if
(
num_dim_spatial
==
2
&&
layout
==
ConvLayout
::
NHWC_KYXC_NHWK
)
...
@@ -212,7 +213,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
...
@@ -212,7 +213,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
}
}
else
if
(
data_type
==
ConvDataType
::
BF16_F32_BF16
)
else
if
(
data_type
==
ConvDataType
::
BF16_F32_BF16
)
{
{
return
profile
(
I2
,
NHWC
{},
KYXC
{},
NHWK
{},
BF16
{},
BF16
{},
BF16
{});
// fp32 atomic add is used for weight tensor in bf16 kernel
return
profile
(
I2
,
NHWC
{},
KYXC
{},
NHWK
{},
BF16
{},
F32
{},
BF16
{});
}
}
}
}
else
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
NHWC_KYXC_NHWK
)
else
if
(
num_dim_spatial
==
3
&&
layout
==
ConvLayout
::
NHWC_KYXC_NHWK
)
...
@@ -227,7 +229,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
...
@@ -227,7 +229,8 @@ int profile_conv_bwd_weight(int argc, char* argv[])
}
}
else
if
(
data_type
==
ConvDataType
::
BF16_F32_BF16
)
else
if
(
data_type
==
ConvDataType
::
BF16_F32_BF16
)
{
{
return
profile
(
I3
,
NDHWC
{},
KZYXC
{},
NDHWK
{},
BF16
{},
BF16
{},
BF16
{});
// fp32 atomic add is used for weight tensor in bf16 kernel
return
profile
(
I3
,
NDHWC
{},
KZYXC
{},
NDHWK
{},
BF16
{},
F32
{},
BF16
{});
}
}
}
}
...
...
test/convnd_bwd_data/convnd_bwd_data.cpp
View file @
74603261
...
@@ -196,6 +196,6 @@ int main()
...
@@ -196,6 +196,6 @@ int main()
else
else
{
{
std
::
cout
<<
"test convnd bwd: Fail "
<<
std
::
endl
;
std
::
cout
<<
"test convnd bwd: Fail "
<<
std
::
endl
;
return
-
1
;
return
1
;
}
}
}
}
test/convnd_bwd_weight/CMakeLists.txt
View file @
74603261
add_test_executable
(
test_convnd_bwd_weight convnd_bwd_weight.cpp
)
add_test_executable
(
test_convnd_bwd_weight convnd_bwd_weight.cpp
)
target_link_libraries
(
test_convnd_bwd_weight PRIVATE utility device_conv
n
d_bwd_weight_instance
)
target_link_libraries
(
test_convnd_bwd_weight PRIVATE utility device_conv
1d_bwd_weight_instance device_conv2d_bwd_weight_instance device_conv3
d_bwd_weight_instance
)
test/convnd_bwd_weight/convnd_bwd_weight.cpp
View file @
74603261
...
@@ -7,90 +7,61 @@
...
@@ -7,90 +7,61 @@
#include <cstdlib>
#include <cstdlib>
#include <vector>
#include <vector>
#include "test/convnd_fwd/conv_util.hpp"
#include "profiler/include/profile_conv_bwd_weight_impl.hpp"
#include "profiler/include/profile_convnd_bwd_weight_impl.hpp"
int
test_self
()
int
main
()
{
{
bool
pass
=
true
;
bool
pass
=
true
;
std
::
vector
<
ck
::
utils
::
conv
::
ConvParams
>
params
;
std
::
vector
<
ck
::
tensor_operation
::
device
::
ConvParams
>
params
;
// check 1d
params
.
push_back
({
1
,
128
,
256
,
256
,
{
1
},
{
7
},
{
2
},
{
1
},
{
0
},
{
0
}});
params
.
push_back
({
1
,
128
,
256
,
256
,
{
1
},
{
7
},
{
2
},
{
1
},
{
0
},
{
0
}});
params
.
push_back
({
1
,
128
,
256
,
256
,
{
3
},
{
14
},
{
1
},
{
1
},
{
1
},
{
1
}});
params
.
push_back
({
1
,
128
,
256
,
256
,
{
3
},
{
14
},
{
1
},
{
1
},
{
1
},
{
1
}});
params
.
push_back
({
1
,
128
,
256
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
params
.
push_back
({
1
,
128
,
256
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
for
(
auto
&
param
:
params
)
for
(
auto
&
param
:
params
)
{
{
// f32
// fp32
pass
&=
ck
::
profiler
::
profile_convnd_bwd_weight_impl
<
1
,
pass
&=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
1
,
float
,
float
,
float
,
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
>
(
ck
::
tensor_layout
::
convolution
::
NWK
,
true
,
// do_verification
float
,
float
,
float
>
(
true
,
// do_verification
1
,
// init_method
1
,
// init_method
false
,
// do_log
false
,
// do_log
true
,
// time_kernel
false
,
// time_kernel
param
.
N_
,
param
,
param
.
K_
,
param
.
C_
,
param
.
input_spatial_lengths_
,
param
.
filter_spatial_lengths_
,
param
.
GetOutputSpatialLengths
(),
param
.
conv_filter_strides_
,
param
.
conv_filter_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
,
2
);
2
);
// fp16
// fp16
pass
&=
ck
::
profiler
::
profile_convnd_bwd_weight_impl
<
1
,
pass
&=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
1
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
>
(
ck
::
tensor_layout
::
convolution
::
NWK
,
true
,
// do_verification
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
>
(
true
,
// do_verification
1
,
// init_method
1
,
// init_method
false
,
// do_log
false
,
// do_log
true
,
// time_kernel
false
,
// time_kernel
param
.
N_
,
param
,
param
.
K_
,
param
.
C_
,
param
.
input_spatial_lengths_
,
param
.
filter_spatial_lengths_
,
param
.
GetOutputSpatialLengths
(),
param
.
conv_filter_strides_
,
param
.
conv_filter_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
,
2
);
2
);
// bf16
// bf16, wei is f32
pass
&=
ck
::
profiler
::
profile_convnd_bwd_weight_impl
<
1
,
pass
&=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
1
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
NWC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
NWK
>
(
ck
::
tensor_layout
::
convolution
::
NWK
,
true
,
// do_verification
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
>
(
true
,
// do_verification
1
,
// init_method
1
,
// init_method
false
,
// do_log
false
,
// do_log
true
,
// time_kernel
false
,
// time_kernel
param
.
N_
,
param
,
param
.
K_
,
param
.
C_
,
param
.
input_spatial_lengths_
,
param
.
filter_spatial_lengths_
,
param
.
GetOutputSpatialLengths
(),
param
.
conv_filter_strides_
,
param
.
conv_filter_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
,
2
);
2
);
}
}
...
@@ -102,80 +73,50 @@ int test_self()
...
@@ -102,80 +73,50 @@ int test_self()
for
(
auto
&
param
:
params
)
for
(
auto
&
param
:
params
)
{
{
// f32
// fp32
pass
&=
ck
::
profiler
::
profile_convnd_bwd_weight_impl
<
2
,
pass
&=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
2
,
float
,
float
,
float
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
ck
::
tensor_layout
::
convolution
::
NHWK
,
true
,
// do_verification
float
,
float
,
float
>
(
true
,
// do_verification
1
,
// init_method
1
,
// init_method
false
,
// do_log
false
,
// do_log
true
,
// time_kernel
false
,
// time_kernel
param
.
N_
,
param
,
param
.
K_
,
param
.
C_
,
param
.
input_spatial_lengths_
,
param
.
filter_spatial_lengths_
,
param
.
GetOutputSpatialLengths
(),
param
.
conv_filter_strides_
,
param
.
conv_filter_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
,
2
);
2
);
// fp16
// fp16
pass
&=
ck
::
profiler
::
profile_convnd_bwd_weight_impl
<
2
,
pass
&=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
2
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
ck
::
tensor_layout
::
convolution
::
NHWK
,
true
,
// do_verification
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
>
(
true
,
// do_verification
1
,
// init_method
1
,
// init_method
false
,
// do_log
false
,
// do_log
true
,
// time_kernel
false
,
// time_kernel
param
.
N_
,
param
,
param
.
K_
,
param
.
C_
,
param
.
input_spatial_lengths_
,
param
.
filter_spatial_lengths_
,
param
.
GetOutputSpatialLengths
(),
param
.
conv_filter_strides_
,
param
.
conv_filter_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
,
2
);
2
);
// bf16
// bf16, wei is f32
pass
&=
ck
::
profiler
::
profile_convnd_bwd_weight_impl
<
2
,
pass
&=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
2
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
ck
::
tensor_layout
::
convolution
::
NHWK
,
true
,
// do_verification
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
>
(
true
,
// do_verification
1
,
// init_method
1
,
// init_method
false
,
// do_log
false
,
// do_log
true
,
// time_kernel
false
,
// time_kernel
param
.
N_
,
param
,
param
.
K_
,
param
.
C_
,
param
.
input_spatial_lengths_
,
param
.
filter_spatial_lengths_
,
param
.
GetOutputSpatialLengths
(),
param
.
conv_filter_strides_
,
param
.
conv_filter_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
,
2
);
2
);
}
}
// check
2
d
// check
3
d
params
.
clear
();
params
.
clear
();
params
.
push_back
(
params
.
push_back
(
{
3
,
128
,
256
,
256
,
{
1
,
1
,
1
},
{
4
,
4
,
4
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
{
3
,
128
,
256
,
256
,
{
1
,
1
,
1
},
{
4
,
4
,
4
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
...
@@ -186,90 +127,49 @@ int test_self()
...
@@ -186,90 +127,49 @@ int test_self()
for
(
auto
&
param
:
params
)
for
(
auto
&
param
:
params
)
{
{
// f32
// fp32
pass
&=
ck
::
profiler
::
profile_convnd_bwd_weight_impl
<
3
,
pass
&=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
3
,
float
,
float
,
float
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
(
ck
::
tensor_layout
::
convolution
::
NDHWK
,
true
,
// do_verification
float
,
float
,
float
>
(
true
,
// do_verification
1
,
// init_method
1
,
// init_method
false
,
// do_log
false
,
// do_log
true
,
// time_kernel
false
,
// time_kernel
param
.
N_
,
param
,
param
.
K_
,
param
.
C_
,
param
.
input_spatial_lengths_
,
param
.
filter_spatial_lengths_
,
param
.
GetOutputSpatialLengths
(),
param
.
conv_filter_strides_
,
param
.
conv_filter_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
,
2
);
2
);
// fp16
// fp16
pass
&=
ck
::
profiler
::
profile_convnd_bwd_weight_impl
<
3
,
pass
&=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
3
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
(
ck
::
tensor_layout
::
convolution
::
NDHWK
,
true
,
// do_verification
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
>
(
true
,
// do_verification
1
,
// init_method
1
,
// init_method
false
,
// do_log
false
,
// do_log
true
,
// time_kernel
false
,
// time_kernel
param
.
N_
,
param
,
param
.
K_
,
param
.
C_
,
param
.
input_spatial_lengths_
,
param
.
filter_spatial_lengths_
,
param
.
GetOutputSpatialLengths
(),
param
.
conv_filter_strides_
,
param
.
conv_filter_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
,
2
);
2
);
// bf16
// bf16, wei is f32
pass
&=
ck
::
profiler
::
profile_convnd_bwd_weight_impl
<
3
,
pass
&=
ck
::
profiler
::
profile_conv_bwd_weight_impl
<
3
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
NDHWC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
(
ck
::
tensor_layout
::
convolution
::
NDHWK
,
true
,
// do_verification
ck
::
bhalf_t
,
float
,
ck
::
bhalf_t
>
(
true
,
// do_verification
1
,
// init_method
1
,
// init_method
false
,
// do_log
false
,
// do_log
true
,
// time_kernel
false
,
// time_kernel
param
.
N_
,
param
,
param
.
K_
,
param
.
C_
,
param
.
input_spatial_lengths_
,
param
.
filter_spatial_lengths_
,
param
.
GetOutputSpatialLengths
(),
param
.
conv_filter_strides_
,
param
.
conv_filter_dilations_
,
param
.
input_left_pads_
,
param
.
input_right_pads_
,
2
);
2
);
}
}
return
pass
;
}
int
main
()
{
// int data_type = 1;
// int init_method = 1;
bool
pass
=
true
;
pass
=
test_self
();
if
(
pass
)
if
(
pass
)
{
{
std
::
cout
<<
"test conv2d bwd weight : Pass"
<<
std
::
endl
;
std
::
cout
<<
"test conv2d bwd weight : Pass"
<<
std
::
endl
;
...
@@ -278,6 +178,6 @@ int main()
...
@@ -278,6 +178,6 @@ int main()
else
else
{
{
std
::
cout
<<
"test conv2d bwd weight: Fail "
<<
std
::
endl
;
std
::
cout
<<
"test conv2d bwd weight: Fail "
<<
std
::
endl
;
return
-
1
;
return
1
;
}
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment