Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cd01db8b
Commit
cd01db8b
authored
May 24, 2023
by
Po-Yen, Chen
Browse files
Merge descriptors into one object
parent
a434991e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
106 deletions
+59
-106
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
...pu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
+59
-106
No files found.
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
View file @
cd01db8b
...
@@ -151,17 +151,17 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -151,17 +151,17 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
index_t
N
,
ck
::
index_t
K
,
index_t
K
,
ck
::
index_t
C
,
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
tildes
)
std
::
vector
<
index_t
>
tildes
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -348,21 +348,21 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -348,21 +348,21 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
in_gemmm_gemmn_grid_desc
);
}
}
}
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
index_t
N
,
ck
::
index_t
K
,
index_t
K
,
ck
::
index_t
C
,
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
tildes
)
std
::
vector
<
index_t
>
tildes
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -621,22 +621,21 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -621,22 +621,21 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
in_gemmm_gemmn_grid_desc
);
}
}
}
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
index_t
N
,
ck
::
index_t
K
,
index_t
K
,
ck
::
index_t
C
,
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
tildes
)
std
::
vector
<
index_t
>
tildes
)
{
{
using
namespace
ck
;
using
namespace
ck
;
...
@@ -978,8 +977,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -978,8 +977,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
wei_gemmk0_gemmn_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
in_gemmm_gemmn_grid_desc
);
}
}
}
}
// function end
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
GetDummyABCGridDesc
()
static
auto
GetDummyABCGridDesc
()
...
@@ -1125,9 +1123,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -1125,9 +1123,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_left_pads_
,
input_left_pads_
,
input_right_pads_
,
input_right_pads_
,
{
i_xtilde
});
{
i_xtilde
});
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
grid_desc_container_
.
push_back
(
descs
);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
}
}
}
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
...
@@ -1172,9 +1168,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -1172,9 +1168,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_left_pads_
,
input_left_pads_
,
input_right_pads_
,
input_right_pads_
,
{
i_ytilde
,
i_xtilde
});
{
i_ytilde
,
i_xtilde
});
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
grid_desc_container_
.
push_back
(
descs
);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
}
}
}
}
}
}
...
@@ -1228,9 +1222,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -1228,9 +1222,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_left_pads_
,
input_left_pads_
,
input_right_pads_
,
input_right_pads_
,
{
i_ztilde
,
i_ytilde
,
i_xtilde
});
{
i_ztilde
,
i_ytilde
,
i_xtilde
});
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
grid_desc_container_
.
push_back
(
descs
);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
}
}
}
}
}
}
...
@@ -1239,9 +1231,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -1239,9 +1231,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
std
::
vector
<
AGridDesc_K0_M_K1
>
a_grid_desc_k0_m_k1_container_
;
std
::
vector
<
ABCGridDescs
>
grid_desc_container_
;
std
::
vector
<
BGridDesc_K0_N_K1
>
b_grid_desc_k0_n_k1_container_
;
std
::
vector
<
CGridDesc_M_N
>
c_grid_desc_m_n_container_
;
index_t
M01_
;
index_t
M01_
;
// for checking IsSupportedArgument()
// for checking IsSupportedArgument()
index_t
Conv_N_
;
index_t
Conv_N_
;
...
@@ -1265,50 +1255,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -1265,50 +1255,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
float
ave_time
=
0
;
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_
grid_desc_
k0_m_k1_
container_
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
arg
.
grid_desc_container_
.
size
();
i
++
)
{
{
#if DEBUG_LOG
auto
a_grid_desc_k0_m_k1
=
arg
.
grid_desc_container_
[
i
][
I0
];
{
auto
b_grid_desc_k0_n_k1
=
arg
.
grid_desc_container_
[
i
][
I1
];
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_container_{"
auto
c_grid_desc_m_n
=
arg
.
grid_desc_container_
[
i
][
I2
];
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_container_{"
<<
arg
.
b_grid_desc_k0_n_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_container_
[
i
].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_container_{ "
<<
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I2
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I3
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I4
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I5
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I6
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I7
)
<<
" ) "
<<
std
::
endl
;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
))
arg
.
c_grid_desc_m_n_container_
[
i
]))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"
);
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"
);
...
@@ -1316,11 +1270,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -1316,11 +1270,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I1
));
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
*
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1
_container_
[
i
]
.
GetLength
(
I2
);
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
a_grid_desc_k0_m_k1
.
GetLength
(
I2
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
...
@@ -1341,9 +1294,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -1341,9 +1294,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1
_container_
[
i
]
,
a_grid_desc_k0_m_k1
,
arg
.
b_grid_desc_k0_n_k1
_container_
[
i
]
,
b_grid_desc_k0_n_k1
,
arg
.
c_grid_desc_m_n
_container_
[
i
]
,
c_grid_desc_m_n
,
GridwiseGemm
::
CalculateNumKBlockLoop
(
K
));
GridwiseGemm
::
CalculateNumKBlockLoop
(
K
));
}
}
else
else
...
@@ -1366,9 +1319,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -1366,9 +1319,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1
_container_
[
i
]
,
a_grid_desc_k0_m_k1
,
arg
.
b_grid_desc_k0_n_k1
_container_
[
i
]
,
b_grid_desc_k0_n_k1
,
arg
.
c_grid_desc_m_n
_container_
[
i
]
,
c_grid_desc_m_n
,
GridwiseGemm
::
CalculateNumKBlockLoop
(
K
));
GridwiseGemm
::
CalculateNumKBlockLoop
(
K
));
}
}
}
}
...
@@ -1419,11 +1372,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
...
@@ -1419,11 +1372,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
}
}
// Gridwise GEMM size
// Gridwise GEMM size
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
a_
grid_desc_
k0_m_k1_
container_
.
size
();
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
grid_desc_container_
.
size
();
i
++
)
{
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_
grid_desc_
k0_m_k1_
container_
[
i
],
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
grid_desc_container_
[
i
]
[
I0
]
,
arg
.
b_
grid_desc_
k0_n_k1_
container_
[
i
],
arg
.
grid_desc_container_
[
i
]
[
I1
]
,
arg
.
c_
grid_desc_
m_n_
container_
[
i
]))
arg
.
grid_desc_container_
[
i
]
[
I2
]
))
{
{
return
false
;
return
false
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment