Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ae1b4ee6
Commit
ae1b4ee6
authored
Jul 22, 2022
by
Chao Liu
Browse files
add bias
parent
cf95b944
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
298 additions
and
111 deletions
+298
-111
example/09_convnd_fwd/convnd_fwd_common.hpp
example/09_convnd_fwd/convnd_fwd_common.hpp
+74
-26
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
+8
-3
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
...on/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
+125
-34
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+54
-27
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+37
-21
No files found.
example/09_convnd_fwd/convnd_fwd_common.hpp
View file @
ae1b4ee6
...
@@ -114,13 +114,23 @@ int run_conv_fwd(bool do_verification,
...
@@ -114,13 +114,23 @@ int run_conv_fwd(bool do_verification,
const
auto
wei_desc
=
ck
::
utils
::
conv
::
get_weight_host_tensor_descriptor
<
WeiLayout
>
(
conv_param
);
const
auto
wei_desc
=
ck
::
utils
::
conv
::
get_weight_host_tensor_descriptor
<
WeiLayout
>
(
conv_param
);
const
auto
out_desc
=
ck
::
utils
::
conv
::
get_output_host_tensor_descriptor
<
OutLayout
>
(
conv_param
);
const
auto
out_desc
=
ck
::
utils
::
conv
::
get_output_host_tensor_descriptor
<
OutLayout
>
(
conv_param
);
// hacky, hardcoded for 2d NHWK
const
auto
bias_desc
=
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
conv_param
.
N_
),
static_cast
<
std
::
size_t
>
(
conv_param
.
output_spatial_lengths_
[
0
]),
static_cast
<
std
::
size_t
>
(
conv_param
.
output_spatial_lengths_
[
1
]),
static_cast
<
std
::
size_t
>
(
conv_param
.
K_
)},
std
::
vector
<
std
::
size_t
>
{
0
,
0
,
0
,
1
});
Tensor
<
InDataType
>
in
(
in_desc
);
Tensor
<
InDataType
>
in
(
in_desc
);
Tensor
<
WeiDataType
>
wei
(
wei_desc
);
Tensor
<
WeiDataType
>
wei
(
wei_desc
);
Tensor
<
OutDataType
>
bias
(
bias_desc
);
Tensor
<
OutDataType
>
out_host
(
out_desc
);
Tensor
<
OutDataType
>
out_host
(
out_desc
);
Tensor
<
OutDataType
>
out_device
(
out_desc
);
Tensor
<
OutDataType
>
out_device
(
out_desc
);
std
::
cout
<<
"in: "
<<
in
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"in: "
<<
in
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei: "
<<
wei
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei: "
<<
wei
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"bias: "
<<
bias
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out: "
<<
out_host
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out: "
<<
out_host
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
switch
(
init_method
)
...
@@ -129,23 +139,28 @@ int run_conv_fwd(bool do_verification,
...
@@ -129,23 +139,28 @@ int run_conv_fwd(bool do_verification,
case
1
:
case
1
:
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
in
.
GenerateTensorValue
(
GeneratorTensor_2
<
InDataType
>
{
-
5
,
5
});
wei
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
wei
.
GenerateTensorValue
(
GeneratorTensor_2
<
WeiDataType
>
{
-
5
,
5
});
bias
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
5
,
5
});
break
;
break
;
default:
default:
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0.0
,
1.0
});
in
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
0.0
,
1.0
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
wei
.
GenerateTensorValue
(
GeneratorTensor_3
<
WeiDataType
>
{
-
0.5
,
0.5
});
bias
.
GenerateTensorValue
(
GeneratorTensor_3
<
OutDataType
>
{
-
0.5
,
0.5
});
}
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpace
());
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_device_buf
(
sizeof
(
WeiDataType
)
*
wei
.
mDesc
.
GetElementSpace
());
DeviceMem
bias_device_buf
(
sizeof
(
OutDataType
)
*
bias
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_device
.
mDesc
.
GetElementSpace
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_device
.
mDesc
.
GetElementSpace
());
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
in_device_buf
.
ToDevice
(
in
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias
.
mData
.
data
());
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
HostTensorDescriptor
in_n_c_wis_desc
=
in_desc
;
HostTensorDescriptor
in_n_c_wis_desc
=
in_desc
;
HostTensorDescriptor
wei_k_c_xs_desc
=
wei_desc
;
HostTensorDescriptor
wei_k_c_xs_desc
=
wei_desc
;
HostTensorDescriptor
out_n_k_wos_desc
=
out_desc
;
HostTensorDescriptor
bias_n_k_wos_desc
=
bias_desc
;
HostTensorDescriptor
out_n_k_wos_desc
=
out_desc
;
// input
// input
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NWC
>
)
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NWC
>
)
...
@@ -186,22 +201,33 @@ int run_conv_fwd(bool do_verification,
...
@@ -186,22 +201,33 @@ int run_conv_fwd(bool do_verification,
{
{
out_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
out_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
bias_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
bias_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
}
}
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
)
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
)
{
{
out_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
out_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
bias_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
bias_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
}
}
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
)
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
)
{
{
out_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
out_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
bias_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
bias_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
}
}
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
a_n_c_wis_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
a_n_c_wis_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
a_n_c_wis_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
a_n_c_wis_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
b_k_c_xs_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
b_k_c_xs_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
b_k_c_xs_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
b_k_c_xs_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
d_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
d_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
e_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
e_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
e_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
e_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
...
@@ -215,6 +241,8 @@ int run_conv_fwd(bool do_verification,
...
@@ -215,6 +241,8 @@ int run_conv_fwd(bool do_verification,
copy
(
in_n_c_wis_desc
.
GetStrides
(),
a_n_c_wis_strides
);
copy
(
in_n_c_wis_desc
.
GetStrides
(),
a_n_c_wis_strides
);
copy
(
wei_k_c_xs_desc
.
GetLengths
(),
b_k_c_xs_lengths
);
copy
(
wei_k_c_xs_desc
.
GetLengths
(),
b_k_c_xs_lengths
);
copy
(
wei_k_c_xs_desc
.
GetStrides
(),
b_k_c_xs_strides
);
copy
(
wei_k_c_xs_desc
.
GetStrides
(),
b_k_c_xs_strides
);
copy
(
bias_n_k_wos_desc
.
GetLengths
(),
d_n_k_wos_lengths
);
copy
(
bias_n_k_wos_desc
.
GetStrides
(),
d_n_k_wos_strides
);
copy
(
out_n_k_wos_desc
.
GetLengths
(),
e_n_k_wos_lengths
);
copy
(
out_n_k_wos_desc
.
GetLengths
(),
e_n_k_wos_lengths
);
copy
(
out_n_k_wos_desc
.
GetStrides
(),
e_n_k_wos_strides
);
copy
(
out_n_k_wos_desc
.
GetStrides
(),
e_n_k_wos_strides
);
copy
(
conv_param
.
conv_filter_strides_
,
conv_filter_strides
);
copy
(
conv_param
.
conv_filter_strides_
,
conv_filter_strides
);
...
@@ -225,25 +253,26 @@ int run_conv_fwd(bool do_verification,
...
@@ -225,25 +253,26 @@ int run_conv_fwd(bool do_verification,
// do GEMM
// do GEMM
auto
conv
=
DeviceConvNDFwdInstance
{};
auto
conv
=
DeviceConvNDFwdInstance
{};
auto
invoker
=
conv
.
MakeInvoker
();
auto
invoker
=
conv
.
MakeInvoker
();
auto
argument
=
conv
.
MakeArgument
(
in_device_buf
.
GetDeviceBuffer
(),
auto
argument
=
conv
.
MakeArgument
(
wei_device_buf
.
GetDeviceBuffer
(),
in_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
0
>
{},
wei_device_buf
.
GetDeviceBuffer
(),
out_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
bias_device_buf
.
GetDeviceBuffer
()},
a_n_c_wis_lengths
,
out_device_buf
.
GetDeviceBuffer
(),
a_n_c_wis_strides
,
a_n_c_wis_lengths
,
b_k_c_xs_lengths
,
a_n_c_wis_strides
,
b_k_c_xs_strides
,
b_k_c_xs_lengths
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
,
0
>
{{}},
b_k_c_xs_strides
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
,
0
>
{{}},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
,
1
>
{{
d_n_k_wos_lengths
}},
e_n_k_wos_lengths
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
,
1
>
{{
d_n_k_wos_strides
}},
e_n_k_wos_strides
,
e_n_k_wos_lengths
,
conv_filter_strides
,
e_n_k_wos_strides
,
conv_filter_dilations
,
conv_filter_strides
,
input_left_pads
,
conv_filter_dilations
,
input_right_pads
,
input_left_pads
,
in_element_op
,
input_right_pads
,
wei_element_op
,
in_element_op
,
out_element_op
);
wei_element_op
,
out_element_op
);
if
(
!
conv
.
IsSupportedArgument
(
argument
))
if
(
!
conv
.
IsSupportedArgument
(
argument
))
{
{
...
@@ -264,6 +293,10 @@ int run_conv_fwd(bool do_verification,
...
@@ -264,6 +293,10 @@ int run_conv_fwd(bool do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
Tensor
<
OutDataType
>
c_host
(
out_desc
);
auto
ref_conv
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd
<
NDimSpatial
,
auto
ref_conv
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd
<
NDimSpatial
,
InLayout
,
InLayout
,
WeiLayout
,
WeiLayout
,
...
@@ -273,26 +306,41 @@ int run_conv_fwd(bool do_verification,
...
@@ -273,26 +306,41 @@ int run_conv_fwd(bool do_verification,
OutDataType
,
OutDataType
,
InElementOp
,
InElementOp
,
WeiElementOp
,
WeiElementOp
,
OutElementOp
>
();
PassThrough
>
();
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_invoker
=
ref_conv
.
MakeInvoker
();
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in
,
auto
ref_argument
=
ref_conv
.
MakeArgument
(
in
,
wei
,
wei
,
out
_host
,
c
_host
,
conv_param
.
conv_filter_strides_
,
conv_param
.
conv_filter_strides_
,
conv_param
.
conv_filter_dilations_
,
conv_param
.
conv_filter_dilations_
,
conv_param
.
input_left_pads_
,
conv_param
.
input_left_pads_
,
conv_param
.
input_right_pads_
,
conv_param
.
input_right_pads_
,
in_element_op
,
in_element_op
,
wei_element_op
,
wei_element_op
,
out_element_op
);
PassThrough
{}
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
for
(
int
n
=
0
;
n
<
out_host
.
mDesc
.
GetLengths
()[
0
];
n
++
)
{
for
(
int
ho
=
0
;
ho
<
out_host
.
mDesc
.
GetLengths
()[
1
];
ho
++
)
{
for
(
int
wo
=
0
;
wo
<
out_host
.
mDesc
.
GetLengths
()[
2
];
wo
++
)
{
for
(
int
k
=
0
;
k
<
out_host
.
mDesc
.
GetLengths
()[
3
];
k
++
)
{
out_element_op
(
out_host
(
n
,
ho
,
wo
,
k
),
c_host
(
n
,
ho
,
wo
,
k
),
bias
(
n
,
ho
,
wo
,
k
));
}
}
}
}
out_device_buf
.
FromDevice
(
out_device
.
mData
.
data
());
out_device_buf
.
FromDevice
(
out_device
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
return
ck
::
utils
::
check_err
(
out_
host
.
mData
,
out_
device
.
mData
,
"Error: incorrect results!"
,
1e-5
f
,
1e-4
f
)
out_
device
.
mData
,
out_
host
.
mData
,
"Error: incorrect results!"
,
1e-5
f
,
1e-4
f
)
?
0
?
0
:
1
;
:
1
;
}
}
...
...
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
View file @
ae1b4ee6
...
@@ -16,7 +16,8 @@ using S = ck::Sequence<Is...>;
...
@@ -16,7 +16,8 @@ using S = ck::Sequence<Is...>;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
UnaryConvert
;
// using OutElementOp = ck::tensor_operation::element_wise::UnaryConvert;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
#if 0
#if 0
static constexpr auto ConvFwdDefault =
static constexpr auto ConvFwdDefault =
...
@@ -60,6 +61,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
...
@@ -60,6 +61,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvNdFwdNwc
1>; // CThreadTransferDstScalarPerVector
1>; // CThreadTransferDstScalarPerVector
#else
#else
using
CShuffleDataType
=
ck
::
half_t
;
using
CShuffleDataType
=
ck
::
half_t
;
using
DDataType
=
ck
::
half_t
;
static
constexpr
auto
ConvSpec
=
static
constexpr
auto
ConvSpec
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
...
@@ -77,7 +79,10 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMulti
...
@@ -77,7 +79,10 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMulti
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
KXC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
KYXC
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>>
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>>
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NW_K
,
ck
::
tensor_layout
::
convolution
::
NHW_K
,
ck
::
tensor_layout
::
convolution
::
NDHW_K
>>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
NWK
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
ck
::
tensor_layout
::
convolution
::
NHWK
,
...
@@ -86,7 +91,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMulti
...
@@ -86,7 +91,7 @@ using DeviceConvNDFwdInstance = ck::tensor_operation::device::DeviceConvFwdMulti
WeiDataType
,
WeiDataType
,
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
ck
::
Tuple
<>
,
ck
::
Tuple
<
DDataType
>
,
OutDataType
,
OutDataType
,
InElementOp
,
InElementOp
,
WeiElementOp
,
WeiElementOp
,
...
...
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
ae1b4ee6
...
@@ -565,10 +565,6 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -565,10 +565,6 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
}
}
// supported layout:
// KXC, K_XC
// KYXC, K_YXC
// KZYXC, K_ZYXC
template
<
typename
BLay
,
template
<
typename
BLay
,
typename
std
::
enable_if
<
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
KXC
>
||
typename
std
::
enable_if
<
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
KXC
>
||
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
KYXC
>
||
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
KYXC
>
||
...
@@ -625,10 +621,57 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -625,10 +621,57 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
out_gemmm_gemmn_grid_desc
;
return
out_gemmm_gemmn_grid_desc
;
}
}
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
template
<
typename
ELay
,
typename
std
::
enable_if
<
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NW_K
>
||
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NHW_K
>
||
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NDHW_K
>
,
bool
>::
type
=
false
>
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
)
{
namespace
ctc
=
ck
::
tensor_layout
::
convolution
;
const
index_t
N
=
e_n_k_wos_lengths
[
0
];
const
index_t
K
=
e_n_k_wos_lengths
[
1
];
const
index_t
WoStride
=
e_n_k_wos_strides
[
NDimSpatial
+
1
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_n_k_wos_lengths
.
begin
()
+
2
,
e_n_k_wos_lengths
.
begin
()
+
2
+
NDimSpatial
,
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
const
index_t
GemmNRaw
=
K
;
const
auto
out_gemmmraw_gemmnraw_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
GemmMRaw
,
GemmNRaw
),
make_tuple
(
WoStride
,
I1
));
const
auto
out_gemmm_gemmn_grid_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_grid_desc
);
return
out_gemmm_gemmn_grid_desc
;
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_n_k_wos_lengths
[
i
],
ds_n_k_wos_strides
[
i
]);
},
Number
<
NumDTensor
>
{});
}
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
({},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
MakeAGridDescriptor_M_K
<
ALayout
>
({},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
({},
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
...
@@ -643,7 +686,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -643,7 +686,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_M_K
,
AGridDesc_M_K
,
BGridDesc_N_K
,
BGridDesc_N_K
,
StaticallyIndexedArray
<
EGridDesc_M_N
,
NumDTensor
>
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
...
@@ -762,6 +805,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -762,6 +805,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
block_2_etile_map_
=
Block2ETileMap
{
e_grid_desc_m_n_
};
block_2_etile_map_
=
Block2ETileMap
{
e_grid_desc_m_n_
};
// populate pointer and desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_n_k_wos_lengths
[
i
],
ds_n_k_wos_strides
[
i
]);
});
// populate desc for Ds/E
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
ds_grid_desc_m_n_
,
...
@@ -772,22 +827,21 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -772,22 +827,21 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
e_grid_desc_m_n_
);
// populate pointer and desc for Ds
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
ds_grid_desc_m_n_
);
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
ds_grid_desc_m_n_
[
i
]
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
ds_n_k_wos_lengths
[
i
],
ds_n_k_wos_strides
[
i
]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
i
)
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
[
i
]);
});
}
}
}
}
void
Print
()
const
{
std
::
cout
<<
"A[M, K]: "
<<
a_grid_desc_m_k_
<<
std
::
endl
;
std
::
cout
<<
"B[N, K]: "
<<
b_grid_desc_n_k_
<<
std
::
endl
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
std
::
cout
<<
"Ds[M, N]: "
<<
ds_grid_desc_m_n_
[
i
]
<<
std
::
endl
;
});
std
::
cout
<<
"E[M, N]: "
<<
e_grid_desc_m_n_
<<
std
::
endl
;
}
// private:
// private:
// pointers
// pointers
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
...
@@ -798,14 +852,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -798,14 +852,12 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// tensor descriptors
// tensor descriptors
AGridDesc_M_K
a_grid_desc_m_k_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
StaticallyIndexedArray
<
EGridDesc_M_N
,
NumDTensor
>
ds_grid_desc_m_n_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
...
@@ -841,11 +893,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -841,11 +893,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
#if 1
#if 1
{
arg
.
Print
();
std
::
cout
<<
"A[M, K]: "
<<
arg
.
a_grid_desc_m_k_
<<
std
::
endl
;
std
::
cout
<<
"B[N, K]: "
<<
arg
.
b_grid_desc_n_k_
<<
std
::
endl
;
std
::
cout
<<
"E[M, N]: "
<<
arg
.
e_grid_desc_m_n_
<<
std
::
endl
;
}
#endif
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
...
@@ -876,9 +924,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -876,9 +924,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
CDEElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
StaticallyIndexedArray
<
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
Block2ETileMap
,
has_main_loop
>
;
has_main_loop
>
;
...
@@ -921,8 +967,15 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -921,8 +967,15 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
#if 1
arg
.
Print
();
#endif
namespace
ctc
=
tensor_layout
::
convolution
;
namespace
ctc
=
tensor_layout
::
convolution
;
int
itmp
=
0
;
printf
(
"itmp %d
\n
"
,
itmp
++
);
// check device
// check device
if
(
get_device_name
()
==
"gfx908"
)
if
(
get_device_name
()
==
"gfx908"
)
{
{
...
@@ -945,6 +998,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -945,6 +998,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
false
;
return
false
;
}
}
printf
(
"itmp %d
\n
"
,
itmp
++
);
// check ConvolutionForwardSpecialization
// check ConvolutionForwardSpecialization
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
...
@@ -980,6 +1035,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -980,6 +1035,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
}
}
printf
(
"itmp %d
\n
"
,
itmp
++
);
// check vector access of A
// check vector access of A
if
constexpr
(
is_same_v
<
ALayout
,
ctc
::
NWC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWC
>
||
if
constexpr
(
is_same_v
<
ALayout
,
ctc
::
NWC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWC
>
)
is_same_v
<
ALayout
,
ctc
::
NDHWC
>
)
...
@@ -996,6 +1053,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -996,6 +1053,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
false
;
return
false
;
}
}
printf
(
"itmp %d
\n
"
,
itmp
++
);
// check vector access of B
// check vector access of B
if
constexpr
(
is_same_v
<
BLayout
,
ctc
::
KXC
>
||
is_same_v
<
BLayout
,
ctc
::
KYXC
>
||
if
constexpr
(
is_same_v
<
BLayout
,
ctc
::
KXC
>
||
is_same_v
<
BLayout
,
ctc
::
KYXC
>
||
is_same_v
<
BLayout
,
ctc
::
KZYXC
>
)
is_same_v
<
BLayout
,
ctc
::
KZYXC
>
)
...
@@ -1012,7 +1071,37 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1012,7 +1071,37 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
false
;
return
false
;
}
}
// FIXME: check vector access of Ds
printf
(
"itmp %d
\n
"
,
itmp
++
);
// check vector access of Ds
bool
valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
is_same_v
<
DLayout
,
ctc
::
NWK
>
||
is_same_v
<
DLayout
,
ctc
::
NHWK
>
||
is_same_v
<
DLayout
,
ctc
::
NDHWK
>
||
is_same_v
<
DLayout
,
ctc
::
NW_K
>
||
is_same_v
<
DLayout
,
ctc
::
NHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
NDHW_K
>
)
{
const
index_t
K
=
arg
.
ds_n_k_wos_lengths_
[
i
][
1
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
valid
=
false
;
}
}
else
{
valid
=
false
;
}
});
if
(
!
valid
)
{
return
false
;
}
printf
(
"itmp %d
\n
"
,
itmp
++
);
// check vector access of E
// check vector access of E
if
constexpr
(
is_same_v
<
ELayout
,
ctc
::
NWK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWK
>
||
if
constexpr
(
is_same_v
<
ELayout
,
ctc
::
NWK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWK
>
||
...
@@ -1030,6 +1119,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1030,6 +1119,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
false
;
return
false
;
}
}
printf
(
"itmp %d
\n
"
,
itmp
++
);
// check Gridwise GEMM
// check Gridwise GEMM
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
...
...
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
ae1b4ee6
...
@@ -25,41 +25,45 @@ struct ColumnMajor : public BaseTensorLayout
...
@@ -25,41 +25,45 @@ struct ColumnMajor : public BaseTensorLayout
namespace
convolution
{
namespace
convolution
{
// 1D Conv
// input tensor
// packed NWC/NHWC/NDHWC
struct
NWC
:
public
BaseTensorLayout
struct
NWC
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"NWC"
;
static
constexpr
const
char
*
name
=
"NWC"
;
};
};
struct
KX
C
:
public
BaseTensorLayout
struct
NHW
C
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"
KX
C"
;
static
constexpr
const
char
*
name
=
"
NHW
C"
;
};
};
struct
N
WK
:
public
BaseTensorLayout
struct
N
DHWC
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"N
WK
"
;
static
constexpr
const
char
*
name
=
"N
DHWC
"
;
};
};
// input tensor
// packed NCW/NCHW/NCDHW
struct
NCW
:
public
BaseTensorLayout
struct
NCW
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"NCW"
;
static
constexpr
const
char
*
name
=
"NCW"
;
};
};
struct
KCX
:
public
BaseTensorLayout
struct
NCHW
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"
KCX
"
;
static
constexpr
const
char
*
name
=
"
NCHW
"
;
};
};
struct
N
K
W
:
public
BaseTensorLayout
struct
N
CDH
W
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"N
K
W"
;
static
constexpr
const
char
*
name
=
"N
CDH
W"
;
};
};
// 2D Conv
// weight tensor
struct
NHWC
:
public
BaseTensorLayout
// packed KXC/KYXC/KZYXC
struct
KXC
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"
NHW
C"
;
static
constexpr
const
char
*
name
=
"
KX
C"
;
};
};
struct
KYXC
:
public
BaseTensorLayout
struct
KYXC
:
public
BaseTensorLayout
...
@@ -67,14 +71,16 @@ struct KYXC : public BaseTensorLayout
...
@@ -67,14 +71,16 @@ struct KYXC : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"KYXC"
;
static
constexpr
const
char
*
name
=
"KYXC"
;
};
};
struct
NHWK
:
public
BaseTensorLayout
struct
KZYXC
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"
NHWK
"
;
static
constexpr
const
char
*
name
=
"
KZYXC
"
;
};
};
struct
NCHW
:
public
BaseTensorLayout
// weight tensor
// packed KCX/KCYX/KCZYX
struct
KCX
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"
NCHW
"
;
static
constexpr
const
char
*
name
=
"
KCX
"
;
};
};
struct
KCYX
:
public
BaseTensorLayout
struct
KCYX
:
public
BaseTensorLayout
...
@@ -82,34 +88,38 @@ struct KCYX : public BaseTensorLayout
...
@@ -82,34 +88,38 @@ struct KCYX : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"KCYX"
;
static
constexpr
const
char
*
name
=
"KCYX"
;
};
};
struct
NKHW
:
public
BaseTensorLayout
struct
KCZYX
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"
NKHW
"
;
static
constexpr
const
char
*
name
=
"
KCZYX
"
;
};
};
// 3D Conv
// output tensor
struct
NDHWC
:
public
BaseTensorLayout
// packed NWK/NHWK/NDHWK
struct
NWK
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"N
DHWC
"
;
static
constexpr
const
char
*
name
=
"N
WK
"
;
};
};
struct
KZYXC
:
public
BaseTensorLayout
struct
NHWK
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"
KZYXC
"
;
static
constexpr
const
char
*
name
=
"
NHWK
"
;
};
};
struct
NDHWK
:
public
BaseTensorLayout
struct
NDHWK
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"NDHWK"
;
static
constexpr
const
char
*
name
=
"NDHWK"
;
};
};
struct
NCDHW
:
public
BaseTensorLayout
// output tensor
// packed NKW/NKHW/NKDHW
struct
NKW
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"N
CDH
W"
;
static
constexpr
const
char
*
name
=
"N
K
W"
;
};
};
struct
KCZYX
:
public
BaseTensorLayout
struct
NKHW
:
public
BaseTensorLayout
{
{
static
constexpr
const
char
*
name
=
"
KCZYX
"
;
static
constexpr
const
char
*
name
=
"
NKHW
"
;
};
};
struct
NKDHW
:
public
BaseTensorLayout
struct
NKDHW
:
public
BaseTensorLayout
...
@@ -117,6 +127,23 @@ struct NKDHW : public BaseTensorLayout
...
@@ -117,6 +127,23 @@ struct NKDHW : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"NKDHW"
;
static
constexpr
const
char
*
name
=
"NKDHW"
;
};
};
// output tensor
// strided layout
struct
NW_K
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NW_K"
;
};
struct
NHW_K
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NHW_K"
;
};
struct
NDHW_K
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NDHW_K"
;
};
}
// namespace convolution
}
// namespace convolution
template
<
template
<
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
ae1b4ee6
...
@@ -165,6 +165,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -165,6 +165,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
c_block_size
*
sizeof
(
CShuffleDataType
));
c_block_size
*
sizeof
(
CShuffleDataType
));
}
}
// A desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
MakeDefaultAGridDescriptor_AK0_M_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
)
{
{
...
@@ -180,6 +181,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -180,6 +181,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
// B desc for source in blockwise copy
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
MakeDefaultBGridDescriptor_BK0_N_BK1
(
const
BGridDesc_N_K
&
b_grid_desc_n_k
)
{
{
...
@@ -195,8 +197,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -195,8 +197,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
__host__
__device__
static
constexpr
auto
// E desc for destination in blockwise copy
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
template
<
typename
EGridDescriptor_M_N
>
__host__
__device__
static
constexpr
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDescriptor_M_N
&
e_grid_desc_m_n
)
{
{
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
e_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
e_grid_desc_m_n
.
GetLength
(
I1
);
...
@@ -214,6 +218,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -214,6 +218,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
}
}
// Ds desc for source in blockwise copy
template
<
typename
DsGridDescriptor_M_N
>
__host__
__device__
static
constexpr
auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
DsGridDescriptor_M_N
&
ds_grid_desc_m_n
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// return block_id to E matrix tile idx (m0, n0) mapping
// return block_id to E matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
MakeDefaultBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
...
@@ -301,8 +318,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -301,8 +318,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
DefaultBGridDesc_BK0_N_BK1
=
using
DefaultBGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
DefaultBlock2ETileMap
=
using
DefaultBlock2ETileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
...
@@ -313,24 +332,21 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -313,24 +332,21 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
Block2ETileMap
>
typename
Block2ETileMap
>
__device__
static
void
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
StaticallyIndexedArray
<
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
NumDTensor
>&
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
// FIXME: Ds desc may be of different
e_grid_desc_mblock_mperblock_nblock_nperblock
,
// type from E
const
Block2ETileMap
&
block_2_etile_map
)
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
&
block_2_etile_map
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment