Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b53926e9
Commit
b53926e9
authored
Jul 23, 2021
by
Jing Zhang
Browse files
activ_type argument
parent
fe427fd1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
39 additions
and
30 deletions
+39
-30
composable_kernel/include/driver/driver_static_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp
...tion_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp
+11
-1
composable_kernel/include/tensor_operation/gridwise_static_gemm_v2.hpp
...rnel/include/tensor_operation/gridwise_static_gemm_v2.hpp
+5
-3
host/driver_offline/conv_fwd_driver_offline.cpp
host/driver_offline/conv_fwd_driver_offline.cpp
+15
-13
host/driver_offline/include/device_static_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
...convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
+2
-0
host/host_tensor/include/host_conv.hpp
host/host_tensor/include/host_conv.hpp
+6
-13
No files found.
composable_kernel/include/driver/driver_static_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp
View file @
b53926e9
...
@@ -35,7 +35,8 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -35,7 +35,8 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
typename
ConvStrides
,
typename
ConvStrides
,
typename
ConvDilations
,
typename
ConvDilations
,
typename
InLeftPads
,
typename
InLeftPads
,
typename
InRightPads
>
typename
InRightPads
,
index_t
activ_type
>
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
__host__
void
Run
(
const
DynamicTensorDescriptor
<
Wei
...
>&
wei_k_c_y_x_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
In
...
>&
in_n_c_hi_wi_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
const
DynamicTensorDescriptor
<
Out
...
>&
out_n_k0_ho_wo_k1_global_desc
,
...
@@ -43,6 +44,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -43,6 +44,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads_
,
const
InRightPads
&
in_right_pads_
,
Number
<
activ_type
>
,
const
FloatAB
*
__restrict__
p_wei_global
,
const
FloatAB
*
__restrict__
p_wei_global
,
const
FloatAB
*
__restrict__
p_in_global
,
const
FloatAB
*
__restrict__
p_in_global
,
FloatC
*
__restrict__
p_out_global
)
const
FloatC
*
__restrict__
p_out_global
)
const
...
@@ -297,6 +299,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -297,6 +299,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
FloatAB
*
,
const
FloatAB
*
,
const
FloatAB
*
,
const
FloatAB
*
,
FloatC
*
,
FloatC
*
,
Number
<
activ_type
>
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>>
;
integral_constant
<
bool
,
true
>>
;
...
@@ -308,6 +311,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -308,6 +311,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global
,
p_wei_global
,
p_in_global
,
p_in_global
,
p_out_global
,
p_out_global
,
Number
<
activ_type
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{});
integral_constant
<
bool
,
true
>
{});
}
}
...
@@ -317,6 +321,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -317,6 +321,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
FloatAB
*
,
const
FloatAB
*
,
const
FloatAB
*
,
const
FloatAB
*
,
FloatC
*
,
FloatC
*
,
Number
<
activ_type
>
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
true
>
,
integral_constant
<
bool
,
false
>>
;
integral_constant
<
bool
,
false
>>
;
...
@@ -328,6 +333,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -328,6 +333,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global
,
p_wei_global
,
p_in_global
,
p_in_global
,
p_out_global
,
p_out_global
,
Number
<
activ_type
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
true
>
{},
integral_constant
<
bool
,
false
>
{});
integral_constant
<
bool
,
false
>
{});
}
}
...
@@ -337,6 +343,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -337,6 +343,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
FloatAB
*
,
const
FloatAB
*
,
const
FloatAB
*
,
const
FloatAB
*
,
FloatC
*
,
FloatC
*
,
Number
<
activ_type
>
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
true
>>
;
integral_constant
<
bool
,
true
>>
;
...
@@ -348,6 +355,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -348,6 +355,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global
,
p_wei_global
,
p_in_global
,
p_in_global
,
p_out_global
,
p_out_global
,
Number
<
activ_type
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
true
>
{});
integral_constant
<
bool
,
true
>
{});
}
}
...
@@ -357,6 +365,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -357,6 +365,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
const
FloatAB
*
,
const
FloatAB
*
,
const
FloatAB
*
,
const
FloatAB
*
,
FloatC
*
,
FloatC
*
,
Number
<
activ_type
>
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>
,
integral_constant
<
bool
,
false
>>
;
integral_constant
<
bool
,
false
>>
;
...
@@ -368,6 +377,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
...
@@ -368,6 +377,7 @@ struct DriverStaticConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
p_wei_global
,
p_wei_global
,
p_in_global
,
p_in_global
,
p_out_global
,
p_out_global
,
Number
<
activ_type
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{},
integral_constant
<
bool
,
false
>
{});
integral_constant
<
bool
,
false
>
{});
}
}
...
...
composable_kernel/include/tensor_operation/gridwise_static_gemm_v2.hpp
View file @
b53926e9
...
@@ -72,11 +72,12 @@ struct GridwiseStaticGemm_km_kn_mn_v3
...
@@ -72,11 +72,12 @@ struct GridwiseStaticGemm_km_kn_mn_v3
return
a_block_space_size
*
sizeof
(
FloatAB
);
return
a_block_space_size
*
sizeof
(
FloatAB
);
}
}
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
index_t
activ_type
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
__device__
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
Number
<
activ_type
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
{
...
@@ -348,7 +349,6 @@ struct GridwiseStaticGemm_km_kn_mn_v3
...
@@ -348,7 +349,6 @@ struct GridwiseStaticGemm_km_kn_mn_v3
// activ
// activ
{
{
constexpr
index_t
activ_type
=
2
;
static_for
<
0
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
activ_type
==
1
)
if
constexpr
(
activ_type
==
1
)
c_thread_buf
(
i
)
=
c_thread_buf
[
i
]
>=
0
?
c_thread_buf
[
i
]
:
0.0
;
c_thread_buf
(
i
)
=
c_thread_buf
[
i
]
>=
0
?
c_thread_buf
[
i
]
:
0.0
;
...
@@ -392,10 +392,11 @@ struct GridwiseStaticGemm_km_kn_mn_v3
...
@@ -392,10 +392,11 @@ struct GridwiseStaticGemm_km_kn_mn_v3
}
}
// pass tensor descriptor by reference
// pass tensor descriptor by reference
template
<
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
template
<
index_t
activ_type
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
__device__
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
__device__
void
Run
(
const
FloatAB
*
__restrict__
p_a_global
,
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
Number
<
activ_type
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
const
{
{
...
@@ -407,6 +408,7 @@ struct GridwiseStaticGemm_km_kn_mn_v3
...
@@ -407,6 +408,7 @@ struct GridwiseStaticGemm_km_kn_mn_v3
p_b_global
,
p_b_global
,
p_c_global
,
p_c_global
,
p_shared_block
,
p_shared_block
,
Number
<
activ_type
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
}
}
...
...
host/driver_offline/conv_fwd_driver_offline.cpp
View file @
b53926e9
...
@@ -437,6 +437,8 @@ int main(int argc, char* argv[])
...
@@ -437,6 +437,8 @@ int main(int argc, char* argv[])
}
}
#endif
#endif
constexpr
ck
::
index_t
activ_type
=
2
;
#if USE_CONV_FWD_V5R1_NCHW
#if USE_CONV_FWD_V5R1_NCHW
if
(
algo
==
ConvForwardAlgo
::
V5R1NCHW
)
if
(
algo
==
ConvForwardAlgo
::
V5R1NCHW
)
{
{
...
@@ -452,17 +454,17 @@ int main(int argc, char* argv[])
...
@@ -452,17 +454,17 @@ int main(int argc, char* argv[])
#else
#else
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw
#endif
#endif
<
in_data_t
,
8
,
8
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
<
in_data_t
,
8
,
8
,
activ_type
,
acc_data_t
,
out_data_t
>
(
tmp
[
I0
],
tmp
[
I1
],
tmp
[
I1
],
tmp
[
I2
],
tmp
[
I2
],
tmp
[
I3
],
tmp
[
I3
],
tmp
[
I4
],
tmp
[
I4
],
tmp
[
I5
],
tmp
[
I5
],
tmp
[
I6
],
tmp
[
I6
],
in
,
in
,
wei
,
wei
,
out_device
,
out_device
,
nrepeat
);
nrepeat
);
}
}
#endif
#endif
...
@@ -529,8 +531,8 @@ int main(int argc, char* argv[])
...
@@ -529,8 +531,8 @@ int main(int argc, char* argv[])
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
conv_dilation_h
,
conv_dilation_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_left_pad_h
,
in_left_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
make_tuple
(
in_right_pad_h
,
in_right_pad_w
),
layout
,
activ_type
,
ActivType_t
::
sigmoid
);
layout
);
check_error
(
out_host
,
out_device
);
check_error
(
out_host
,
out_device
);
...
...
host/driver_offline/include/device_static_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp
View file @
b53926e9
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
template
<
typename
TInWei
,
template
<
typename
TInWei
,
ck
::
index_t
InWeiVectorSize
,
ck
::
index_t
InWeiVectorSize
,
ck
::
index_t
OutVectorSize
,
ck
::
index_t
OutVectorSize
,
ck
::
index_t
activ_type
,
typename
TAcc
,
typename
TAcc
,
typename
TOut
,
typename
TOut
,
typename
InLengths
,
typename
InLengths
,
...
@@ -152,6 +153,7 @@ void device_static_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
...
@@ -152,6 +153,7 @@ void device_static_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
conv_dilations
,
conv_dilations
,
in_left_pads
,
in_left_pads
,
in_right_pads
,
in_right_pads
,
Number
<
activ_type
>
{},
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
wei_k_c0_y_x_c1_device_buf
.
GetDeviceBuffer
()),
wei_k_c0_y_x_c1_device_buf
.
GetDeviceBuffer
()),
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
static_cast
<
typename
vector_type
<
TInWei
,
InWeiVectorSize
>::
type
*>
(
...
...
host/host_tensor/include/host_conv.hpp
View file @
b53926e9
#pragma once
#pragma once
#include "host_tensor.hpp"
#include "host_tensor.hpp"
typedef
enum
{
passthrough
=
0
,
relu
,
sigmoid
}
ActivType_t
;
template
<
typename
TIn
,
template
<
typename
TIn
,
typename
TWei
,
typename
TWei
,
typename
TOut
,
typename
TOut
,
...
@@ -96,13 +89,13 @@ void host_direct_convolution(const Tensor<TIn>& in,
...
@@ -96,13 +89,13 @@ void host_direct_convolution(const Tensor<TIn>& in,
}
}
template
<
typename
T
>
template
<
typename
T
>
inline
auto
activ
(
T
v
,
const
ActivType
_t
activ_type
)
inline
auto
activ
(
T
v
,
const
ck
::
index
_t
activ_type
)
{
{
switch
(
activ_type
)
switch
(
activ_type
)
{
{
case
passthrough
:
return
v
;
case
0
:
return
v
;
case
relu
:
return
(
v
>=
0
?
v
:
0
);
case
1
:
return
(
v
>=
0
?
v
:
0
);
case
sigmoid
:
return
(
1
/
(
1
+
exp
(
-
v
)));
case
2
:
return
(
1
/
(
1
+
exp
(
-
v
)));
default:
throw
std
::
runtime_error
(
"unsupported activ type"
);
break
;
default:
throw
std
::
runtime_error
(
"unsupported activ type"
);
break
;
}
}
}
}
...
@@ -121,8 +114,8 @@ void host_direct_convolution_activ(const Tensor<TIn>& in,
...
@@ -121,8 +114,8 @@ void host_direct_convolution_activ(const Tensor<TIn>& in,
const
ConvDilations
&
conv_dilations
,
const
ConvDilations
&
conv_dilations
,
const
InLeftPads
&
in_left_pads
,
const
InLeftPads
&
in_left_pads
,
const
InRightPads
&
in_right_pads
,
const
InRightPads
&
in_right_pads
,
const
ConvTensorLayout
layout
=
ConvTensorLayout
::
NCHW
,
const
ck
::
index_t
activ_type
,
const
ActivType_t
activ_type
=
ActivType_t
::
passthrough
)
const
ConvTensorLayout
layout
=
ConvTensorLayout
::
NCHW
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment