Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8ce8f734
Commit
8ce8f734
authored
Feb 07, 2021
by
Chao Liu
Browse files
pass tensor descriptor from host to device by reference, pointer and void*
parent
e1eea81a
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
537 additions
and
254 deletions
+537
-254
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+455
-202
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
...lude/tensor_description/dynamic_multi_index_transform.hpp
+4
-1
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+63
-16
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
+10
-30
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+5
-5
No files found.
composable_kernel/include/driver/driver_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
8ce8f734
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_description/dynamic_multi_index_transform.hpp
View file @
8ce8f734
...
...
@@ -949,7 +949,10 @@ struct DynamicFreeze
__host__
__device__
constexpr
DynamicFreeze
()
=
default
;
__host__
__device__
constexpr
DynamicFreeze
(
const
index_t
&
low_idx
)
:
low_idx_
{
low_idx
}
{}
__host__
__device__
constexpr
DynamicFreeze
(
const
index_t
&
low_idx
)
:
low_idx_
{
make_multi_index
(
low_idx
)}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
8ce8f734
...
...
@@ -16,6 +16,9 @@ template <index_t BlockSize,
typename
Float
,
typename
AccFloat
,
InMemoryDataOperation
CGlobalMemoryDataOperation
,
typename
AGlobalDesc
,
typename
BGlobalDesc
,
typename
CGlobalDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
...
...
@@ -74,16 +77,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
return
2
*
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
Float
);
}
template
<
typename
...
ADesc
,
typename
...
BDesc
,
typename
...
CDesc
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
DynamicTensorDescriptor
<
ADesc
...
>&
a_k_m_global_desc
,
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
DynamicTensorDescriptor
<
BDesc
...
>
&
b_k_n_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
DynamicTensorDescriptor
<
CDesc
...
>
&
c_m0_m1_n0_n1_global_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
Float
*
__restrict__
p_c_global
,
Float
*
__restrict__
p_shared_block
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
...
...
@@ -466,16 +465,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
}
}
template
<
typename
...
ADesc
,
typename
...
BDesc
,
typename
...
CDesc
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
DynamicTensorDescriptor
<
ADesc
...
>&
a_k_m_global_desc
,
// pass tensor descriptor by reference
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
&
a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
DynamicTensorDescriptor
<
BDesc
...
>
&
b_k_n_global_desc
,
const
BGlobalDesc
&
b_k_n_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
DynamicTensorDescriptor
<
CDesc
...
>
&
c_m0_m1_n0_n1_global_desc
,
const
CGlobalDesc
&
c_m0_m1_n0_n1_global_desc
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
...
...
@@ -494,6 +490,57 @@ struct GridwiseDynamicGemm_km_kn_mn_v1
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
// pass tensor descriptors by their pointers
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
AGlobalDesc
*
p_a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
BGlobalDesc
*
p_b_k_n_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
CGlobalDesc
*
p_c_m0_m1_n0_n1_global_desc
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
const
auto
a_k_m_global_desc
=
*
p_a_k_m_global_desc
;
const
auto
b_k_n_global_desc
=
*
p_b_k_n_global_desc
;
const
auto
c_m0_m1_n0_n1_global_desc
=
*
p_c_m0_m1_n0_n1_global_desc
;
Run
(
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
// pass tensor descriptors by void*
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
void
*
p_a_k_m_global_desc
,
const
Float
*
__restrict__
p_a_global
,
const
void
*
p_b_k_n_global_desc
,
const
Float
*
__restrict__
p_b_global
,
const
void
*
p_c_m0_m1_n0_n1_global_desc
,
Float
*
__restrict__
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
const
auto
a_k_m_global_desc
=
*
reinterpret_cast
<
const
AGlobalDesc
*>
(
p_a_k_m_global_desc
);
const
auto
b_k_n_global_desc
=
*
reinterpret_cast
<
const
BGlobalDesc
*>
(
p_b_k_n_global_desc
);
const
auto
c_m0_m1_n0_n1_global_desc
=
*
reinterpret_cast
<
const
CGlobalDesc
*>
(
p_c_m0_m1_n0_n1_global_desc
);
Run
(
a_k_m_global_desc
,
p_a_global
,
b_k_n_global_desc
,
p_b_global
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
};
}
// namespace ck
...
...
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp
View file @
8ce8f734
...
...
@@ -263,36 +263,16 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(InDesc
GemmBBlockTransferDstScalarPerVector_GemmN
,
GemmCThreadTransferDstScalarPerVector_GemmN1
>
{};
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
conv_driver
.
Run
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
static_cast
<
TDevice
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TDevice
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TDevice
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
conv_driver
.
Run
(
wei_k_c_y_x_desc
,
in_n_c_hi_wi_desc
,
out_n_k_ho_wo_desc
,
conv_strides
,
conv_dilations
,
in_left_pads
,
in_right_pads
,
static_cast
<
TDevice
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TDevice
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TDevice
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/src/conv_driver.cpp
View file @
8ce8f734
...
...
@@ -11,12 +11,12 @@
#include "conv_common.hpp"
#include "host_conv.hpp"
#include "device_tensor.hpp"
//
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
//
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
//
#include "device_dummy_static_transform.hpp"
//
#include "device_dummy_dynamic_transform_v1.hpp"
//
#include "device_dummy_dynamic_transform.hpp"
#include "device_dummy_static_transform.hpp"
#include "device_dummy_dynamic_transform_v1.hpp"
#include "device_dummy_dynamic_transform.hpp"
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment