Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cee6c981
Commit
cee6c981
authored
Oct 03, 2020
by
Chao Liu
Browse files
adding dynamic col2im
parent
506df423
Changes
3
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
736 deletions
+15
-736
composable_kernel/include/kernel_algorithm/dummy_dynamic_transform_v2.hpp
...l/include/kernel_algorithm/dummy_dynamic_transform_v2.hpp
+0
-541
driver/include/device_dummy_dynamic_transform_v2.hpp
driver/include/device_dummy_dynamic_transform_v2.hpp
+0
-195
driver/src/col2im_driver.cpp
driver/src/col2im_driver.cpp
+15
-0
No files found.
composable_kernel/include/kernel_algorithm/dummy_dynamic_transform_v2.hpp
deleted
100644 → 0
View file @
506df423
This diff is collapsed.
Click to expand it.
driver/include/device_dummy_dynamic_transform_v2.hpp
deleted
100644 → 0
View file @
506df423
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "dummy_dynamic_transform_v2.hpp"
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
,
class
ConvStrides
,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
>
void
device_dummy_dynamic_transform_v2
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
ConvStrides
,
ConvDilations
,
InLeftPads
,
InRightPads
,
ck
::
index_t
nrepeat
)
{
using
namespace
ck
;
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
const
auto
in_nchw_desc
=
make_dynamic_native_tensor_descriptor
<
4
>
(
to_multi_index
(
InDesc
::
GetLengths
()),
to_multi_index
(
InDesc
::
GetStrides
()));
const
auto
wei_kcyx_desc
=
make_dynamic_native_tensor_descriptor
<
4
>
(
to_multi_index
(
WeiDesc
::
GetLengths
()),
to_multi_index
(
WeiDesc
::
GetStrides
()));
const
auto
out_nkhw_desc
=
make_dynamic_native_tensor_descriptor
<
4
>
(
to_multi_index
(
OutDesc
::
GetLengths
()),
to_multi_index
(
OutDesc
::
GetStrides
()));
const
auto
conv_strides
=
to_multi_index
(
ConvStrides
{});
const
auto
conv_dilations
=
to_multi_index
(
ConvDilations
{});
const
auto
in_left_pads
=
to_multi_index
(
InLeftPads
{});
const
auto
in_right_pads
=
to_multi_index
(
InRightPads
{});
const
auto
tensor_descs
=
map_convolution_into_gemm_fwd_v4r4
(
wei_kcyx_desc
,
in_nchw_desc
,
out_nkhw_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
);
const
auto
in_gemmk_gemmn_gemmkpack_global_desc
=
tensor_descs
.
At
(
Number
<
0
>
{});
// test on cpu
{
auto
in_gemmk_gemmn_gemmkpack_coord
=
make_dynamic_tensor_coordinate
(
in_gemmk_gemmn_gemmkpack_global_desc
,
make_multi_index
(
0
,
0
,
0
));
const
auto
in_gemmk_gemmn_gemmkpack_coord_step_0_0_1
=
make_dynamic_tensor_coordinate_step
(
in_gemmk_gemmn_gemmkpack_global_desc
,
make_multi_index
(
0
,
0
,
1
));
print_array_v2
(
"do_tansforms 0 0 1: "
,
in_gemmk_gemmn_gemmkpack_coord_step_0_0_1
.
do_transforms_
);
for
(
index_t
iter
=
0
;
iter
<
10
;
++
iter
)
{
printf
(
"iter %d
\n
"
,
iter
);
print_array_v2
(
"idx: "
,
in_gemmk_gemmn_gemmkpack_coord
.
GetIndex
());
print_array_v2
(
"hidden idx: "
,
in_gemmk_gemmn_gemmkpack_coord
.
GetHiddenIndex
());
printf
(
"offset: %d
\n
"
,
in_gemmk_gemmn_gemmkpack_coord
.
GetOffset
());
printf
(
"
\n
"
);
move_dynamic_tensor_coordinate
(
in_gemmk_gemmn_gemmkpack_global_desc
,
in_gemmk_gemmn_gemmkpack_coord
,
in_gemmk_gemmn_gemmkpack_coord_step_0_0_1
);
}
}
{
auto
in_gemmk_gemmn_gemmkpack_coord
=
make_dynamic_tensor_coordinate
(
in_gemmk_gemmn_gemmkpack_global_desc
,
make_multi_index
(
0
,
0
,
0
));
const
auto
in_gemmk_gemmn_gemmkpack_coord_step_0_1_0
=
make_dynamic_tensor_coordinate_step
(
in_gemmk_gemmn_gemmkpack_global_desc
,
make_multi_index
(
0
,
1
,
0
));
print_array_v2
(
"do_tansforms 0 1 0: "
,
in_gemmk_gemmn_gemmkpack_coord_step_0_1_0
.
do_transforms_
);
for
(
index_t
iter
=
0
;
iter
<
10
;
++
iter
)
{
printf
(
"iter %d
\n
"
,
iter
);
print_array_v2
(
"idx: "
,
in_gemmk_gemmn_gemmkpack_coord
.
GetIndex
());
print_array_v2
(
"hidden idx: "
,
in_gemmk_gemmn_gemmkpack_coord
.
GetHiddenIndex
());
printf
(
"offset: %d
\n
"
,
in_gemmk_gemmn_gemmkpack_coord
.
GetOffset
());
printf
(
"
\n
"
);
move_dynamic_tensor_coordinate
(
in_gemmk_gemmn_gemmkpack_global_desc
,
in_gemmk_gemmn_gemmkpack_coord
,
in_gemmk_gemmn_gemmkpack_coord_step_0_1_0
);
}
}
{
auto
in_gemmk_gemmn_gemmkpack_coord
=
make_dynamic_tensor_coordinate
(
in_gemmk_gemmn_gemmkpack_global_desc
,
make_multi_index
(
0
,
0
,
0
));
const
auto
in_gemmk_gemmn_gemmkpack_coord_step_1_0_0
=
make_dynamic_tensor_coordinate_step
(
in_gemmk_gemmn_gemmkpack_global_desc
,
make_multi_index
(
1
,
0
,
0
));
print_array_v2
(
"do_tansforms 1 0 0: "
,
in_gemmk_gemmn_gemmkpack_coord_step_1_0_0
.
do_transforms_
);
for
(
index_t
iter
=
0
;
iter
<
10
;
++
iter
)
{
printf
(
"iter %d
\n
"
,
iter
);
print_array_v2
(
"idx: "
,
in_gemmk_gemmn_gemmkpack_coord
.
GetIndex
());
print_array_v2
(
"hidden idx: "
,
in_gemmk_gemmn_gemmkpack_coord
.
GetHiddenIndex
());
printf
(
"offset: %d
\n
"
,
in_gemmk_gemmn_gemmkpack_coord
.
GetOffset
());
printf
(
"
\n
"
);
move_dynamic_tensor_coordinate
(
in_gemmk_gemmn_gemmkpack_global_desc
,
in_gemmk_gemmn_gemmkpack_coord
,
in_gemmk_gemmn_gemmkpack_coord_step_1_0_0
);
}
}
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GridSize
=
1
;
printf
(
"%s: BlockSize %u, GridSize %u
\n
"
,
__func__
,
BlockSize
,
GridSize
);
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
#if 0
launch_kernel(run_gridwise_operation<DummyDynamicTransform_v2_1<BlockSize>,
index_t* const,
float* const,
float* const,
const decltype(wei_kcyx_desc),
const decltype(in_nchw_desc),
const decltype(out_nkhw_desc),
const MultiIndex<2>,
const MultiIndex<2>,
const MultiIndex<2>,
const MultiIndex<2>>,
dim3(GridSize),
dim3(BlockSize),
0,
0,
static_cast<index_t*>(wei_kcyx_device_buf.GetDeviceBuffer()),
static_cast<float*>(in_nchw_device_buf.GetDeviceBuffer()),
static_cast<float*>(out_nkhw_device_buf.GetDeviceBuffer()),
wei_kcyx_desc,
in_nchw_desc,
out_nkhw_desc,
conv_strides,
conv_dilations,
in_left_pads,
in_right_pads);
#else
launch_kernel
(
run_gridwise_operation
<
DummyDynamicTransform_v2_fwd_v4r4
<
BlockSize
>
,
index_t
*
const
,
float
*
const
,
float
*
const
,
const
decltype
(
in_gemmk_gemmn_gemmkpack_global_desc
)
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
index_t
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
float
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()),
in_gemmk_gemmn_gemmkpack_global_desc
);
#endif
}
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/src/col2im_driver.cpp
View file @
cee6c981
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "device_tensor.hpp"
#include "device_tensor.hpp"
#include "host_col2im.hpp"
#include "host_col2im.hpp"
#include "device_col2im_eb_nchw.hpp"
#include "device_col2im_eb_nchw.hpp"
#include "device_dynamic_col2im_eb_nchw.hpp"
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
...
@@ -349,6 +350,7 @@ int main(int argc, char* argv[])
...
@@ -349,6 +350,7 @@ int main(int argc, char* argv[])
#endif
#endif
}
}
#if 0
device_col2im_eb_nchw(col_eb_desc,
device_col2im_eb_nchw(col_eb_desc,
col_eb,
col_eb,
img_nchw_desc,
img_nchw_desc,
...
@@ -360,6 +362,19 @@ int main(int argc, char* argv[])
...
@@ -360,6 +362,19 @@ int main(int argc, char* argv[])
LeftPads{},
LeftPads{},
RightPads{},
RightPads{},
nrepeat);
nrepeat);
#else
device_dynamic_col2im_eb_nchw
(
col_eb_desc
,
col_eb
,
img_nchw_desc
,
img_nchw_device
,
FilterSizes
{},
OutputSizes
{},
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#endif
if
(
do_verification
)
if
(
do_verification
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment