Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e090e72a
Commit
e090e72a
authored
May 24, 2023
by
Po-Yen, Chen
Browse files
Add K0 into KernelArgument
parent
66a297cd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
90 additions
and
87 deletions
+90
-87
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
...pu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
+90
-87
No files found.
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
View file @
e090e72a
...
...
@@ -77,6 +77,7 @@ struct KernelArgument
std
::
array
<
index_t
,
Dim
>
input_left_pads
;
std
::
array
<
index_t
,
Dim
>
input_right_pads
;
std
::
array
<
index_t
,
Dim
>
tildes
;
index_t
K0
;
KernelArgument
(
index_t
N_
,
index_t
K_
,
...
...
@@ -88,8 +89,9 @@ struct KernelArgument
const
std
::
vector
<
index_t
>&
conv_filter_dilations_
,
const
std
::
vector
<
index_t
>&
input_left_pads_
,
const
std
::
vector
<
index_t
>&
input_right_pads_
,
const
std
::
vector
<
index_t
>&
tildes_
)
:
N
{
N_
},
K
{
K_
},
C
{
C_
}
const
std
::
vector
<
index_t
>&
tildes_
,
index_t
K0_
)
:
N
{
N_
},
K
{
K_
},
C
{
C_
},
K0
{
K0_
}
{
#if defined(PP_COPY_MEMBER_VALUES)
#error "PP_COPY_MEMBER_VALUES macro was already defined"
...
...
@@ -213,8 +215,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
index_t
ConvStrideW
=
karg
.
conv_filter_strides
[
0
];
const
index_t
ConvDilationW
=
karg
.
conv_filter_dilations
[
0
];
const
auto
K0
=
karg
.
K
/
K1
;
const
auto
in_n_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
karg
.
N
,
Wi
,
karg
.
C
));
...
...
@@ -225,14 +225,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
karg
.
N
*
Wo
,
karg
.
K
)),
make_tuple
(
make_pass_through_transform
(
karg
.
N
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_unmerge_transform
(
make_tuple
(
karg
.
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
karg
.
K
,
karg
.
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
karg
.
K0
,
K1
)),
make_pass_through_transform
(
karg
.
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
...
@@ -310,13 +310,13 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_tuple
(
make_pass_through_transform
(
karg
.
N
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_unmerge_transform
(
make_tuple
(
karg
.
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
,
4
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
XDotSlice
,
K0
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
XDotSlice
,
karg
.
K0
)),
make_merge_transform
(
make_tuple
(
karg
.
N
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
>
{},
Sequence
<
0
,
2
>
{},
Sequence
<
4
>
{}),
...
...
@@ -334,7 +334,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
auto
wei_k0_k1_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
karg
.
K0
,
K1
)),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_xtilde
),
make_pass_through_transform
(
karg
.
C
)),
...
...
@@ -343,7 +343,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
XDotSlice
,
K0
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
XDotSlice
,
karg
.
K0
)),
make_pass_through_transform
(
karg
.
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
0
>
{},
Sequence
<
3
>
{},
Sequence
<
1
>
{}),
...
...
@@ -418,8 +418,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
index_t
ConvDilationH
=
karg
.
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
karg
.
conv_filter_dilations
[
1
];
const
auto
K0
=
karg
.
K
/
K1
;
const
auto
out_n_ho_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
karg
.
N
,
Ho
,
Wo
,
karg
.
K
));
const
auto
wei_k_y_x_c_grid_desc
=
...
...
@@ -434,14 +432,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
karg
.
N
*
Ho
*
Wo
,
karg
.
K
)),
make_tuple
(
make_pass_through_transform
(
karg
.
N
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_unmerge_transform
(
make_tuple
(
karg
.
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
karg
.
K
,
karg
.
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
karg
.
K0
,
K1
)),
make_pass_through_transform
(
karg
.
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
...
@@ -533,7 +531,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_unmerge_transform
(
make_tuple
(
karg
.
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
@@ -549,7 +547,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
karg
.
K0
)),
make_merge_transform
(
make_tuple
(
karg
.
N
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
...
...
@@ -567,9 +565,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
karg
.
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytilde
),
...
...
@@ -590,7 +588,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
karg
.
K0
)),
make_pass_through_transform
(
karg
.
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
...
...
@@ -688,8 +686,6 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
index_t
ConvDilationH
=
karg
.
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
karg
.
conv_filter_dilations
[
2
];
const
auto
K0
=
karg
.
K
/
K1
;
const
auto
out_n_do_ho_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
karg
.
N
,
Do
,
Ho
,
Wo
,
karg
.
K
));
const
auto
wei_k_z_y_x_c_grid_desc
=
...
...
@@ -704,14 +700,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
karg
.
N
*
Do
*
Ho
*
Wo
,
karg
.
K
)),
make_tuple
(
make_pass_through_transform
(
karg
.
N
*
Do
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_unmerge_transform
(
make_tuple
(
karg
.
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
karg
.
K
,
karg
.
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
karg
.
K0
,
K1
)),
make_pass_through_transform
(
karg
.
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
...
@@ -839,7 +835,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_unmerge_transform
(
make_tuple
(
karg
.
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
@@ -860,7 +856,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
karg
.
K0
)),
make_merge_transform
(
make_tuple
(
karg
.
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
8
>
{}),
...
...
@@ -888,8 +884,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
Sequence
<
7
>
{}));
const
auto
wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
transform_tensor_descriptor
(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
karg
.
K0
,
K1
)),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
...
...
@@ -916,7 +913,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K0
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
karg
.
K0
)),
make_pass_through_transform
(
karg
.
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
3
,
4
,
0
>
{},
Sequence
<
5
>
{},
Sequence
<
1
>
{}),
...
...
@@ -1000,14 +998,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
static
auto
GetDummyABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
1
>
(
detail
::
KernelArgument
<
1
>
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
0
}));
detail
::
KernelArgument
<
1
>
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
0
}
,
1
));
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetDummyABCGridDesc
()
{
return
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
2
>
(
detail
::
KernelArgument
<
2
>
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
}));
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
}
,
1
));
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
...
...
@@ -1024,7 +1022,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
}));
{
0
,
0
,
0
},
1
));
}
// GridwiseGemm
...
...
@@ -1140,7 +1139,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_xtilde
}));
{
i_xtilde
},
GridwiseGemm
::
CalculateK0
(
Conv_K_
)));
grid_desc_container_
.
push_back
(
descs
);
}
}
...
...
@@ -1176,7 +1176,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
detail
::
KernelArgument
<
NDimSpatial
>
(
Conv_N_
,
detail
::
KernelArgument
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
...
...
@@ -1186,7 +1187,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_ytilde
,
i_xtilde
}));
{
i_ytilde
,
i_xtilde
},
GridwiseGemm
::
CalculateK0
(
Conv_K_
)));
grid_desc_container_
.
push_back
(
descs
);
}
}
...
...
@@ -1242,7 +1244,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_ztilde
,
i_ytilde
,
i_xtilde
}));
{
i_ztilde
,
i_ytilde
,
i_xtilde
},
GridwiseGemm
::
CalculateK0
(
Conv_K_
)));
grid_desc_container_
.
push_back
(
descs
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment