Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel_onnx
Commits
b1cb48a0
You need to sign in or sign up before continuing.
Commit
b1cb48a0
authored
Jun 13, 2019
by
Chao Liu
Browse files
added strides and dilations suppport to implicit gemm v4
parent
1566b317
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
96 additions
and
36 deletions
+96
-36
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
...ion_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
+11
-7
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
...l/include/tensor_description/ConstantTensorDescriptor.hpp
+12
-0
driver/include/conv_common.hpp
driver/include/conv_common.hpp
+24
-18
driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
...de/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
+10
-1
driver/src/driver.cpp
driver/src/driver.cpp
+39
-10
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp
View file @
b1cb48a0
...
@@ -22,6 +22,8 @@ template <index_t GridSize,
...
@@ -22,6 +22,8 @@ template <index_t GridSize,
class
InGlobalDesc
,
class
InGlobalDesc
,
class
WeiGlobalDesc
,
class
WeiGlobalDesc
,
class
OutGlobalDesc
,
class
OutGlobalDesc
,
class
ConvStrides
,
class
ConvDilations
,
index_t
BPerBlock
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
EPerBlock
,
index_t
EPerBlock
,
...
@@ -117,15 +119,17 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
...
@@ -117,15 +119,17 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer
// input tensor
// input tensor
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Ho
>
{})
constexpr
auto
in_n0_n1_n2_h_w_global_desc
=
.
Slice
(
I3
,
Number
<
Wo
>
{})
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Ho
>
{},
Number
<
ConvStrides
::
Get
(
I0
)
>
{})
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
StridedSlice
(
I3
,
Number
<
Wo
>
{},
Number
<
ConvStrides
::
Get
(
I1
)
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
.
Fold
(
I0
,
Number
<
N1
>
{},
Number
<
N2
>
{})
.
Extract
(
Sequence
<
0
,
1
,
2
,
4
,
5
>
{});
// batch descritpor for device memory
// batch descritpor for device memory
constexpr
auto
in_c_y_x_global_desc
=
in_n_c_h_w_global_desc
.
Slice
(
I2
,
Number
<
Y
>
{})
constexpr
auto
in_c_y_x_global_desc
=
.
Slice
(
I3
,
Number
<
X
>
{})
in_n_c_h_w_global_desc
.
StridedSlice
(
I2
,
Number
<
Y
>
{},
Number
<
ConvDilations
::
Get
(
I0
)
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
.
StridedSlice
(
I3
,
Number
<
X
>
{},
Number
<
ConvDilations
::
Get
(
I1
)
>
{})
.
Extract
(
Sequence
<
1
,
2
,
3
>
{});
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
constexpr
auto
in_e_n1_b_n2_global_merged_desc
=
make_ConstantMergedTensorDescriptor
(
...
...
composable_kernel/include/tensor_description/ConstantTensorDescriptor.hpp
View file @
b1cb48a0
...
@@ -320,6 +320,18 @@ struct ConstantTensorDescriptor
...
@@ -320,6 +320,18 @@ struct ConstantTensorDescriptor
return
ConstantTensorDescriptor
<
slice_lengths
,
Strides
>
{};
return
ConstantTensorDescriptor
<
slice_lengths
,
Strides
>
{};
}
}
template
<
index_t
IDim
,
index_t
SliceLength
,
index_t
SliceStride
>
__host__
__device__
static
constexpr
auto
StridedSlice
(
Number
<
IDim
>
,
Number
<
SliceLength
>
,
Number
<
SliceStride
>
)
{
constexpr
index_t
new_stride
=
Strides
::
Get
(
Number
<
IDim
>
{})
*
SliceStride
;
using
new_lengths
=
decltype
(
Lengths
::
Modify
(
Number
<
IDim
>
{},
Number
<
SliceLength
>
{}));
using
new_strides
=
decltype
(
Strides
::
Modify
(
Number
<
IDim
>
{},
Number
<
new_stride
>
{}));
return
ConstantTensorDescriptor
<
new_lengths
,
new_strides
>
{};
}
template
<
index_t
IDim
,
index_t
...
FoldIntervals
>
template
<
index_t
IDim
,
index_t
...
FoldIntervals
>
__host__
__device__
static
constexpr
auto
Fold
(
Number
<
IDim
>
,
Number
<
FoldIntervals
>
...)
__host__
__device__
static
constexpr
auto
Fold
(
Number
<
IDim
>
,
Number
<
FoldIntervals
>
...)
{
{
...
...
driver/include/conv_common.hpp
View file @
b1cb48a0
...
@@ -36,11 +36,14 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDe
...
@@ -36,11 +36,14 @@ constexpr auto get_convolution_output_default_4d_tensor_descriptor(InDesc, WeiDe
return
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
K
,
HO
,
WO
>
{});
return
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
K
,
HO
,
WO
>
{});
}
}
template
<
class
InDesc
,
class
WeiDesc
,
class
LowerPads
,
class
UpperPads
>
template
<
class
InDesc
,
constexpr
auto
get_convolution_with_padding_output_default_4d_tensor_descriptor
(
InDesc
,
class
WeiDesc
,
WeiDesc
,
class
ConvStrides
,
LowerPads
,
class
ConvDilations
,
UpperPads
)
class
LowerPads
,
class
UpperPads
>
constexpr
auto
get_convolution_with_padding_output_default_4d_tensor_descriptor
(
InDesc
,
WeiDesc
,
ConvStrides
,
ConvDilations
,
LowerPads
,
UpperPads
)
{
{
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
in_desc
=
InDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
constexpr
auto
wei_desc
=
WeiDesc
{};
...
@@ -55,24 +58,27 @@ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
...
@@ -55,24 +58,27 @@ constexpr auto get_convolution_with_padding_output_default_4d_tensor_descriptor(
static_assert
(
in_desc
.
GetLength
(
I1
)
==
wei_desc
.
GetLength
(
I1
),
static_assert
(
in_desc
.
GetLength
(
I1
)
==
wei_desc
.
GetLength
(
I1
),
"input & weight dimension not consistent"
);
"input & weight dimension not consistent"
);
constexpr
auto
N
=
in_desc
.
GetLength
(
I0
);
constexpr
index_t
N
=
in_desc
.
GetLength
(
I0
);
constexpr
auto
H
I
=
in_desc
.
GetLength
(
I2
);
constexpr
index_t
H
i
=
in_desc
.
GetLength
(
I2
);
constexpr
auto
W
I
=
in_desc
.
GetLength
(
I3
);
constexpr
index_t
W
i
=
in_desc
.
GetLength
(
I3
);
constexpr
auto
K
=
wei_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
wei_desc
.
GetLength
(
I0
);
constexpr
auto
Y
=
wei_desc
.
GetLength
(
I2
);
constexpr
index_t
Y
=
wei_desc
.
GetLength
(
I2
);
constexpr
auto
X
=
wei_desc
.
GetLength
(
I3
);
constexpr
index_t
X
=
wei_desc
.
GetLength
(
I3
);
constexpr
auto
HPadLow
=
LowerPads
{}.
Get
(
I0
);
constexpr
index_t
HPadLow
=
LowerPads
{}.
Get
(
I0
);
constexpr
auto
WPadLow
=
LowerPads
{}.
Get
(
I1
);
constexpr
index_t
WPadLow
=
LowerPads
{}.
Get
(
I1
);
constexpr
auto
HPadUp
=
UpperPads
{}.
Get
(
I0
);
constexpr
index_t
HPadUp
=
UpperPads
{}.
Get
(
I0
);
constexpr
auto
WPadUp
=
UpperPads
{}.
Get
(
I1
);
constexpr
index_t
WPadUp
=
UpperPads
{}.
Get
(
I1
);
constexpr
auto
HO
=
HI
+
HPadLow
+
HPadUp
+
1
-
Y
;
constexpr
index_t
YEff
=
(
Y
-
1
)
*
ConvDilations
{}[
0
]
+
1
;
constexpr
auto
WO
=
WI
+
WPadLow
+
WPadUp
+
1
-
X
;
constexpr
index_t
XEff
=
(
X
-
1
)
*
ConvDilations
{}[
1
]
+
1
;
return
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
K
,
HO
,
WO
>
{});
constexpr
index_t
Ho
=
(
Hi
+
HPadLow
+
HPadUp
-
YEff
)
/
ConvStrides
{}[
0
]
+
1
;
constexpr
index_t
Wo
=
(
Wi
+
WPadLow
+
WPadUp
-
XEff
)
/
ConvStrides
{}[
1
]
+
1
;
return
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
K
,
Ho
,
Wo
>
{});
}
}
template
<
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
...
...
driver/include/device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw.hpp
View file @
b1cb48a0
...
@@ -8,13 +8,20 @@
...
@@ -8,13 +8,20 @@
using
namespace
ck
;
using
namespace
ck
;
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
>
void
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
(
InDesc
,
void
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
index_t
nrepeat
)
index_t
nrepeat
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -107,6 +114,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
...
@@ -107,6 +114,8 @@ void device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw(InDesc,
decltype
(
in_nchw_desc
),
decltype
(
in_nchw_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
wei_kcyx_desc
),
decltype
(
out_nkhw_desc
),
decltype
(
out_nkhw_desc
),
ConvStrides
,
ConvDilations
,
BPerBlock
,
BPerBlock
,
KPerBlock
,
KPerBlock
,
CPerBlock
,
CPerBlock
,
...
...
driver/src/driver.cpp
View file @
b1cb48a0
...
@@ -103,10 +103,18 @@ auto make_TensorDescriptor(TConstTensorDesc)
...
@@ -103,10 +103,18 @@ auto make_TensorDescriptor(TConstTensorDesc)
return
TensorDescriptor
(
lengths
,
strides
);
return
TensorDescriptor
(
lengths
,
strides
);
}
}
template
<
class
TIn
,
class
TWei
,
class
TOut
,
class
LowerPads
,
class
UpperPads
>
template
<
class
TIn
,
class
TWei
,
class
TOut
,
class
ConvStrides
,
class
ConvDilations
,
class
LowerPads
,
class
UpperPads
>
void
host_direct_convolution
(
const
Tensor
<
TIn
>&
in_nchw
,
void
host_direct_convolution
(
const
Tensor
<
TIn
>&
in_nchw
,
const
Tensor
<
TWei
>&
wei_kcyx
,
const
Tensor
<
TWei
>&
wei_kcyx
,
Tensor
<
TOut
>&
out_nkhw
,
Tensor
<
TOut
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
LowerPads
,
LowerPads
,
UpperPads
)
UpperPads
)
{
{
...
@@ -122,10 +130,10 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
...
@@ -122,10 +130,10 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
{
{
for
(
int
y
=
0
;
y
<
wei_kcyx
.
mDesc
.
GetLengths
()[
2
];
++
y
)
for
(
int
y
=
0
;
y
<
wei_kcyx
.
mDesc
.
GetLengths
()[
2
];
++
y
)
{
{
int
hi
=
ho
+
y
-
h_pad_low
;
int
hi
=
ho
*
ConvStrides
{}[
0
]
+
y
*
ConvDilations
{}[
0
]
-
h_pad_low
;
for
(
int
x
=
0
;
x
<
wei_kcyx
.
mDesc
.
GetLengths
()[
3
];
++
x
)
for
(
int
x
=
0
;
x
<
wei_kcyx
.
mDesc
.
GetLengths
()[
3
];
++
x
)
{
{
int
wi
=
wo
+
x
-
w_pad_low
;
int
wi
=
wo
*
ConvStrides
{}[
1
]
+
x
*
ConvDilations
{}[
1
]
-
w_pad_low
;
if
(
hi
>=
0
&&
hi
<
in_nchw
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
if
(
hi
>=
0
&&
hi
<
in_nchw
.
mDesc
.
GetLengths
()[
2
]
&&
wi
>=
0
&&
wi
<
in_nchw
.
mDesc
.
GetLengths
()[
3
])
wi
<
in_nchw
.
mDesc
.
GetLengths
()[
3
])
{
{
...
@@ -419,9 +427,9 @@ int main(int argc, char* argv[])
...
@@ -419,9 +427,9 @@ int main(int argc, char* argv[])
constexpr index_t HPad = 0;
constexpr index_t HPad = 0;
constexpr index_t WPad = 0;
constexpr index_t WPad = 0;
#elif
1
#elif
0
// 3x3, 34x34
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
34
;
constexpr
index_t
HI
=
34
;
constexpr
index_t
WI
=
34
;
constexpr
index_t
WI
=
34
;
...
@@ -429,6 +437,9 @@ int main(int argc, char* argv[])
...
@@ -429,6 +437,9 @@ int main(int argc, char* argv[])
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
...
@@ -453,6 +464,9 @@ int main(int argc, char* argv[])
...
@@ -453,6 +464,9 @@ int main(int argc, char* argv[])
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
using
ConvStrides
=
Sequence
<
2
,
2
>
;
using
ConvDilations
=
Sequence
<
1
,
1
>
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif 0
#elif 0
...
@@ -583,7 +597,7 @@ int main(int argc, char* argv[])
...
@@ -583,7 +597,7 @@ int main(int argc, char* argv[])
auto
in_nchw_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
in_nchw_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
N
,
C
,
HI
,
WI
>
{});
auto
wei_kcyx_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
auto
wei_kcyx_desc
=
make_ConstantTensorDescriptor_packed
(
Sequence
<
K
,
C
,
Y
,
X
>
{});
auto
out_nkhw_desc
=
get_convolution_with_padding_output_default_4d_tensor_descriptor
(
auto
out_nkhw_desc
=
get_convolution_with_padding_output_default_4d_tensor_descriptor
(
in_nchw_desc
,
wei_kcyx_desc
,
lower_pads
,
upper_pads
);
in_nchw_desc
,
wei_kcyx_desc
,
ConvStrides
{},
ConvDilations
{},
lower_pads
,
upper_pads
);
ostream_ConstantTensorDescriptor
(
in_nchw_desc
,
std
::
cout
<<
"in_nchw_desc: "
);
ostream_ConstantTensorDescriptor
(
in_nchw_desc
,
std
::
cout
<<
"in_nchw_desc: "
);
ostream_ConstantTensorDescriptor
(
wei_kcyx_desc
,
std
::
cout
<<
"wei_kcyx_desc: "
);
ostream_ConstantTensorDescriptor
(
wei_kcyx_desc
,
std
::
cout
<<
"wei_kcyx_desc: "
);
...
@@ -645,9 +659,17 @@ int main(int argc, char* argv[])
...
@@ -645,9 +659,17 @@ int main(int argc, char* argv[])
#elif 1
#elif 1
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
#endif
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
nrepeat
);
#elif
1
#elif
0
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
(
in_nchw_desc
,
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded
(
in_nchw_desc
,
in_nchw
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx_desc
,
...
@@ -662,14 +684,21 @@ int main(int argc, char* argv[])
...
@@ -662,14 +684,21 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
#if 1
#if 1
if
(
Y
==
3
&&
X
==
3
)
if
(
Y
==
3
&&
X
==
3
&&
ConvStrides
{}[
0
]
==
1
&&
ConvStrides
{}[
1
]
==
1
&&
ConvDilations
{}[
0
]
==
1
&&
ConvDilations
{}[
1
]
==
1
)
{
{
host_winograd_3x3_convolution
(
in_nchw
,
wei_kcyx
,
out_nkhw_host
,
lower_pads
,
upper_pads
);
host_winograd_3x3_convolution
(
in_nchw
,
wei_kcyx
,
out_nkhw_host
,
lower_pads
,
upper_pads
);
}
}
else
else
#endif
#endif
{
{
host_direct_convolution
(
in_nchw
,
wei_kcyx
,
out_nkhw_host
,
lower_pads
,
upper_pads
);
host_direct_convolution
(
in_nchw
,
wei_kcyx
,
out_nkhw_host
,
ConvStrides
{},
ConvDilations
{},
lower_pads
,
upper_pads
);
}
}
check_error
(
out_nkhw_host
,
out_nkhw_device
);
check_error
(
out_nkhw_host
,
out_nkhw_device
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment