Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d4efc6a7
Commit
d4efc6a7
authored
May 24, 2023
by
Po-Yen, Chen
Browse files
Create descriptors on device side
parent
e090e72a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
122 additions
and
138 deletions
+122
-138
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
...pu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
+122
-138
No files found.
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
View file @
d4efc6a7
...
...
@@ -21,47 +21,6 @@ namespace tensor_operation {
namespace
device
{
namespace
detail
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDesc_M_N
c_grid_desc_m_n
,
index_t
NumKBlockLoop
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
,
NumKBlockLoop
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_m_n
;
ignore
=
NumKBlockLoop
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
size_t
Dim
>
struct
KernelArgument
...
...
@@ -115,6 +74,48 @@ struct KernelArgument
};
}
// namespace detail
template
<
index_t
NDim
,
typename
DeviceOp
,
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
detail
::
KernelArgument
<
NDim
>
karg
,
index_t
NumKBlockLoop
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
auto
descs
=
DeviceOp
::
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDim
>(
karg
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared
,
descs
[
Number
<
0
>
{}],
descs
[
Number
<
1
>
{}],
descs
[
Number
<
2
>
{}],
NumKBlockLoop
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_m_n
;
ignore
=
NumKBlockLoop
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
ck
::
index_t
NDimSpatial
,
typename
InDataType
,
...
...
@@ -201,7 +202,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
static
constexpr
auto
GemmK1Number
=
K1Number
;
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
detail
::
KernelArgument
<
NDim
>
karg
)
static
__host__
__device__
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
detail
::
KernelArgument
<
NDim
>
karg
)
{
using
namespace
ck
;
...
...
@@ -390,7 +392,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
detail
::
KernelArgument
<
NDim
>
karg
)
static
__host__
__device__
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
detail
::
KernelArgument
<
NDim
>
karg
)
{
using
namespace
ck
;
...
...
@@ -650,7 +653,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
detail
::
KernelArgument
<
NDim
>
karg
)
static
__host__
__device__
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
detail
::
KernelArgument
<
NDim
>
karg
)
{
using
namespace
ck
;
...
...
@@ -1127,21 +1131,19 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
continue
;
}
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
detail
::
KernelArgument
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_xtilde
},
GridwiseGemm
::
CalculateK0
(
Conv_K_
)));
grid_desc_container_
.
push_back
(
descs
);
karg_container_
.
push_back
(
detail
::
KernelArgument
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_xtilde
},
GridwiseGemm
::
CalculateK0
(
Conv_K_
)));
}
}
...
...
@@ -1174,22 +1176,19 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
continue
;
}
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
detail
::
KernelArgument
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_ytilde
,
i_xtilde
},
GridwiseGemm
::
CalculateK0
(
Conv_K_
)));
grid_desc_container_
.
push_back
(
descs
);
karg_container_
.
push_back
(
detail
::
KernelArgument
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_ytilde
,
i_xtilde
},
GridwiseGemm
::
CalculateK0
(
Conv_K_
)));
}
}
}
...
...
@@ -1231,22 +1230,19 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
continue
;
}
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
detail
::
KernelArgument
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_ztilde
,
i_ytilde
,
i_xtilde
},
GridwiseGemm
::
CalculateK0
(
Conv_K_
)));
grid_desc_container_
.
push_back
(
descs
);
karg_container_
.
push_back
(
detail
::
KernelArgument
<
NDimSpatial
>
(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
,
{
i_ztilde
,
i_ytilde
,
i_xtilde
},
GridwiseGemm
::
CalculateK0
(
Conv_K_
)));
}
}
}
...
...
@@ -1255,7 +1251,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
std
::
vector
<
ABCGridDescs
>
grid_desc
_container_
;
std
::
vector
<
detail
::
KernelArgument
<
NDimSpatial
>>
karg
_container_
;
index_t
M01_
;
// for checking IsSupportedArgument()
index_t
Conv_N_
;
...
...
@@ -1279,36 +1275,29 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
grid_desc
_container_
.
size
();
i
++
)
for
(
size_t
i
=
0
;
i
<
arg
.
karg
_container_
.
size
();
i
++
)
{
auto
a_grid_desc_k0_m_k1
=
arg
.
grid_desc_container_
[
i
][
I0
];
auto
b_grid_desc_k0_n_k1
=
arg
.
grid_desc_container_
[
i
][
I1
];
auto
c_grid_desc_m_n
=
arg
.
grid_desc_container_
[
i
][
I2
];
const
auto
K
=
arg
.
Conv_K_
;
if
(
!
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"
);
}
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
arg
.
karg_container_
[
i
]);
const
auto
c_grid_desc_m_n
=
descs
[
I2
];
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
const
auto
K
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
a_grid_desc_k0_m_k1
.
GetLength
(
I2
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
detail
::
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
C
GridDesc_M_N
,
true
>
;
const
auto
kernel
=
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl
<
NDimSpatial
,
DeviceOp
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B
// datatype
C
DataType
,
true
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
...
...
@@ -1318,22 +1307,19 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
,
arg
.
karg_container_
[
i
],
GridwiseGemm
::
CalculateNumKBlockLoop
(
K
));
}
else
{
const
auto
kernel
=
detail
::
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
false
>
;
const
auto
kernel
=
kernel_convnd_bwd_data_nwc_kxc_nwk_xdl
<
NDimSpatial
,
DeviceOp
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B
// datatype
CDataType
,
false
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
...
...
@@ -1343,9 +1329,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m_n
,
arg
.
karg_container_
[
i
],
GridwiseGemm
::
CalculateNumKBlockLoop
(
K
));
}
}
...
...
@@ -1395,16 +1379,16 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
return
false
;
}
// Gridwise GEMM size
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
grid_desc_container_
.
size
();
i
++
)
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
grid_desc_container_
[
i
][
I0
],
arg
.
grid_desc_container_
[
i
][
I1
],
arg
.
grid_desc_container_
[
i
][
I2
]))
{
return
false
;
}
}
//
//
Gridwise GEMM size
//
for(std::size_t i = 0; i < arg.grid_desc_container_.size(); i++)
//
{
//
if(!GridwiseGemm::CheckValidity(arg.grid_desc_container_[i][I0],
//
arg.grid_desc_container_[i][I1],
//
arg.grid_desc_container_[i][I2]))
//
{
//
return false;
//
}
//
}
return
true
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment