Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
157491ab
"...resnet50_tensorflow.git" did not exist on "d7333c866d70d4ecb72194f573e17b96d5a7c7d1"
Commit
157491ab
authored
Dec 03, 2019
by
Chao Liu
Browse files
added bwd data v2r1: no need for atomic
parent
b7992190
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
140 additions
and
114 deletions
+140
-114
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+50
-40
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+6
-27
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
+3
-8
composable_kernel/include/utility/math.hpp
composable_kernel/include/utility/math.hpp
+50
-5
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
...ution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
+19
-6
driver/src/conv_bwd_data_driver.cpp
driver/src/conv_bwd_data_driver.cpp
+9
-25
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+3
-3
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
157491ab
...
@@ -8,6 +8,9 @@
...
@@ -8,6 +8,9 @@
namespace
ck
{
namespace
ck
{
// GemmK = K * Ydot * Xdot;
// GemmM = C * Ytilda * Xtilda;
// GemmN = N * Htilda * Wtilda;
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
typename
Float
,
typename
Float
,
...
@@ -73,33 +76,41 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -73,33 +76,41 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
"wrong! aligment requirement for vectorized global load of input tensor will "
"wrong! aligment requirement for vectorized global load of input tensor will "
"be violated"
);
"be violated"
);
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
constexpr
index_t
hcf_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
// simplicity
constexpr
index_t
hcf_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
static_assert
(
ConvStrideH
==
1
&&
ConvStrideW
==
1
&&
ConvDilationH
==
1
&&
ConvDilationW
==
1
,
"wrong! not supported yet"
);
// TODO: these logic are only for stride = 1, dilation = 1
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf_stride_dilation_h
;
constexpr
index_t
Ydot
=
Y
;
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf_stride_dilation_w
;
constexpr
index_t
Ytilda
=
1
;
constexpr
index_t
Htilda
=
Ho
+
Y
-
1
;
constexpr
index_t
Xdot
=
X
;
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xtilda
=
1
;
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
constexpr
index_t
Wtilda
=
Wo
+
X
-
1
;
constexpr
index_t
GemmK
=
K
*
Ydot
*
Xdot
;
constexpr
index_t
right_pad_ho
=
(
ConvDilationH
/
hcf_stride_dilation_h
)
*
(
Y
-
Ytilda
);
constexpr
index_t
GemmM
=
C
*
Ytilda
*
Xtilda
;
constexpr
index_t
right_pad_wo
=
(
ConvDilationW
/
hcf_stride_dilation_w
)
*
(
X
-
Xtilda
);
constexpr
index_t
GemmN
=
N
*
Htilda
*
Wtilda
;
constexpr
index_t
Htilda
=
Ho
+
right_pad_ho
;
constexpr
index_t
Wtilda
=
Wo
+
right_pad_wo
;
// weight tensor
// weight tensor
constexpr
auto
wei_k_c_y
dot_ytilda_xdot_xtilda
_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
wei_k_c_y
p_xp
_global_desc
=
transform_tensor_descriptor
(
wei_k_c_y_x_global_desc
,
wei_k_c_y_x_global_desc
,
make_tuple
(
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Y
,
X
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Ydot
*
Ytilda
-
Y
,
Xdot
*
Xtilda
-
X
>
,
true
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
wei_k_c_ydot_ytilda_xdot_xtilda_global_desc
=
transform_tensor_descriptor
(
wei_k_c_yp_xp_global_desc
,
make_tuple
(
PassThrough
<
K
>
{},
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
Embed
<
Sequence
<
Ydot
,
Ytilda
>
,
Sequence
<
1
,
1
,
0
>>
{},
// coefficient may be wrong
Embed
<
Sequence
<
Ydot
,
Ytilda
>
,
Embed
<
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
1
,
1
,
0
>>
{}),
// coefficient may be wrong
Sequence
<
ConvStrideH
/
hcf_stride_dilation_h
,
1
,
0
>>
{},
Embed
<
Sequence
<
Xdot
,
Xtilda
>
,
Sequence
<
ConvStrideW
/
hcf_stride_dilation_w
,
1
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
@@ -110,23 +121,25 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -110,23 +121,25 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// output tensor
// output tensor
constexpr
auto
out_n_k_hop_wop_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
out_n_k_hop_wop_global_desc
=
out_n_k_ho_wo_global_desc
,
transform_tensor_descriptor
(
out_n_k_ho_wo_global_desc
,
make_tuple
(
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
Pad
<
Sequence
<
Ho
,
Wo
>
,
Sequence
<
0
,
0
>
,
Sequence
<
Y
-
1
,
X
-
1
>>
{}),
// coefficient may
Pad
<
Sequence
<
Ho
,
Wo
>
,
// be wrong
Sequence
<
0
,
0
>
,
Sequence
<
right_pad_ho
,
right_pad_wo
>
,
true
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
constexpr
auto
out_n_k_ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
out_n_k_ydot_htilda_xdot_wtilda_global_desc
=
transform_tensor_descriptor
(
out_n_k_hop_wop_global_desc
,
out_n_k_hop_wop_global_desc
,
make_tuple
(
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
N
>
{},
PassThrough
<
K
>
{},
PassThrough
<
K
>
{},
Embed
<
Sequence
<
Ydot
,
Htilda
>
,
Sequence
<
0
,
1
,
0
>>
{},
// coefficient may be wrong
Embed
<
Sequence
<
Ydot
,
Htilda
>
,
Embed
<
Sequence
<
Xdot
,
Wtilda
>
,
Sequence
<
0
,
1
,
0
>>
{}),
// coefficient may be wrong
Sequence
<-
ConvDilationH
/
hcf_stride_dilation_h
,
1
,
0
>>
{},
Embed
<
Sequence
<
Xdot
,
Wtilda
>
,
Sequence
<-
ConvDilationW
/
hcf_stride_dilation_w
,
1
,
0
>>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
,
5
>
{}));
...
@@ -137,14 +150,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -137,14 +150,11 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// input tensor
// input tensor
constexpr
auto
eff_left_pads
=
LeftPads
{}
+
Sequence
<
Y
-
1
,
X
-
1
>
{};
constexpr
auto
eff_right_pads
=
RightPads
{}
+
Sequence
<
Y
-
1
,
X
-
1
>
{};
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
in_n_c_hip_wip_global_desc
=
transform_tensor_descriptor
(
in_n_c_hi_wi_global_desc
,
in_n_c_hi_wi_global_desc
,
make_tuple
(
PassThrough
<
N
>
{},
make_tuple
(
PassThrough
<
N
>
{},
PassThrough
<
C
>
{},
PassThrough
<
C
>
{},
Pad
<
Sequence
<
Hi
,
Wi
>
,
decltype
(
eff_l
eft
_p
ads
)
,
decltype
(
eff_r
ight
_p
ads
)
>
{}),
Pad
<
Sequence
<
Hi
,
Wi
>
,
L
eft
P
ads
,
R
ight
P
ads
,
true
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{}));
...
@@ -160,7 +170,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
...
@@ -160,7 +170,7 @@ struct GridwiseConvolutionBackwardDataImplicitGemm_v2r1_nchw_kcyx_nkhw
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
constexpr
auto
in_gemmm_gemmn_global_desc
=
transform_tensor_descriptor
(
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
in_n_c_ytilda_htilda_xtilda_wtilda_global_desc
,
make_tuple
(
Merge
<
Sequence
<
C
,
Ytilda
,
Xtilda
>>
{},
Merge
<
Sequence
<
N
,
Htilda
,
Wtilda
>>
{}),
make_tuple
(
Merge
<
Sequence
<
C
,
Ytilda
,
Xtilda
>>
{},
Merge
<
Sequence
<
N
,
Htilda
,
Wtilda
>>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
0
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// GEMM
// GEMM
...
...
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
157491ab
...
@@ -83,37 +83,16 @@ struct Pad
...
@@ -83,37 +83,16 @@ struct Pad
__host__
__device__
constexpr
bool
__host__
__device__
constexpr
bool
IsUpperIndexMappedToValidLowerIndex
(
const
UpperIndex
&
idx_up
)
const
IsUpperIndexMappedToValidLowerIndex
(
const
UpperIndex
&
idx_up
)
const
{
#if 0
struct lambda_no_pad
{
__host__ __device__ constexpr bool operator()(index_t x) const { return x == 0; }
};
if(sequence_all_of(LeftPads{}, lambda_no_pad{}) &&
sequence_all_of(RightPads{}, lambda_no_pad{}))
{
return true;
}
else
#endif
{
{
bool
flag
=
true
;
bool
flag
=
true
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
// only check if there is left-padding
flag
=
flag
&&
(
idx_up
[
idim
]
>=
LeftPads
::
At
(
idim
))
&&
static_if
<
(
LeftPads
::
At
(
idim
)
!=
0
)
>
{}(
(
idx_up
[
idim
]
<
LeftPads
::
At
(
idim
)
+
LowerLengths
::
At
(
idim
));
[
&
](
auto
)
{
flag
=
flag
&&
idx_up
[
idim
]
>=
LeftPads
::
At
(
idim
);
});
// only check if there is right-padding
static_if
<
(
RightPads
::
At
(
idim
)
!=
0
)
>
{}([
&
](
auto
)
{
flag
=
flag
&&
(
idx_up
[
idim
]
<
LeftPads
::
At
(
idim
)
+
LowerLengths
::
At
(
idim
));
});
});
});
return
flag
;
return
flag
;
}
}
}
};
};
// LowerLengths: Sequence<...>
// LowerLengths: Sequence<...>
...
...
composable_kernel/include/tensor_operation/gridwise_gemm.hpp
View file @
157491ab
...
@@ -46,18 +46,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
...
@@ -46,18 +46,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1r1
{
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
a_k_m_global_desc
=
AGlobalDesc
{};
constexpr
auto
a_k_m_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_k_n_global_desc
=
BGlobalDesc
{};
constexpr
auto
b_k_n_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
K
=
a_k_m_global_desc
.
GetLength
(
I0
)
;
constexpr
auto
K
=
a_k_m_global_desc
.
GetLength
s
()[
0
]
;
constexpr
auto
M
=
a_k_m_global_desc
.
GetLength
(
I1
)
;
constexpr
auto
M
=
a_k_m_global_desc
.
GetLength
s
()[
1
]
;
constexpr
auto
N
=
b_k_n_global_desc
.
GetLength
(
I1
)
;
constexpr
auto
N
=
b_k_n_global_desc
.
GetLength
s
()[
1
]
;
// lds max alignment
// lds max alignment
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDataPerAccess_M
,
constexpr
index_t
max_lds_align
=
math
::
lcm
(
ABlockCopyDataPerAccess_M
,
...
...
composable_kernel/include/utility/math.hpp
View file @
157491ab
...
@@ -97,12 +97,57 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
...
@@ -97,12 +97,57 @@ __host__ __device__ constexpr T min(T x, Ts... xs)
return
x
<
y
?
x
:
y
;
return
x
<
y
?
x
:
y
;
}
}
// this is WRONG
// highest common factor
// TODO: implement least common multiple properly, instead of calling max()
template
<
typename
T
>
template
<
class
T
,
class
...
Ts
>
__host__
__device__
constexpr
T
hcf
(
T
x
,
T
y
)
__host__
__device__
constexpr
T
lcm
(
T
x
,
Ts
...
xs
)
{
if
(
x
==
0
)
{
return
y
;
}
if
(
y
==
0
)
{
return
x
;
}
if
(
x
==
y
)
{
return
x
;
}
if
(
x
>
y
)
{
return
hcf
(
x
-
y
,
y
);
}
return
hcf
(
x
,
y
-
x
);
}
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
hcf
(
Number
<
X
>
,
Number
<
Y
>
)
{
constexpr
auto
result
=
hcf
(
X
,
Y
);
return
Number
<
result
>
{};
}
template
<
typename
X
,
typename
...
Ys
>
__host__
__device__
constexpr
auto
hcf
(
X
x
,
Ys
...
ys
)
{
return
hcf
(
x
,
ys
...);
}
// least common multiple
template
<
typename
T
>
__host__
__device__
constexpr
T
lcm
(
T
x
,
T
y
)
{
return
(
x
*
y
)
/
hcf
(
x
,
y
);
}
template
<
typename
X
,
typename
Y
,
typename
...
Zs
>
__host__
__device__
constexpr
auto
lcm
(
X
x
,
Y
y
,
Zs
...
zs
)
{
{
return
max
(
x
,
x
s
...);
return
lcm
(
x
,
lcm
(
y
,
z
s
...)
)
;
}
}
template
<
class
T
>
template
<
class
T
>
...
...
driver/include/device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw.hpp
View file @
157491ab
...
@@ -36,6 +36,12 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
...
@@ -36,6 +36,12 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLengths
()[
2
];
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLengths
()[
2
];
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLengths
()[
3
];
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLengths
()[
3
];
constexpr
index_t
ConvStrideH
=
ConvStrides
{}[
0
];
constexpr
index_t
ConvStrideW
=
ConvStrides
{}[
1
];
constexpr
index_t
ConvDilationH
=
ConvDilations
{}[
0
];
constexpr
index_t
ConvDilationW
=
ConvDilations
{}[
1
];
std
::
size_t
data_sz
=
sizeof
(
T
);
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
...
@@ -105,13 +111,20 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
...
@@ -105,13 +111,20 @@ void device_convolution_backward_data_implicit_gemm_v2r1_nchw_kcyx_nkhw(InDesc i
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// TODO: this algo support any stride and dilation. But for now, let's fix them to be 1 for
// simplicity
// simplicity
constexpr
index_t
Ydot
=
1
;
constexpr
index_t
hcf_stride_dilation_h
=
math
::
hcf
(
ConvStrideH
,
ConvDilationH
);
constexpr
index_t
Ytilda
=
Y
;
constexpr
index_t
hcf_stride_dilation_w
=
math
::
hcf
(
ConvStrideW
,
ConvDilationW
);
constexpr
index_t
Htilda
=
Ho
+
Y
-
1
;
constexpr
index_t
Ytilda
=
ConvStrideH
/
hcf_stride_dilation_h
;
// may be wrong
constexpr
index_t
Xtilda
=
ConvStrideW
/
hcf_stride_dilation_w
;
// may be wrong
constexpr
index_t
Ydot
=
math
::
integer_divide_ceil
(
Y
,
Ytilda
);
constexpr
index_t
Xdot
=
math
::
integer_divide_ceil
(
X
,
Xtilda
);
constexpr
index_t
right_pad_ho
=
(
ConvDilationH
/
hcf_stride_dilation_h
)
*
(
Y
-
Ytilda
);
constexpr
index_t
right_pad_wo
=
(
ConvDilationW
/
hcf_stride_dilation_w
)
*
(
X
-
Xtilda
);
constexpr
index_t
Xdot
=
1
;
constexpr
index_t
Htilda
=
Ho
+
right_pad_ho
;
constexpr
index_t
Xtilda
=
X
;
constexpr
index_t
Wtilda
=
Wo
+
right_pad_wo
;
constexpr
index_t
Wtilda
=
Wo
+
X
-
1
;
constexpr
index_t
GemmK
=
K
*
Ydot
*
Xdot
;
constexpr
index_t
GemmK
=
K
*
Ydot
*
Xdot
;
constexpr
index_t
GemmM
=
C
*
Ytilda
*
Xtilda
;
constexpr
index_t
GemmM
=
C
*
Ytilda
*
Xtilda
;
...
...
driver/src/conv_bwd_data_driver.cpp
View file @
157491ab
...
@@ -22,20 +22,20 @@ int main(int argc, char* argv[])
...
@@ -22,20 +22,20 @@ int main(int argc, char* argv[])
using
namespace
ck
;
using
namespace
ck
;
#if 0
#if 0
constexpr index_t N =
4
;
constexpr index_t N =
8
;
constexpr index_t C = 8;
constexpr index_t C =
12
8;
constexpr index_t HI = 1
1
;
constexpr index_t HI = 1
6
;
constexpr index_t WI = 1
1
;
constexpr index_t WI = 1
6
;
constexpr index_t K = 8;
constexpr index_t K = 8;
constexpr index_t Y =
4
;
constexpr index_t Y =
2
;
constexpr index_t X =
4
;
constexpr index_t X =
2
;
using ConvStrides = Sequence<
1
,
1
>;
using ConvStrides = Sequence<
4
,
4
>;
using ConvDilations = Sequence<
1
,
1
>;
using ConvDilations = Sequence<
2
,
2
>;
using LeftPads = Sequence<0, 0>;
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif
1
#elif
0
// 3x3, 34x34
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -52,7 +52,6 @@ int main(int argc, char* argv[])
...
@@ -52,7 +52,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 8x8 image
// 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
HI
=
8
;
...
@@ -68,7 +67,6 @@ int main(int argc, char* argv[])
...
@@ -68,7 +67,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 8x8 image
// 1x1 filter, 8x8 image
// cudnn@V100 77%, ck@V100 76%, ck@P100 79%, ck@VII 51%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
2048
;
constexpr
index_t
C
=
2048
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
HI
=
8
;
...
@@ -84,7 +82,6 @@ int main(int argc, char* argv[])
...
@@ -84,7 +82,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 7x7 image
// 1x1 filter, 7x7 image
// cudnn@V100 82%, ck@V100 76%, ck@P100 67%, ck@VII 64%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
832
;
constexpr
index_t
C
=
832
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
HI
=
7
;
...
@@ -100,7 +97,6 @@ int main(int argc, char* argv[])
...
@@ -100,7 +97,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 8x8 image
// 1x1 filter, 8x8 image
// cudnn@V100 83%, ck@V100 75%, ck@P100 78%, ck@VII 65%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
1280
;
constexpr
index_t
C
=
1280
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
HI
=
8
;
...
@@ -116,7 +112,6 @@ int main(int argc, char* argv[])
...
@@ -116,7 +112,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 14x14 image
// 1x1 filter, 14x14 image
// cudnn@V100 62%, ck@V100 68%, ck@P100 70%, ck@VII 50%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
512
;
constexpr
index_t
C
=
512
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
HI
=
14
;
...
@@ -132,7 +127,6 @@ int main(int argc, char* argv[])
...
@@ -132,7 +127,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 8x8 image
// 1x1 filter, 8x8 image
// cudnn@V100 74%, ck@V100 57%, ck@P100 78%, ck@VII 61%
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
C
=
1536
;
constexpr
index_t
HI
=
8
;
constexpr
index_t
HI
=
8
;
...
@@ -148,7 +142,6 @@ int main(int argc, char* argv[])
...
@@ -148,7 +142,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 28x28 image
// 1x1 filter, 28x28 image
// cudnn@V100 86%, ck@V100 84%, ck@P100 80%, ck@VII 69%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
HI
=
28
;
...
@@ -164,7 +157,6 @@ int main(int argc, char* argv[])
...
@@ -164,7 +157,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 7x7 image
// 1x1 filter, 7x7 image
// cudnn@V100 71%, ck@V100 55%, ck@P100 70%, ck@VII 62%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
832
;
constexpr
index_t
C
=
832
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
HI
=
7
;
...
@@ -180,7 +172,6 @@ int main(int argc, char* argv[])
...
@@ -180,7 +172,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 17x17 input
// 1x1 filter, 17x17 input
// cudnn@V100 81%, ck@V100 76%, ck@P100 70%, ck@VII 76%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
768
;
constexpr
index_t
C
=
768
;
constexpr
index_t
HI
=
17
;
constexpr
index_t
HI
=
17
;
...
@@ -196,7 +187,6 @@ int main(int argc, char* argv[])
...
@@ -196,7 +187,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 14x14 image
// 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 71%, ck@P100 70%, ck@VII 64%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
528
;
constexpr
index_t
C
=
528
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
HI
=
14
;
...
@@ -212,7 +202,6 @@ int main(int argc, char* argv[])
...
@@ -212,7 +202,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 14x14 image
// 1x1 filter, 14x14 image
// cudnn@V100 73%, ck@V100 72%, ck@P100 79%, ck@VII 75%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
528
;
constexpr
index_t
C
=
528
;
constexpr
index_t
HI
=
14
;
constexpr
index_t
HI
=
14
;
...
@@ -228,7 +217,6 @@ int main(int argc, char* argv[])
...
@@ -228,7 +217,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 1x1 filter, 7x7 image
// 1x1 filter, 7x7 image
// cudnn@V100 49%, ck@V100 50%, ck@P100 61%, ck@VII 52%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
832
;
constexpr
index_t
C
=
832
;
constexpr
index_t
HI
=
7
;
constexpr
index_t
HI
=
7
;
...
@@ -244,7 +232,6 @@ int main(int argc, char* argv[])
...
@@ -244,7 +232,6 @@ int main(int argc, char* argv[])
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif 0
#elif 0
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
288
;
constexpr
index_t
C
=
288
;
constexpr
index_t
HI
=
35
;
constexpr
index_t
HI
=
35
;
...
@@ -340,9 +327,6 @@ int main(int argc, char* argv[])
...
@@ -340,9 +327,6 @@ int main(int argc, char* argv[])
#if 0
#if 0
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
wei_kcyx.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
out_nkhw.GenerateTensorValue(GeneratorTensor_1{1}, num_thread);
#elif
0
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
out_nkhw
.
GenerateTensorValue
(
GeneratorTensor_1
{
1
},
num_thread
);
#else
#else
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
wei_kcyx
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
out_nkhw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
out_nkhw
.
GenerateTensorValue
(
GeneratorTensor_2
{
-
5
,
5
},
num_thread
);
...
...
driver/src/conv_driver.cpp
View file @
157491ab
...
@@ -58,7 +58,7 @@ int main(int argc, char* argv[])
...
@@ -58,7 +58,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
1
#elif
0
// 1x1 filter, 8x8 image
// 1x1 filter, 8x8 image
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
// cudnn@V100 68%, ck@V100 72%, ck@P100 52%, ck@VII 42%
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
...
@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
...
@@ -250,7 +250,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// 3x3 filter, 2x2 stride, 35x35 input, 17x17 output
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
// cudnn@V100 90%, ck@V100 93%, ck@P100 83%, ck@VII 81%
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
...
@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
...
@@ -296,7 +296,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
LeftPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
using
RightPads
=
Sequence
<
3
,
0
>
;
#elif
1
#elif
0
// 1x7 filter, 0x3 pad, 17x17 input
// 1x7 filter, 0x3 pad, 17x17 input
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
128
;
constexpr
index_t
C
=
128
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment