Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c29dc4c5
Commit
c29dc4c5
authored
Dec 09, 2021
by
ltqin
Browse files
Merge branch 'develop' into conv_splitk_f32
parents
134af43b
fd3d907a
Changes
44
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
20 deletions
+54
-20
host/driver_offline/src/conv_fwd_driver_offline.cpp
host/driver_offline/src/conv_fwd_driver_offline.cpp
+1
-1
host/host_tensor/include/host_gemm.hpp
host/host_tensor/include/host_gemm.hpp
+13
-4
profiler/include/profile_conv.hpp
profiler/include/profile_conv.hpp
+17
-4
profiler/include/profile_gemm.hpp
profiler/include/profile_gemm.hpp
+23
-11
No files found.
host/driver_offline/src/conv_fwd_driver_offline.cpp
View file @
c29dc4c5
...
...
@@ -18,7 +18,7 @@
#include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp"
#define USE_DYNAMIC_MODE
0
#define USE_DYNAMIC_MODE
1
#define USE_CONV_FWD_V4R4_NCHW 0
#define USE_CONV_FWD_V4R4R2_NHWC 0
#define USE_CONV_FWD_V6R1_NCHW 0
...
...
host/host_tensor/include/host_gemm.hpp
View file @
c29dc4c5
#pragma once
#include "host_tensor.hpp"
template
<
typename
AType
,
typename
BType
,
typename
CType
>
template
<
typename
AType
,
typename
BType
,
typename
CType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
host_gemm_mk_kn_mn
(
const
Tensor
<
AType
>&
a_m_k
,
const
Tensor
<
BType
>&
b_k_n
,
Tensor
<
CType
>&
c_m_n
)
Tensor
<
CType
>&
c_m_n
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
)
{
auto
f_mk_kn_mn
=
[
&
](
auto
m
,
auto
n
)
{
const
int
K
=
a_m_k
.
mDesc
.
GetLengths
()[
1
];
...
...
@@ -13,10 +21,11 @@ void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v
+=
static_cast
<
const
double
>
(
a_m_k
(
m
,
k
))
*
static_cast
<
const
double
>
(
b_k_n
(
k
,
n
));
v
+=
static_cast
<
const
double
>
(
a_element_op
(
a_m_k
(
m
,
k
)))
*
static_cast
<
const
double
>
(
b_element_op
(
b_k_n
(
k
,
n
)));
}
c_m_n
(
m
,
n
)
=
v
;
c_m_n
(
m
,
n
)
=
c_element_op
(
v
)
;
};
make_ParallelTensorFunctor
(
f_mk_kn_mn
,
...
...
profiler/include/profile_conv.hpp
View file @
c29dc4c5
...
...
@@ -8,12 +8,17 @@
#include "device_tensor.hpp"
#include "device_conv.hpp"
#include "device_conv_instance.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv_instance
{
using
DeviceConvFwdNoOpPtr
=
DeviceConvFwdPtr
<
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
template
<
>
void
add_device_conv_fwd_instance
<
2
,
float
,
...
...
@@ -22,7 +27,7 @@ void add_device_conv_fwd_instance<2,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
>&
);
std
::
vector
<
DeviceConvFwd
NoOp
Ptr
>&
);
template
<
>
void
add_device_conv_fwd_instance
<
2
,
...
...
@@ -32,7 +37,7 @@ void add_device_conv_fwd_instance<2,
ck
::
tensor_layout
::
convolution
::
NHWC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
(
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
>&
);
std
::
vector
<
DeviceConvFwd
NoOp
Ptr
>&
);
}
// namespace device_conv_instance
}
// namespace device
...
...
@@ -133,8 +138,13 @@ void profile_conv(int do_verification,
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei_k_c_y_x
.
mData
.
data
());
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DeviceConvFwdNoOpPtr
=
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
<
PassThrough
,
PassThrough
,
PassThrough
>
;
// add device Conv instances
std
::
vector
<
ck
::
tensor_operation
::
device
::
DeviceConvFwdPtr
>
conv_ptrs
;
std
::
vector
<
DeviceConvFwd
NoOp
Ptr
>
conv_ptrs
;
ck
::
tensor_operation
::
device
::
device_conv_instance
::
add_device_conv_fwd_instance
<
2
,
InDataType
,
...
...
@@ -170,7 +180,10 @@ void profile_conv(int do_verification,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
conv_ptr
->
MakeInvokerPointer
();
...
...
profiler/include/profile_gemm.hpp
View file @
c29dc4c5
This diff is collapsed.
Click to expand it.
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment