Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c7bf4232
Commit
c7bf4232
authored
Nov 08, 2022
by
letaoqin
Browse files
device transfer elementwiseop to gridwise
parent
f8aef548
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
52 additions
and
21 deletions
+52
-21
example/09_convnd_fwd/convnd_fwd_dl_multiple_d_fp16.cpp
example/09_convnd_fwd/convnd_fwd_dl_multiple_d_fp16.cpp
+2
-1
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+41
-20
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
...tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
+9
-0
No files found.
example/09_convnd_fwd/convnd_fwd_dl_multiple_d_fp16.cpp
View file @
c7bf4232
...
...
@@ -17,7 +17,8 @@ using S = ck::Sequence<Is...>;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
Relu
;
;
static
constexpr
auto
ConvSpec
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
c7bf4232
...
...
@@ -103,6 +103,9 @@ template <typename GridwiseGemm,
typename
ABDataType
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_K0_M0_M1_K1
,
typename
BGridDesc_K0_N0_N1_K1
,
typename
DsGridDesc_M0_M10_M11_N0_N10_N11
,
...
...
@@ -120,6 +123,9 @@ __global__ void
const
ABDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
index_t
batch_count
,
const
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1
,
...
...
@@ -160,6 +166,9 @@ __global__ void
p_ds_grid_grp
,
p_e_grid
+
c_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_k0_m0_m1_k1
,
b_grid_desc_k0_n0_n1_k1
,
ds_grid_desc_m0_m10_m11_n0_n10_n11
,
...
...
@@ -172,6 +181,9 @@ __global__ void
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
batch_count
;
ignore
=
a_grid_desc_k0_m0_m1_k1
;
ignore
=
b_grid_desc_k0_n0_n1_k1
;
...
...
@@ -212,10 +224,10 @@ template <index_t NDimSpatial,
typename
ALayout
,
typename
BLayout
,
typename
DsDataType
,
typename
C
Layout
,
typename
E
Layout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDElementwiseOperation
,
typename
CDE
E
lementwiseOperation
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
...
...
@@ -250,14 +262,14 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
ALayout
,
BLayout
,
DsLayout
,
C
Layout
,
E
Layout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDElementwiseOperation
>
CDE
E
lementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
;
...
...
@@ -338,13 +350,13 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
C
Lay
>
template
<
typename
E
Lay
>
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
C
Lay
>(
e_g_n_k_wos_lengths
,
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
E
Lay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
const
auto
out_gemmm_gemmn_desc
=
...
...
@@ -373,7 +385,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
C
Layout
>
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
E
Layout
>
({},
{}))
>
;
// GridwiseGemm
using
GridwiseGemm
=
...
...
@@ -382,6 +394,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
AccDataType
,
DsLayout
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
...
...
@@ -447,7 +462,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDElementwiseOperation
&
cd_element_op
)
const
CDE
E
lementwiseOperation
&
cd
e
_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b
)},
p_ds_grid_
{},
...
...
@@ -466,7 +481,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
input_right_pads
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
C
Layout
>
(
e_g_n_k_wos_lengths
,
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
E
Layout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
a_grid_desc_k0_m0_m1_k1_
{},
b_grid_desc_k0_n0_n1_k1_
{},
...
...
@@ -476,7 +491,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
compute_ptr_offset_of_batch_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cd_element_op_
{
cd_element_op
},
cd
e
_element_op_
{
cd
e
_element_op
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
a_g_n_c_wis_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
...
...
@@ -570,7 +585,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDElementwiseOperation
cd_element_op_
;
CDE
E
lementwiseOperation
cd
e
_element_op_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
...
...
@@ -621,6 +636,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_K0_M0_M1_K1
,
DeviceOp
::
BGridDesc_K0_N0_N1_K1
,
DeviceOp
::
DsGridDesc_M0_M10_M11_N0_N10_N11
,
...
...
@@ -639,6 +657,9 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
...
...
@@ -796,11 +817,11 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
}
// check vector access of E
if
constexpr
(
is_same_v
<
C
Layout
,
ctc
::
G_NW_K
>
||
is_same_v
<
C
Layout
,
ctc
::
G_NHW_K
>
||
is_same_v
<
C
Layout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
C
Layout
,
ctc
::
GNWK
>
||
is_same_v
<
C
Layout
,
ctc
::
GNHWK
>
||
is_same_v
<
C
Layout
,
ctc
::
GNDHWK
>
||
is_same_v
<
C
Layout
,
ctc
::
NWGK
>
||
is_same_v
<
C
Layout
,
ctc
::
NHWGK
>
||
is_same_v
<
C
Layout
,
ctc
::
NDHWGK
>
)
if
constexpr
(
is_same_v
<
E
Layout
,
ctc
::
G_NW_K
>
||
is_same_v
<
E
Layout
,
ctc
::
G_NHW_K
>
||
is_same_v
<
E
Layout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
E
Layout
,
ctc
::
GNWK
>
||
is_same_v
<
E
Layout
,
ctc
::
GNHWK
>
||
is_same_v
<
E
Layout
,
ctc
::
GNDHWK
>
||
is_same_v
<
E
Layout
,
ctc
::
NWGK
>
||
is_same_v
<
E
Layout
,
ctc
::
NHWGK
>
||
is_same_v
<
E
Layout
,
ctc
::
NDHWGK
>
)
{
const
index_t
K
=
arg
.
e_g_n_k_wos_lengths_
[
2
];
...
...
@@ -842,7 +863,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDElementwiseOperation
&
cd_element_op
)
const
CDE
E
lementwiseOperation
&
cd
e
_element_op
)
{
return
Argument
{
p_a
,
p_b
,
...
...
@@ -862,7 +883,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
input_right_pads
,
a_element_op
,
b_element_op
,
cd_element_op
};
cd
e
_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -886,7 +907,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDElementwiseOperation
&
cd_element_op
)
override
const
CDE
E
lementwiseOperation
&
cd
e
_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
...
...
@@ -906,7 +927,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
input_right_pads
,
a_element_op
,
b_element_op
,
cd_element_op
);
cd
e
_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp
View file @
c7bf4232
...
...
@@ -22,6 +22,9 @@ template <index_t BlockSize,
typename
FloatAcc
,
typename
DsDataType
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
...
...
@@ -247,6 +250,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
DsGridPointer
p_ds_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_K0_M0_M1_K1
&
a_grid_desc_k0_m0_m1_k1
,
const
BGridDesc_K0_N0_N1_K1
&
b_grid_desc_k0_n0_n1_k1
,
const
DsGridDesc_M0_M10_M11_N0_N10_N11
&
ds_grid_desc_m0_m10_m11_n0_n10_n11
,
...
...
@@ -257,6 +263,9 @@ struct GridwiseGemmDlMultipleD_km_kn_mn
{
ignore
=
p_ds_grid
;
ignore
=
ds_grid_desc_m0_m10_m11_n0_n10_n11
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m0_m1_k1
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment