Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b9f23971
Commit
b9f23971
authored
Dec 29, 2022
by
Rosty Geyyer
Browse files
Update the argument
parent
cb4b511a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
33 deletions
+37
-33
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+37
-33
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
View file @
b9f23971
...
...
@@ -719,12 +719,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
OutElementwiseOperation
out_element_op
,
ck
::
index_t
split_k
)
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_in_grid
},
p_c_grid_
{
p_wei_grid
},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
a_grid_desc_
kbatch_
k0_m_k1_
{},
b_grid_desc_
kbatch_
k0_n_k1_
{},
c_grid_desc_m_n_
{},
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
...
...
@@ -739,10 +740,9 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
input_right_pads_
{
input_right_pads
},
k_batch_
{
split_k
}
{
k_batch_
=
1
;
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
N
,
...
...
@@ -757,12 +757,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
input_right_pads
,
k_batch_
);
a_grid_desc_k0_m_k1_
=
descs
[
I0
];
b_grid_desc_k0_n_k1_
=
descs
[
I1
];
a_grid_desc_
kbatch_
k0_m_k1_
=
descs
[
I0
];
b_grid_desc_
kbatch_
k0_n_k1_
=
descs
[
I1
];
c_grid_desc_m_n_
=
descs
[
I2
];
a_grid_desc_k0_m0_m1_k1_
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_k0_m_k1_
);
b_grid_desc_k0_n0_n1_k1_
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
b_grid_desc_k0_n_k1_
);
a_grid_desc_
kbatch_
k0_m0_m1_k1_
=
GridwiseGemm
::
MakeAGridDescriptor_K0_M0_M1_K1
(
a_grid_desc_
kbatch_
k0_m_k1_
);
b_grid_desc_
kbatch_
k0_n0_n1_k1_
=
GridwiseGemm
::
MakeBGridDescriptor_K0_N0_N1_K1
(
b_grid_desc_
kbatch_
k0_n_k1_
);
c_grid_desc_m0_m10_m11_n0_n10_n11_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_M10_M11_N0_N10_N11
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
);
}
...
...
@@ -771,12 +771,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_
kbatch_
k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_
kbatch_
k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_K0_M0_M1_K1
a_grid_desc_k0_m0_m1_k1_
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_k0_n0_n1_k1_
;
AGridDesc_K0_M0_M1_K1
a_grid_desc_
kbatch_
k0_m0_m1_k1_
;
BGridDesc_K0_N0_N1_K1
b_grid_desc_
kbatch_
k0_n0_n1_k1_
;
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11_
;
DefaultBlock2CTileMap
block_2_ctile_map_
;
...
...
@@ -809,16 +809,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
void
ShowInfo
(
const
Argument
&
arg
)
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
...
...
@@ -832,16 +834,16 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
ShowInfo
(
arg
);
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
kbatch_
k0_m_k1_
,
arg
.
b_grid_desc_
kbatch_
k0_n_k1_
,
arg
.
c_grid_desc_m_n_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
"wrong! GridwiseGemm
GridwiseGemmDl_km_kn_mn_v1r3
has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
,
auto
has_double_tail_k_block_loop
)
{
...
...
@@ -867,13 +869,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m0_m1_k1_
,
arg
.
b_grid_desc_k0_n0_n1_k1_
,
arg
.
a_grid_desc_
kbatch_
k0_m0_m1_k1_
,
arg
.
b_grid_desc_
kbatch_
k0_n0_n1_k1_
,
arg
.
c_grid_desc_m0_m10_m11_n0_n10_n11_
,
arg
.
block_2_ctile_map_
);
};
const
auto
K0
=
arg
.
a_grid_desc_k0_m0_m1_k1_
.
GetLength
(
I
0
);
const
auto
K0
=
arg
.
a_grid_desc_
kbatch_
k0_m0_m1_k1_
.
GetLength
(
I
1
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K0
);
const
bool
has_double_tail_k_block_loop
=
GridwiseGemm
::
CalculateHasDoubleTailKBlockLoop
(
K0
);
...
...
@@ -985,8 +987,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
}
// Gridwise GEMM size
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
kbatch_
k0_m_k1_
,
arg
.
b_grid_desc_
kbatch_
k0_n_k1_
,
arg
.
c_grid_desc_m_n_
);
}
...
...
@@ -1030,7 +1032,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
};
out_element_op
,
split_k
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
@@ -1071,7 +1074,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Dl
input_right_pads
,
in_element_op
,
wei_element_op
,
out_element_op
);
out_element_op
,
split_k
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment