Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
12585e57
"...composable_kernel_rocm.git" did not exist on "a2edd7d802b46737e886f0f42a4ee61af03243b7"
Commit
12585e57
authored
Jul 21, 2022
by
Chao Liu
Browse files
refactor
parent
90acba1d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
460 additions
and
301 deletions
+460
-301
example/09_convnd_fwd/convnd_fwd_common.hpp
example/09_convnd_fwd/convnd_fwd_common.hpp
+96
-13
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp
...ensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp
+21
-18
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
...on/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
+343
-270
No files found.
example/09_convnd_fwd/convnd_fwd_common.hpp
View file @
12585e57
...
@@ -142,22 +142,105 @@ int run_conv_fwd(bool do_verification,
...
@@ -142,22 +142,105 @@ int run_conv_fwd(bool do_verification,
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
HostTensorDescriptor
in_n_c_wis_desc
=
in_desc
;
HostTensorDescriptor
wei_k_c_xs_desc
=
wei_desc
;
HostTensorDescriptor
out_n_k_wos_desc
=
out_desc
;
// input
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NWC
>
)
{
in_n_c_wis_desc
=
transpose_host_tensor_descriptor_given_new2old
(
in_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
}
else
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NHWC
>
)
{
in_n_c_wis_desc
=
transpose_host_tensor_descriptor_given_new2old
(
in_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
}
else
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NDHWC
>
)
{
in_n_c_wis_desc
=
transpose_host_tensor_descriptor_given_new2old
(
in_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
}
// weight
if
constexpr
(
ck
::
is_same_v
<
WeiLayout
,
ck
::
tensor_layout
::
convolution
::
KXC
>
)
{
wei_k_c_xs_desc
=
transpose_host_tensor_descriptor_given_new2old
(
wei_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
}
else
if
constexpr
(
ck
::
is_same_v
<
WeiLayout
,
ck
::
tensor_layout
::
convolution
::
KYXC
>
)
{
wei_k_c_xs_desc
=
transpose_host_tensor_descriptor_given_new2old
(
wei_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
}
else
if
constexpr
(
ck
::
is_same_v
<
WeiLayout
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>
)
{
wei_k_c_xs_desc
=
transpose_host_tensor_descriptor_given_new2old
(
wei_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
}
// output
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NWK
>
)
{
out_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
}
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
)
{
out_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
}
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
)
{
out_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
}
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
a_n_c_wis_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
a_n_c_wis_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
b_k_c_xs_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
b_k_c_xs_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
e_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
e_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
{};
auto
copy
=
[](
auto
&
x
,
auto
&
y
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
y
.
begin
());
};
copy
(
in_n_c_wis_desc
.
GetLengths
(),
a_n_c_wis_lengths
);
copy
(
in_n_c_wis_desc
.
GetStrides
(),
a_n_c_wis_strides
);
copy
(
wei_k_c_xs_desc
.
GetLengths
(),
b_k_c_xs_lengths
);
copy
(
wei_k_c_xs_desc
.
GetStrides
(),
b_k_c_xs_strides
);
copy
(
out_n_k_wos_desc
.
GetLengths
(),
e_n_k_wos_lengths
);
copy
(
out_n_k_wos_desc
.
GetStrides
(),
e_n_k_wos_strides
);
copy
(
conv_param
.
conv_filter_strides_
,
conv_filter_strides
);
copy
(
conv_param
.
conv_filter_dilations_
,
conv_filter_dilations
);
copy
(
conv_param
.
input_left_pads_
,
input_left_pads
);
copy
(
conv_param
.
input_right_pads_
,
input_right_pads
);
// do GEMM
// do GEMM
auto
conv
=
DeviceConvNDFwdInstance
{};
auto
conv
=
DeviceConvNDFwdInstance
{};
auto
invoker
=
conv
.
MakeInvoker
();
auto
invoker
=
conv
.
MakeInvoker
();
auto
argument
=
conv
.
MakeArgument
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
conv
.
MakeArgument
(
in_device_buf
.
GetDeviceBuffer
(),
static_cast
<
WeiDataType
*>
(
wei_device_buf
.
GetDeviceBuffer
()),
wei_device_buf
.
GetDeviceBuffer
(),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
std
::
array
<
const
void
*
,
0
>
{},
conv_param
.
N_
,
out_device_buf
.
GetDeviceBuffer
(),
conv_param
.
K_
,
a_n_c_wis_lengths
,
conv_param
.
C_
,
a_n_c_wis_strides
,
conv_param
.
input_spatial_lengths_
,
b_k_c_xs_lengths
,
conv_param
.
filter_spatial_lengths_
,
b_k_c_xs_strides
,
conv_param
.
GetOutputSpatialLengths
(),
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
,
0
>
{{}},
conv_param
.
conv_filter_strides_
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
,
0
>
{{}},
conv_param
.
conv_filter_dilations_
,
e_n_k_wos_lengths
,
conv_param
.
input_left_pads_
,
e_n_k_wos_strides
,
conv_param
.
input_right_pads_
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
in_element_op
,
in_element_op
,
wei_element_op
,
wei_element_op
,
out_element_op
);
out_element_op
);
...
...
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp
View file @
12585e57
...
@@ -20,7 +20,7 @@ namespace device {
...
@@ -20,7 +20,7 @@ namespace device {
// E = cde_op(C, D0, D1, ...)
// E = cde_op(C, D0, D1, ...)
// Assume:
// Assume:
// D0, D1, ... and E have the same layout
// D0, D1, ... and E have the same layout
template
<
ck
::
index_t
NDimSpatial
,
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
DsLayout
,
typename
DsLayout
,
...
@@ -36,23 +36,26 @@ struct DeviceConvFwdMultipleD : public BaseOperator
...
@@ -36,23 +36,26 @@ struct DeviceConvFwdMultipleD : public BaseOperator
{
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
MakeArgumentPointer
(
const
ADataType
*
p_a
,
const
void
*
p_a
,
const
BDataType
*
p_b
,
const
void
*
p_b
,
EDataType
*
p_e
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
ck
::
index_t
N
,
void
*
p_e
,
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
AElementwiseOperation
a_element_op
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
BElementwiseOperation
b_element_op
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
12585e57
...
@@ -181,59 +181,38 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -181,59 +181,38 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
template
<
typename
ALay
,
typename
BLayout_
,
typename
std
::
enable_if
<
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NWC
>,
typename
std
::
enable_if
<
is_same_v
<
BLayout_
,
ck
::
tensor_layout
::
convolution
::
KXC
>
||
bool
>::
type
=
false
>
is_same_v
<
BLayout_
,
ck
::
tensor_layout
::
convolution
::
KYXC
>
||
static
auto
is_same_v
<
BLayout_
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>
,
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
bool
>::
type
=
false
>
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
static
auto
MakeBGridDescriptor_N_K
(
index_t
GemmNRaw
,
index_t
GemmKRaw
)
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
{
const
auto
wei_k_yxc_grid_desc
=
const
index_t
N
=
a_n_c_wis_lengths
[
0
];
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmNRaw
,
GemmKRaw
));
const
index_t
C
=
a_n_c_wis_lengths
[
1
];
const
auto
wei_gemmn_gemmk_grid_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_k_yxc_grid_desc
);
return
wei_gemmn_gemmk_grid_desc
;
}
template
<
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_n_k_wos_lengths
.
begin
()
+
2
,
typename
ELayout_
,
e_n_k_wos_lengths
.
begin
()
+
3
,
typename
std
::
enable_if
<
is_same_v
<
ELayout_
,
ck
::
tensor_layout
::
convolution
::
NWK
>
||
index_t
{
1
},
is_same_v
<
ELayout_
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
||
std
::
multiplies
<
index_t
>
());
is_same_v
<
ELayout_
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
,
bool
>::
type
=
false
>
static
auto
MakeEGridDescriptor_M_N
(
index_t
GemmMRaw
,
index_t
GemmN
)
{
const
index_t
GemmM
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
);
const
auto
out_gemmmraw_gemmn_grid_desc
=
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_k_c_xs_lengths
.
begin
()
+
2
,
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmM
,
GemmN
));
b_k_c_xs_lengths
.
begin
()
+
3
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
auto
out_gemmm_gemmn_grid_desc
=
const
index_t
Wi
=
a_n_c_wis_lengths
[
2
];
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmn_grid_desc
);
return
out_gemmm_gemmn_grid_desc
;
const
index_t
Wo
=
e_n_k_wos_lengths
[
2
];
}
template
<
typename
ALayout_
,
typename
std
::
enable_if
<
is_same_v
<
ALayout_
,
ck
::
tensor_layout
::
convolution
::
NWC
>,
bool
>::
type
=
false
>
static
auto
MakeAGridDescriptor_M_K
(
index_t
N
,
index_t
C
,
index_t
GemmMRaw
,
index_t
GemmKRaw
,
const
std
::
vector
<
index_t
>&
input_spatial_lengths
,
const
std
::
vector
<
index_t
>&
filter_spatial_lengths
,
const
std
::
vector
<
index_t
>&
output_spatial_lengths
,
const
std
::
vector
<
index_t
>&
conv_filter_strides
,
const
std
::
vector
<
index_t
>&
conv_filter_dilations
,
const
std
::
vector
<
index_t
>&
input_left_pads
,
const
std
::
vector
<
index_t
>&
input_right_pads
)
{
const
index_t
Wi
=
input_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
...
@@ -274,7 +253,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -274,7 +253,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
else
else
{
{
const
index_t
X
=
filter_spatial
_lengths
[
0
];
const
index_t
X
=
b_k_c_xs
_lengths
[
2
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
...
@@ -313,26 +292,39 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -313,26 +292,39 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
}
}
template
<
typename
ALay
out_
,
template
<
typename
ALay
,
typename
std
::
enable_if
<
is_same_v
<
ALay
out_
,
ck
::
tensor_layout
::
convolution
::
NHWC
>,
typename
std
::
enable_if
<
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NHWC
>,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
static
auto
MakeAGridDescriptor_M_K
(
index_t
N
,
static
auto
index_t
C
,
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
index_t
GemmMRaw
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
index_t
GemmKRaw
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
const
std
::
vector
<
index_t
>&
input_spatial_length
s
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_stride
s
,
const
std
::
vector
<
index_t
>&
filter_spatial
_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos
_lengths
,
const
std
::
vector
<
index_t
>&
output_spatial_length
s
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_stride
s
,
const
std
::
vector
<
index_t
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
vector
<
index_t
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
vector
<
index_t
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
vector
<
index_t
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
{
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
N
=
a_n_c_wis_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
C
=
a_n_c_wis_lengths
[
1
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_n_k_wos_lengths
.
begin
()
+
2
,
e_n_k_wos_lengths
.
begin
()
+
4
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_k_c_xs_lengths
.
begin
()
+
2
,
const
index_t
Wo
=
output_spatial_lengths
[
1
];
b_k_c_xs_lengths
.
begin
()
+
4
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
Hi
=
a_n_c_wis_lengths
[
2
];
const
index_t
Wi
=
a_n_c_wis_lengths
[
3
];
const
index_t
Ho
=
e_n_k_wos_lengths
[
2
];
const
index_t
Wo
=
e_n_k_wos_lengths
[
3
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
...
@@ -377,8 +369,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -377,8 +369,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
else
else
{
{
const
index_t
Y
=
filter_spatial
_lengths
[
0
];
const
index_t
Y
=
b_k_c_xs
_lengths
[
2
];
const
index_t
X
=
filter_spatial
_lengths
[
1
];
const
index_t
X
=
b_k_c_xs
_lengths
[
3
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
...
@@ -425,28 +417,41 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -425,28 +417,41 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
}
}
template
<
typename
ALay
out_
,
template
<
typename
ALay
,
typename
std
::
enable_if
<
is_same_v
<
ALay
out_
,
ck
::
tensor_layout
::
convolution
::
NDHWC
>,
typename
std
::
enable_if
<
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NDHWC
>,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
static
auto
MakeAGridDescriptor_M_K
(
index_t
N
,
static
auto
index_t
C
,
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
index_t
GemmMRaw
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
index_t
GemmKRaw
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
const
std
::
vector
<
index_t
>&
input_spatial_length
s
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_stride
s
,
const
std
::
vector
<
index_t
>&
filter_spatial
_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos
_lengths
,
const
std
::
vector
<
index_t
>&
output_spatial_length
s
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_stride
s
,
const
std
::
vector
<
index_t
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
vector
<
index_t
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
vector
<
index_t
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
vector
<
index_t
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
{
const
index_t
Di
=
input_spatial_lengths
[
0
];
const
index_t
N
=
a_n_c_wis_lengths
[
0
];
const
index_t
Hi
=
input_spatial_lengths
[
1
];
const
index_t
C
=
a_n_c_wis_lengths
[
1
];
const
index_t
Wi
=
input_spatial_lengths
[
2
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_n_k_wos_lengths
.
begin
()
+
2
,
e_n_k_wos_lengths
.
begin
()
+
5
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_k_c_xs_lengths
.
begin
()
+
2
,
b_k_c_xs_lengths
.
begin
()
+
5
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
Di
=
a_n_c_wis_lengths
[
2
];
const
index_t
Hi
=
a_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
a_n_c_wis_lengths
[
4
];
const
index_t
Do
=
output_spatial
_lengths
[
0
];
const
index_t
Do
=
e_n_k_wos
_lengths
[
2
];
const
index_t
Ho
=
output_spatial
_lengths
[
1
];
const
index_t
Ho
=
e_n_k_wos
_lengths
[
3
];
const
index_t
Wo
=
output_spatial
_lengths
[
2
];
const
index_t
Wo
=
e_n_k_wos
_lengths
[
4
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
...
@@ -495,9 +500,9 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -495,9 +500,9 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
else
else
{
{
const
index_t
Z
=
filter_spatial
_lengths
[
0
];
const
index_t
Z
=
b_k_c_xs
_lengths
[
2
];
const
index_t
Y
=
filter_spatial
_lengths
[
1
];
const
index_t
Y
=
b_k_c_xs
_lengths
[
3
];
const
index_t
X
=
filter_spatial
_lengths
[
2
];
const
index_t
X
=
b_k_c_xs
_lengths
[
4
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
...
@@ -556,49 +561,80 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -556,49 +561,80 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
}
}
static
index_t
GetGemmMRaw
(
index_t
N
,
const
std
::
vector
<
index_t
>&
output_spatial_lengths
)
// supported layout:
// KXC, K_XC
// KYXC, K_YXC
// KZYXC, K_ZYXC
template
<
typename
BLay
,
typename
std
::
enable_if
<
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
KXC
>
||
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
KYXC
>
||
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
KZYXC
>
,
bool
>::
type
=
false
>
static
auto
MakeBGridDescriptor_N_K
(
index_t
GemmNRaw
,
index_t
GemmKRaw
)
{
{
return
N
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths
),
const
auto
wei_k_yxc_grid_desc
=
std
::
end
(
output_spatial_lengths
),
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmNRaw
,
GemmKRaw
));
1
,
std
::
multiplies
<
index_t
>
());
const
auto
wei_gemmn_gemmk_grid_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_k_yxc_grid_desc
);
return
wei_gemmn_gemmk_grid_desc
;
}
}
static
index_t
GetGemmKRaw
(
index_t
C
,
const
std
::
vector
<
index_t
>&
filter_spatial_lengths
)
template
<
typename
ELay
,
typename
std
::
enable_if
<
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NWK
>
||
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NHWK
>
||
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NDHWK
>
,
bool
>::
type
=
false
>
static
auto
MakeEGridDescriptor_M_N
(
index_t
GemmMRaw
,
index_t
GemmN
)
{
{
return
C
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths
),
const
index_t
GemmM
=
math
::
integer_least_multiple
(
GemmMRaw
,
MPerBlock
);
std
::
end
(
filter_spatial_lengths
),
1
,
const
auto
out_gemmmraw_gemmn_grid_desc
=
std
::
multiplies
<
index_t
>
());
make_naive_tensor_descriptor_packed
(
make_tuple
(
GemmM
,
GemmN
));
const
auto
out_gemmm_gemmn_grid_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmn_grid_desc
);
return
out_gemmm_gemmn_grid_desc
;
}
}
static
auto
static
auto
MakeABEGridDescriptor
_A_K0_M_K1_B_K0_N_K1_C_M_N
(
index_t
N
,
MakeABEGridDescriptor
s
(
const
std
::
array
<
index_t
,
N
DimSpatial
+
2
>&
a_n_c_wis_lengths
,
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
std
::
vector
<
index_t
>
input_spatial_length
s
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_stride
s
,
std
::
vector
<
index_t
>
filter_spatial
_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos
_lengths
,
std
::
vector
<
index_t
>
output_spatial_length
s
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_stride
s
,
std
::
vector
<
index_t
>
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
std
::
vector
<
index_t
>
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
std
::
vector
<
index_t
>
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
std
::
vector
<
index_t
>
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
{
using
namespace
ck
;
const
index_t
N
=
a_n_c_wis_lengths
[
0
];
const
index_t
K
=
b_k_c_xs_lengths
[
0
];
const
index_t
C
=
a_n_c_wis_lengths
[
1
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_n_k_wos_lengths
.
begin
()
+
2
,
e_n_k_wos_lengths
.
begin
()
+
2
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
GemmMRaw
=
GetGemmMRaw
(
N
,
output_spatial_lengths
);
const
index_t
GemmNRaw
=
K
;
const
index_t
GemmNRaw
=
K
;
const
index_t
GemmKRaw
=
GetGemmKRaw
(
C
,
filter_spatial_lengths
);
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_k_c_xs_lengths
.
begin
()
+
2
,
b_k_c_xs_lengths
.
begin
()
+
2
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
// A:
// A:
const
auto
in_gemmm_gemmk_grid_desc
=
const
auto
in_gemmm_gemmk_grid_desc
=
MakeAGridDescriptor_M_K
<
ALayout
>
(
N
,
MakeAGridDescriptor_M_K
<
ALayout
>
(
a_n_c_wis_lengths
,
C
,
a_n_c_wis_strides
,
GemmMRaw
,
b_k_c_xs_lengths
,
GemmKRaw
,
b_k_c_xs_strides
,
input_spatial_lengths
,
e_n_k_wos_lengths
,
filter_spatial_lengths
,
e_n_k_wos_strides
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -614,28 +650,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -614,28 +650,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
in_gemmm_gemmk_grid_desc
,
wei_gemmn_gemmk_grid_desc
,
out_gemmm_gemmn_grid_desc
);
in_gemmm_gemmk_grid_desc
,
wei_gemmn_gemmk_grid_desc
,
out_gemmm_gemmn_grid_desc
);
}
}
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
using
ABEGridDescs
=
decltype
(
MakeABEGridDescriptors
({},
{},
{},
{},
{},
{},
{},
{},
{},
{}));
static
auto
GetABEGridDesc
()
{
return
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
1
,
1
,
1
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
});
}
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
GetABEGridDesc
()
{
return
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
});
}
template
<
index_t
NDim
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABEGridDesc
()
{
return
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
1
,
1
,
1
,
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
});
}
using
ABEGridDescs
=
decltype
(
GetABEGridDesc
<
NDimSpatial
>
());
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I0
])
>
;
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I0
])
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I1
])
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
ABEGridDescs
{}[
I1
])
>
;
...
@@ -698,53 +713,61 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -698,53 +713,61 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
const
ADataType
*
p_in_grid
,
Argument
(
const
BDataType
*
p_wei_grid
,
const
void
*
p_a
,
EDataType
*
p_out_grid
,
const
void
*
p_b
,
index_t
N
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
index_t
K
,
void
*
p_e
,
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
std
::
vector
<
index_t
>
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
std
::
vector
<
index_t
>
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
std
::
vector
<
index_t
>
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
std
::
vector
<
index_t
>
conv_filter_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_lengths
,
std
::
vector
<
index_t
>
conv_filter_dilations
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_strides
,
std
::
vector
<
index_t
>
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
std
::
vector
<
index_t
>
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
AElementwiseOperation
in_element_op
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
BElementwiseOperation
wei_element_op
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
CDEElementwiseOperation
out_element_op
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_in_grid
)},
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_wei_grid
)},
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b
)},
p_ds_grid_
{},
// FIXME
p_ds_grid_
{},
// FIXME
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_
out_grid
)},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_
e
)},
a_grid_desc_ak0_m_ak1_
{},
a_grid_desc_ak0_m_ak1_
{},
b_grid_desc_bk0_n_bk1_
{},
b_grid_desc_bk0_n_bk1_
{},
e_grid_desc_m_n_
{},
e_grid_desc_m_n_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{},
block_2_etile_map_
{},
a_element_op_
{
in_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
wei_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
out_element_op
},
cde_element_op_
{
cde_element_op
},
Conv_N_
{
N
},
a_n_c_wis_lengths_
{
a_n_c_wis_lengths
},
Conv_K_
{
K
},
a_n_c_wis_strides_
{
a_n_c_wis_strides
},
Conv_C_
{
C
},
b_k_c_xs_lengths_
{
b_k_c_xs_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
b_k_c_xs_strides_
{
b_k_c_xs_strides
},
ds_n_k_wos_lengths_
{
ds_n_k_wos_lengths
},
ds_n_k_wos_strides_
{
ds_n_k_wos_strides
},
e_n_k_wos_lengths_
{
e_n_k_wos_lengths
},
e_n_k_wos_strides_
{
e_n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
input_right_pads_
{
input_right_pads
}
{
{
const
auto
descs
=
const
auto
descs
=
DeviceOp
::
MakeABEGridDescriptors
(
a_n_c_wis_lengths
,
DeviceOp
::
MakeABEGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
N
,
a_n_c_wis_strides
,
K
,
b_k_c_xs_lengths
,
C
,
b_k_c_xs_strides
,
input_spatial_lengths
,
e_n_k_wos_lengths
,
filter_spatial_lengths
,
e_n_k_wos_strides
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
);
input_right_pads
);
const
auto
a_grid_desc_m_k
=
descs
[
I0
];
const
auto
a_grid_desc_m_k
=
descs
[
I0
];
const
auto
b_grid_desc_n_k
=
descs
[
I1
];
const
auto
b_grid_desc_n_k
=
descs
[
I1
];
...
@@ -796,13 +819,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -796,13 +819,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
// for checking IsSupportedArgument()
index_t
Conv_N_
;
std
::
array
<
index_t
,
NDimSpatial
+
2
>
a_n_c_wis_lengths_
;
index_t
Conv_K_
;
std
::
array
<
index_t
,
NDimSpatial
+
2
>
a_n_c_wis_strides_
;
index_t
Conv_C_
;
std
::
array
<
index_t
,
NDimSpatial
+
2
>
b_k_c_xs_lengths_
;
std
::
vector
<
index_t
>
filter_spatial_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
2
>
b_k_c_xs_strides_
;
std
::
vector
<
index_t
>
conv_filter_strides_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>
ds_n_k_wos_lengths_
;
std
::
vector
<
index_t
>
input_left_pads_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>
ds_n_k_wos_strides_
;
std
::
vector
<
index_t
>
input_right_pads_
;
std
::
array
<
index_t
,
NDimSpatial
+
2
>
e_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
2
>
e_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
};
};
// Invoker
// Invoker
...
@@ -856,7 +884,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -856,7 +884,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
CDEElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
ck
::
StaticallyIndexedArray
<
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
NumDTensor
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
@@ -905,21 +933,9 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -905,21 +933,9 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if 1
namespace
ctc
=
tensor_layout
::
convolution
;
{
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_bk0_n_bk1_{"
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.e_grid_desc_m_n_{ "
<<
arg
.
e_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
if
(
get_device_name
()
==
"gfx908"
)
<<
arg
.
e_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
#endif
if
(
ck
::
get_device_name
()
==
"gfx908"
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
is_same_v
<
AccDataType
,
int32_t
>
))
...
@@ -927,7 +943,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -927,7 +943,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
false
;
return
false
;
}
}
}
}
else
if
(
ck
::
get_device_name
()
==
"gfx90a"
)
else
if
(
get_device_name
()
==
"gfx90a"
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
...
@@ -940,8 +956,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -940,8 +956,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
false
;
return
false
;
}
}
//
tensors
can't be
big
ger than 2GB each
.
//
check tensor size:
can't be
lar
ger than 2GB each
constexpr
ck
::
long_index_t
GB2
=
(
ck
::
long_index_t
{
1
}
<<
31
);
constexpr
long_index_t
GB2
=
(
long_index_t
{
1
}
<<
31
);
if
(
arg
.
a_grid_desc_ak0_m_ak1_
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
>
GB2
||
if
(
arg
.
a_grid_desc_ak0_m_ak1_
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
>
GB2
||
arg
.
b_grid_desc_bk0_n_bk1_
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
>
GB2
||
arg
.
b_grid_desc_bk0_n_bk1_
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
>
GB2
||
...
@@ -950,14 +966,19 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -950,14 +966,19 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
false
;
return
false
;
}
}
// check ConvolutionForwardSpecialization
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
{
// check if it's 1x1, stride=1 conv
// check if it's 1x1, stride=1 conv
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
conv_filter_strides_
[
i
]
==
1
&&
const
index_t
X
=
arg
.
b_k_c_xs_lengths_
[
i
+
2
];
arg
.
input_left_pads_
[
i
]
==
0
&&
arg
.
input_right_pads_
[
i
]
==
0
))
const
index_t
ConvStride
=
arg
.
conv_filter_strides_
[
i
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
if
(
!
(
X
==
1
&&
ConvStride
==
1
&&
LeftPad
==
0
&&
RightPad
==
0
))
{
{
return
false
;
return
false
;
}
}
...
@@ -969,24 +990,63 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -969,24 +990,63 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// check if it's 1x1 conv
// check if it's 1x1 conv
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
{
if
(
!
(
arg
.
filter_spatial_lengths_
[
i
]
==
1
&&
arg
.
input_left_pads_
[
i
]
==
0
&&
const
index_t
X
=
arg
.
b_k_c_xs_lengths_
[
i
+
2
];
arg
.
input_right_pads_
[
i
]
==
0
))
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
if
(
!
(
X
==
1
&&
LeftPad
==
0
&&
RightPad
==
0
))
{
{
return
false
;
return
false
;
}
}
}
}
}
}
// vector load A/B matrix from global memory
// check vector access of A
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
BBlockTransferSrcVectorDim
==
2
&&
if
constexpr
(
is_same_v
<
ALayout
,
ctc
::
NWC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWC
>
||
arg
.
Conv_C_
%
ABlockTransferSrcScalarPerVector
==
0
&&
is_same_v
<
ALayout
,
ctc
::
NDHWC
>
)
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
))
{
const
index_t
C
=
arg
.
a_n_c_wis_lengths_
[
1
];
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
// vector store D/E matrix into global memory
// check vector access of B
if
(
!
(
arg
.
Conv_K_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
if
constexpr
(
is_same_v
<
BLayout
,
ctc
::
KXC
>
||
is_same_v
<
BLayout
,
ctc
::
KYXC
>
||
is_same_v
<
BLayout
,
ctc
::
KZYXC
>
)
{
const
index_t
C
=
arg
.
b_k_c_xs_lengths_
[
1
];
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
return
false
;
}
// FIXME: check vector access of Ds
// check vector access of E
if
constexpr
(
is_same_v
<
ELayout
,
ctc
::
NWK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWK
>
)
{
const
index_t
K
=
arg
.
e_n_k_wos_lengths_
[
1
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
}
else
{
{
return
false
;
return
false
;
}
}
...
@@ -1003,77 +1063,90 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1003,77 +1063,90 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
MakeArgument
(
const
ADataType
*
p_in_grid
,
static
auto
MakeArgument
(
const
BDataType
*
p_wei_grid
,
const
void
*
p_a
,
EDataType
*
p_out_grid
,
const
void
*
p_b
,
index_t
N
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
index_t
K
,
void
*
p_e
,
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
std
::
vector
<
index_t
>
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
std
::
vector
<
index_t
>
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
std
::
vector
<
index_t
>
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
std
::
vector
<
index_t
>
conv_filter_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_lengths
,
std
::
vector
<
index_t
>
conv_filter_dilations
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_strides
,
std
::
vector
<
index_t
>
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
std
::
vector
<
index_t
>
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
AElementwiseOperation
in_element_op
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
BElementwiseOperation
wei_element_op
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
CDEElementwiseOperation
out_element_op
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
{
return
Argument
{
p_in_grid
,
return
Argument
{
p_a
,
p_wei_grid
,
p_b
,
p_out_grid
,
p_ds
,
N
,
p_e
,
K
,
a_n_c_wis_lengths
,
C
,
a_n_c_wis_strides
,
input_spatial_lengths
,
b_k_c_xs_lengths
,
filter_spatial_lengths
,
b_k_c_xs_strides
,
output_spatial_lengths
,
ds_n_k_wos_lengths
,
ds_n_k_wos_strides
,
e_n_k_wos_lengths
,
e_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
in
_element_op
,
a
_element_op
,
wei
_element_op
,
b
_element_op
,
out
_element_op
};
cde
_element_op
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
MakeArgumentPointer
(
const
ADataType
*
p_in_grid
,
const
void
*
p_a
,
const
BDataType
*
p_wei_grid
,
const
void
*
p_b
,
EDataType
*
p_out_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
index_t
N
,
void
*
p_e
,
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
std
::
vector
<
index_t
>
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
std
::
vector
<
index_t
>
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
std
::
vector
<
index_t
>
output_spatial_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_lengths
,
std
::
vector
<
index_t
>
conv_filter_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_strides
,
std
::
vector
<
index_t
>
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
std
::
vector
<
index_t
>
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
std
::
vector
<
index_t
>
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
AElementwiseOperation
in_element_op
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
BElementwiseOperation
wei_element_op
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
CDEElementwiseOperation
out_element_op
)
override
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_in_grid
),
return
std
::
make_unique
<
Argument
>
(
p_a
,
static_cast
<
const
BDataType
*>
(
p_wei_grid
),
p_b
,
static_cast
<
EDataType
*>
(
p_out_grid
),
p_ds
,
N
,
p_e
,
K
,
a_n_c_wis_lengths
,
C
,
a_n_c_wis_strides
,
input_spatial_lengths
,
b_k_c_xs_lengths
,
filter_spatial_lengths
,
b_k_c_xs_strides
,
output_spatial_lengths
,
ds_n_k_wos_lengths
,
ds_n_k_wos_strides
,
e_n_k_wos_lengths
,
e_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
,
input_right_pads
,
in
_element_op
,
a
_element_op
,
wei
_element_op
,
b
_element_op
,
out
_element_op
);
cde
_element_op
);
}
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment