Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
19173ab7
"vscode:/vscode.git/clone" did not exist on "5fd25df859ca61d9ff30b021fa5b628c593748f4"
Commit
19173ab7
authored
Jul 24, 2022
by
Chao Liu
Browse files
add G
parent
c3379310
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
521 additions
and
462 deletions
+521
-462
example/09_convnd_fwd/convnd_fwd_common.hpp
example/09_convnd_fwd/convnd_fwd_common.hpp
+84
-111
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
+1
-1
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp
...ensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp
+8
-8
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
...on/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
+347
-192
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
...ary/reference_tensor_operation/cpu/reference_conv_fwd.hpp
+48
-128
library/include/ck/library/utility/convolution_parameter.hpp
library/include/ck/library/utility/convolution_parameter.hpp
+18
-14
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+4
-0
library/src/utility/convolution_parameter.cpp
library/src/utility/convolution_parameter.cpp
+11
-8
No files found.
example/09_convnd_fwd/convnd_fwd_common.hpp
View file @
19173ab7
...
@@ -25,7 +25,7 @@ void print_helper_msg()
...
@@ -25,7 +25,7 @@ void print_helper_msg()
<<
"arg3: time kernel (0=no, 1=yes)
\n
"
<<
"arg3: time kernel (0=no, 1=yes)
\n
"
<<
"arg4: N spatial dimensions (default 2)
\n
"
<<
"arg4: N spatial dimensions (default 2)
\n
"
<<
"Following arguments (depending on number of spatial dims):
\n
"
<<
"Following arguments (depending on number of spatial dims):
\n
"
<<
" N, K, C,
\n
"
<<
"
G,
N, K, C,
\n
"
<<
" <filter spatial dimensions>, (ie Y, X for 2D)
\n
"
<<
" <filter spatial dimensions>, (ie Y, X for 2D)
\n
"
<<
" <input image spatial dimensions>, (ie Hi, Wi for 2D)
\n
"
<<
" <input image spatial dimensions>, (ie Hi, Wi for 2D)
\n
"
<<
" <strides>, (ie Sy, Sx for 2D)
\n
"
<<
" <strides>, (ie Sy, Sx for 2D)
\n
"
...
@@ -37,6 +37,7 @@ void print_helper_msg()
...
@@ -37,6 +37,7 @@ void print_helper_msg()
ck
::
utils
::
conv
::
ConvParam
parse_conv_params
(
int
num_dim_spatial
,
int
arg_idx
,
char
*
const
argv
[])
ck
::
utils
::
conv
::
ConvParam
parse_conv_params
(
int
num_dim_spatial
,
int
arg_idx
,
char
*
const
argv
[])
{
{
const
ck
::
index_t
G
=
std
::
stoi
(
argv
[
arg_idx
++
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
arg_idx
++
]);
const
ck
::
index_t
N
=
std
::
stoi
(
argv
[
arg_idx
++
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
arg_idx
++
]);
const
ck
::
index_t
K
=
std
::
stoi
(
argv
[
arg_idx
++
]);
const
ck
::
index_t
C
=
std
::
stoi
(
argv
[
arg_idx
++
]);
const
ck
::
index_t
C
=
std
::
stoi
(
argv
[
arg_idx
++
]);
...
@@ -79,6 +80,7 @@ ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, c
...
@@ -79,6 +80,7 @@ ck::utils::conv::ConvParam parse_conv_params(int num_dim_spatial, int arg_idx, c
}
}
return
ck
::
utils
::
conv
::
ConvParam
{
num_dim_spatial
,
return
ck
::
utils
::
conv
::
ConvParam
{
num_dim_spatial
,
G
,
N
,
N
,
K
,
K
,
C
,
C
,
...
@@ -110,23 +112,56 @@ int run_conv_fwd(bool do_verification,
...
@@ -110,23 +112,56 @@ int run_conv_fwd(bool do_verification,
const
WeiElementOp
&
wei_element_op
,
const
WeiElementOp
&
wei_element_op
,
const
OutElementOp
&
out_element_op
)
const
OutElementOp
&
out_element_op
)
{
{
const
auto
in_desc
=
ck
::
utils
::
conv
::
get_input_host_tensor_descriptor
<
InLayout
>
(
conv_param
);
#if 0
const
auto
wei_desc
=
ck
::
utils
::
conv
::
get_weight_host_tensor_descriptor
<
WeiLayout
>
(
conv_param
);
const auto in_g_n_c_wis_desc = ck::utils::conv::get_input_host_tensor_descriptor<InLayout>(conv_param);
const
auto
out_desc
=
ck
::
utils
::
conv
::
get_output_host_tensor_descriptor
<
OutLayout
>
(
conv_param
);
const auto wei_g_k_c_xs_desc = ck::utils::conv::get_weight_host_tensor_descriptor<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc = ck::utils::conv::get_output_host_tensor_descriptor<OutLayout>(conv_param);
// hacky, hardcoded for 2d NHWK
#else
const
auto
bias_desc
=
HostTensorDescriptor
(
const
auto
in_g_n_wis_c_desc
=
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
conv_param
.
N_
),
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
conv_param
.
G_
),
static_cast
<
std
::
size_t
>
(
conv_param
.
N_
),
static_cast
<
std
::
size_t
>
(
conv_param
.
input_spatial_lengths_
[
0
]),
static_cast
<
std
::
size_t
>
(
conv_param
.
input_spatial_lengths_
[
1
]),
static_cast
<
std
::
size_t
>
(
conv_param
.
C_
)});
const
auto
wei_g_k_xs_c_desc
=
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
conv_param
.
G_
),
static_cast
<
std
::
size_t
>
(
conv_param
.
K_
),
static_cast
<
std
::
size_t
>
(
conv_param
.
filter_spatial_lengths_
[
0
]),
static_cast
<
std
::
size_t
>
(
conv_param
.
filter_spatial_lengths_
[
1
]),
static_cast
<
std
::
size_t
>
(
conv_param
.
C_
)});
const
auto
bias_g_n_wos_k_desc
=
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
conv_param
.
G_
),
static_cast
<
std
::
size_t
>
(
conv_param
.
N_
),
static_cast
<
std
::
size_t
>
(
conv_param
.
output_spatial_lengths_
[
0
]),
static_cast
<
std
::
size_t
>
(
conv_param
.
output_spatial_lengths_
[
0
]),
static_cast
<
std
::
size_t
>
(
conv_param
.
output_spatial_lengths_
[
1
]),
static_cast
<
std
::
size_t
>
(
conv_param
.
output_spatial_lengths_
[
1
]),
static_cast
<
std
::
size_t
>
(
conv_param
.
K_
)},
static_cast
<
std
::
size_t
>
(
conv_param
.
K_
)},
std
::
vector
<
std
::
size_t
>
{
0
,
0
,
0
,
1
});
std
::
vector
<
std
::
size_t
>
{
0
,
0
,
0
,
0
,
1
});
const
auto
out_g_n_wos_k_desc
=
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
conv_param
.
G_
),
static_cast
<
std
::
size_t
>
(
conv_param
.
N_
),
static_cast
<
std
::
size_t
>
(
conv_param
.
output_spatial_lengths_
[
0
]),
static_cast
<
std
::
size_t
>
(
conv_param
.
output_spatial_lengths_
[
1
]),
static_cast
<
std
::
size_t
>
(
conv_param
.
K_
)});
Tensor
<
InDataType
>
in
(
in_desc
);
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
Tensor
<
WeiDataType
>
wei
(
wei_desc
);
const
auto
in_g_n_c_wis_desc
=
transpose_host_tensor_descriptor_given_new2old
(
Tensor
<
OutDataType
>
bias
(
bias_desc
);
in_g_n_wis_c_desc
,
std
::
vector
<
ck
::
index_t
>
{
0
,
1
,
4
,
2
,
3
});
Tensor
<
OutDataType
>
out_host
(
out_desc
);
const
auto
wei_g_k_c_xs_desc
=
transpose_host_tensor_descriptor_given_new2old
(
Tensor
<
OutDataType
>
out_device
(
out_desc
);
wei_g_k_xs_c_desc
,
std
::
vector
<
ck
::
index_t
>
{
0
,
1
,
4
,
2
,
3
});
const
auto
bias_g_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
bias_g_n_wos_k_desc
,
std
::
vector
<
ck
::
index_t
>
{
0
,
1
,
4
,
2
,
3
});
const
auto
out_g_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_g_n_wos_k_desc
,
std
::
vector
<
ck
::
index_t
>
{
0
,
1
,
4
,
2
,
3
});
#endif
Tensor
<
InDataType
>
in
(
in_g_n_c_wis_desc
);
Tensor
<
WeiDataType
>
wei
(
wei_g_k_c_xs_desc
);
Tensor
<
OutDataType
>
bias
(
bias_g_n_k_wos_desc
);
Tensor
<
OutDataType
>
out_host
(
out_g_n_k_wos_desc
);
Tensor
<
OutDataType
>
out_device
(
out_g_n_k_wos_desc
);
std
::
cout
<<
"in: "
<<
in
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"in: "
<<
in
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei: "
<<
wei
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"wei: "
<<
wei
.
mDesc
<<
std
::
endl
;
...
@@ -156,80 +191,14 @@ int run_conv_fwd(bool do_verification,
...
@@ -156,80 +191,14 @@ int run_conv_fwd(bool do_verification,
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
wei_device_buf
.
ToDevice
(
wei
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias
.
mData
.
data
());
bias_device_buf
.
ToDevice
(
bias
.
mData
.
data
());
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths
{};
HostTensorDescriptor
in_n_c_wis_desc
=
in_desc
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides
{};
HostTensorDescriptor
wei_k_c_xs_desc
=
wei_desc
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths
{};
HostTensorDescriptor
bias_n_k_wos_desc
=
bias_desc
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides
{};
HostTensorDescriptor
out_n_k_wos_desc
=
out_desc
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
d_g_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
d_g_n_k_wos_strides
{};
// input
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths
{};
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NWC
>
)
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides
{};
{
in_n_c_wis_desc
=
transpose_host_tensor_descriptor_given_new2old
(
in_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
}
else
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NHWC
>
)
{
in_n_c_wis_desc
=
transpose_host_tensor_descriptor_given_new2old
(
in_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
}
else
if
constexpr
(
ck
::
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NDHWC
>
)
{
in_n_c_wis_desc
=
transpose_host_tensor_descriptor_given_new2old
(
in_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
}
// weight
if
constexpr
(
ck
::
is_same_v
<
WeiLayout
,
ck
::
tensor_layout
::
convolution
::
KXC
>
)
{
wei_k_c_xs_desc
=
transpose_host_tensor_descriptor_given_new2old
(
wei_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
}
else
if
constexpr
(
ck
::
is_same_v
<
WeiLayout
,
ck
::
tensor_layout
::
convolution
::
KYXC
>
)
{
wei_k_c_xs_desc
=
transpose_host_tensor_descriptor_given_new2old
(
wei_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
}
else
if
constexpr
(
ck
::
is_same_v
<
WeiLayout
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>
)
{
wei_k_c_xs_desc
=
transpose_host_tensor_descriptor_given_new2old
(
wei_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
}
// output
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NWK
>
)
{
out_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
bias_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
bias_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
}
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
)
{
out_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
bias_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
bias_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
}
else
if
constexpr
(
ck
::
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
)
{
out_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
out_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
bias_n_k_wos_desc
=
transpose_host_tensor_descriptor_given_new2old
(
bias_desc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
}
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
a_n_c_wis_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
a_n_c_wis_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
b_k_c_xs_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
b_k_c_xs_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
d_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
d_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
e_n_k_wos_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
e_n_k_wos_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
...
@@ -237,14 +206,14 @@ int run_conv_fwd(bool do_verification,
...
@@ -237,14 +206,14 @@ int run_conv_fwd(bool do_verification,
auto
copy
=
[](
auto
&
x
,
auto
&
y
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
y
.
begin
());
};
auto
copy
=
[](
auto
&
x
,
auto
&
y
)
{
std
::
copy
(
x
.
begin
(),
x
.
end
(),
y
.
begin
());
};
copy
(
in_n_c_wis_desc
.
GetLengths
(),
a_n_c_wis_lengths
);
copy
(
in_
g_
n_c_wis_desc
.
GetLengths
(),
a_
g_
n_c_wis_lengths
);
copy
(
in_n_c_wis_desc
.
GetStrides
(),
a_n_c_wis_strides
);
copy
(
in_
g_
n_c_wis_desc
.
GetStrides
(),
a_
g_
n_c_wis_strides
);
copy
(
wei_k_c_xs_desc
.
GetLengths
(),
b_k_c_xs_lengths
);
copy
(
wei_
g_
k_c_xs_desc
.
GetLengths
(),
b_
g_
k_c_xs_lengths
);
copy
(
wei_k_c_xs_desc
.
GetStrides
(),
b_k_c_xs_strides
);
copy
(
wei_
g_
k_c_xs_desc
.
GetStrides
(),
b_
g_
k_c_xs_strides
);
copy
(
bias_n_k_wos_desc
.
GetLengths
(),
d_n_k_wos_lengths
);
copy
(
bias_
g_
n_k_wos_desc
.
GetLengths
(),
d_
g_
n_k_wos_lengths
);
copy
(
bias_n_k_wos_desc
.
GetStrides
(),
d_n_k_wos_strides
);
copy
(
bias_
g_
n_k_wos_desc
.
GetStrides
(),
d_
g_
n_k_wos_strides
);
copy
(
out_n_k_wos_desc
.
GetLengths
(),
e_n_k_wos_lengths
);
copy
(
out_
g_
n_k_wos_desc
.
GetLengths
(),
e_
g_
n_k_wos_lengths
);
copy
(
out_n_k_wos_desc
.
GetStrides
(),
e_n_k_wos_strides
);
copy
(
out_
g_
n_k_wos_desc
.
GetStrides
(),
e_
g_
n_k_wos_strides
);
copy
(
conv_param
.
conv_filter_strides_
,
conv_filter_strides
);
copy
(
conv_param
.
conv_filter_strides_
,
conv_filter_strides
);
copy
(
conv_param
.
conv_filter_dilations_
,
conv_filter_dilations
);
copy
(
conv_param
.
conv_filter_dilations_
,
conv_filter_dilations
);
copy
(
conv_param
.
input_left_pads_
,
input_left_pads
);
copy
(
conv_param
.
input_left_pads_
,
input_left_pads
);
...
@@ -258,14 +227,14 @@ int run_conv_fwd(bool do_verification,
...
@@ -258,14 +227,14 @@ int run_conv_fwd(bool do_verification,
wei_device_buf
.
GetDeviceBuffer
(),
wei_device_buf
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
bias_device_buf
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
bias_device_buf
.
GetDeviceBuffer
()},
out_device_buf
.
GetDeviceBuffer
(),
out_device_buf
.
GetDeviceBuffer
(),
a_n_c_wis_lengths
,
a_
g_
n_c_wis_lengths
,
a_n_c_wis_strides
,
a_
g_
n_c_wis_strides
,
b_k_c_xs_lengths
,
b_
g_
k_c_xs_lengths
,
b_k_c_xs_strides
,
b_
g_
k_c_xs_strides
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
,
1
>
{{
d_n_k_wos_lengths
}},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
,
1
>
{{
d_
g_
n_k_wos_lengths
}},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
2
>
,
1
>
{{
d_n_k_wos_strides
}},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
,
1
>
{{
d_
g_
n_k_wos_strides
}},
e_n_k_wos_lengths
,
e_
g_
n_k_wos_lengths
,
e_n_k_wos_strides
,
e_
g_
n_k_wos_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -295,7 +264,7 @@ int run_conv_fwd(bool do_verification,
...
@@ -295,7 +264,7 @@ int run_conv_fwd(bool do_verification,
{
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
Tensor
<
OutDataType
>
c_host
(
out_desc
);
Tensor
<
OutDataType
>
c_host
(
out_
g_n_k_wos_
desc
);
auto
ref_conv
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd
<
NDimSpatial
,
auto
ref_conv
=
ck
::
tensor_operation
::
host
::
ReferenceConvFwd
<
NDimSpatial
,
InLayout
,
InLayout
,
...
@@ -322,16 +291,20 @@ int run_conv_fwd(bool do_verification,
...
@@ -322,16 +291,20 @@ int run_conv_fwd(bool do_verification,
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
for
(
int
n
=
0
;
n
<
out_host
.
mDesc
.
GetLengths
()[
0
];
n
++
)
for
(
int
g
=
0
;
g
<
out_host
.
mDesc
.
GetLengths
()[
0
];
g
++
)
{
{
for
(
int
ho
=
0
;
ho
<
out_host
.
mDesc
.
GetLengths
()[
1
];
ho
++
)
for
(
int
n
=
0
;
n
<
out_host
.
mDesc
.
GetLengths
()[
1
];
n
++
)
{
{
for
(
int
wo
=
0
;
wo
<
out_host
.
mDesc
.
GetLengths
()[
2
];
wo
++
)
for
(
int
k
=
0
;
k
<
out_host
.
mDesc
.
GetLengths
()[
2
];
k
++
)
{
{
for
(
int
k
=
0
;
k
<
out_host
.
mDesc
.
GetLengths
()[
3
];
k
++
)
for
(
int
ho
=
0
;
ho
<
out_host
.
mDesc
.
GetLengths
()[
3
];
ho
++
)
{
{
out_element_op
(
for
(
int
wo
=
0
;
wo
<
out_host
.
mDesc
.
GetLengths
()[
4
];
wo
++
)
out_host
(
n
,
ho
,
wo
,
k
),
c_host
(
n
,
ho
,
wo
,
k
),
bias
(
n
,
ho
,
wo
,
k
));
{
out_element_op
(
out_host
(
g
,
n
,
k
,
ho
,
wo
),
c_host
(
g
,
n
,
k
,
ho
,
wo
),
bias
(
g
,
n
,
k
,
ho
,
wo
));
}
}
}
}
}
}
}
...
...
example/09_convnd_fwd/convnd_fwd_xdl_fp16.cpp
View file @
19173ab7
...
@@ -138,7 +138,7 @@ int main(int argc, char* argv[])
...
@@ -138,7 +138,7 @@ int main(int argc, char* argv[])
int
num_dim_spatial
=
2
;
int
num_dim_spatial
=
2
;
ck
::
utils
::
conv
::
ConvParam
params
{
ck
::
utils
::
conv
::
ConvParam
params
{
2
,
128
,
256
,
192
,
{
3
,
3
},
{
71
,
71
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}};
2
,
1
,
128
,
256
,
192
,
{
3
,
3
},
{
71
,
71
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}};
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
...
...
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d.hpp
View file @
19173ab7
...
@@ -39,14 +39,14 @@ struct DeviceConvFwdMultipleD : public BaseOperator
...
@@ -39,14 +39,14 @@ struct DeviceConvFwdMultipleD : public BaseOperator
const
void
*
p_b
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_
g_
n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_
g_
n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_
g_
n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_
g_
n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
...
...
include/ck/tensor_operation/gpu/device/device_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
19173ab7
...
@@ -28,39 +28,152 @@ namespace device {
...
@@ -28,39 +28,152 @@ namespace device {
namespace
{
namespace
{
template
<
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs
,
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
};
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
ABDataType
,
typename
Float
DsPointer
,
typename
DsPointer
,
typename
FloatE
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc
riptor
_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_multiple_d_xdl_cshuffle
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_batch_gemm_multiple_d_xdl_cshuffle
(
const
FloatAB
*
__restrict__
p_b_grid
,
const
ABDataType
*
__restrict__
p_a_grid
,
FloatDsPointer
p_ds_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
FloatE
*
__restrict__
p_e_grid
,
DsPointer
p_ds_grid
,
const
AElementwiseOperation
a_element_op
,
EDataType
*
__restrict__
p_e_grid
,
const
BElementwiseOperation
b_element_op
,
const
AElementwiseOperation
a_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
CDEElementwiseOperation
cde_element_op
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
index_t
batch_count
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if 1
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
DsPointer
p_ds_grid_grp
;
static
constexpr
index_t
NumDTensor
=
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
);
#else
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -71,26 +184,31 @@ __global__ void
...
@@ -71,26 +184,31 @@ __global__ void
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
cde_element_op
,
cde_element_op
,
a_grid_desc_
a
k0_m_
a
k1
,
a_grid_desc_k0_m_k1
,
b_grid_desc_
b
k0_n_
b
k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_etile_map
);
block_2_ctile_map
);
#endif
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
p_e_grid
;
ignore
=
batch_count
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
compute_ptr_offset_of_batch
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
block_2_ctile_map
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
#endif
#endif
}
}
}
// namespace
}
// namespace
//
//
...
@@ -187,33 +305,33 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -187,33 +305,33 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NWC
>,
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NWC
>,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
static
auto
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_
g_
n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_
g_
n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
{
const
index_t
N
=
a_n_c_wis_lengths
[
0
];
const
index_t
N
=
a_
g_
n_c_wis_lengths
[
1
];
const
index_t
C
=
a_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_
g_
n_c_wis_lengths
[
2
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_n_k_wos_lengths
.
begin
()
+
2
,
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_
g_
n_k_wos_lengths
.
begin
()
+
3
,
e_n_k_wos_lengths
.
begin
()
+
3
,
e_
g_
n_k_wos_lengths
.
begin
()
+
4
,
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
std
::
multiplies
<
index_t
>
());
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_k_c_xs_lengths
.
begin
()
+
2
,
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_
g_
k_c_xs_lengths
.
begin
()
+
3
,
b_k_c_xs_lengths
.
begin
()
+
3
,
b_
g_
k_c_xs_lengths
.
begin
()
+
4
,
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
std
::
multiplies
<
index_t
>
());
const
index_t
Wi
=
a_n_c_wis_lengths
[
2
];
const
index_t
Wi
=
a_
g_
n_c_wis_lengths
[
3
];
const
index_t
Wo
=
e_n_k_wos_lengths
[
2
];
const
index_t
Wo
=
e_
g_
n_k_wos_lengths
[
3
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
...
@@ -255,7 +373,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -255,7 +373,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
else
else
{
{
const
index_t
X
=
b_k_c_xs_lengths
[
2
];
const
index_t
X
=
b_
g_
k_c_xs_lengths
[
3
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
...
@@ -299,35 +417,35 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -299,35 +417,35 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NHWC
>,
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NHWC
>,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
static
auto
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_
g_
n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_
g_
n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
{
const
index_t
N
=
a_n_c_wis_lengths
[
0
];
const
index_t
N
=
a_
g_
n_c_wis_lengths
[
1
];
const
index_t
C
=
a_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_
g_
n_c_wis_lengths
[
2
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_n_k_wos_lengths
.
begin
()
+
2
,
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_
g_
n_k_wos_lengths
.
begin
()
+
3
,
e_n_k_wos_lengths
.
begin
()
+
4
,
e_
g_
n_k_wos_lengths
.
begin
()
+
5
,
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
std
::
multiplies
<
index_t
>
());
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_k_c_xs_lengths
.
begin
()
+
2
,
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_
g_
k_c_xs_lengths
.
begin
()
+
3
,
b_k_c_xs_lengths
.
begin
()
+
4
,
b_
g_
k_c_xs_lengths
.
begin
()
+
5
,
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
std
::
multiplies
<
index_t
>
());
const
index_t
Hi
=
a_n_c_wis_lengths
[
2
];
const
index_t
Hi
=
a_
g_
n_c_wis_lengths
[
3
];
const
index_t
Wi
=
a_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
a_
g_
n_c_wis_lengths
[
4
];
const
index_t
Ho
=
e_n_k_wos_lengths
[
2
];
const
index_t
Ho
=
e_
g_
n_k_wos_lengths
[
3
];
const
index_t
Wo
=
e_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
e_
g_
n_k_wos_lengths
[
4
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
...
@@ -372,8 +490,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -372,8 +490,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
else
else
{
{
const
index_t
Y
=
b_k_c_xs_lengths
[
2
];
const
index_t
Y
=
b_
g_
k_c_xs_lengths
[
3
];
const
index_t
X
=
b_k_c_xs_lengths
[
3
];
const
index_t
X
=
b_
g_
k_c_xs_lengths
[
4
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
...
@@ -425,37 +543,37 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -425,37 +543,37 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NDHWC
>,
is_same_v
<
ALay
,
tensor_layout
::
convolution
::
NDHWC
>,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
static
auto
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
MakeAGridDescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_
g_
n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_
g_
n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
{
const
index_t
N
=
a_n_c_wis_lengths
[
0
];
const
index_t
N
=
a_
g_
n_c_wis_lengths
[
1
];
const
index_t
C
=
a_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_
g_
n_c_wis_lengths
[
2
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_n_k_wos_lengths
.
begin
()
+
2
,
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_
g_
n_k_wos_lengths
.
begin
()
+
3
,
e_n_k_wos_lengths
.
begin
()
+
5
,
e_
g_
n_k_wos_lengths
.
begin
()
+
6
,
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
std
::
multiplies
<
index_t
>
());
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_k_c_xs_lengths
.
begin
()
+
2
,
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_
g_
k_c_xs_lengths
.
begin
()
+
3
,
b_k_c_xs_lengths
.
begin
()
+
5
,
b_
g_
k_c_xs_lengths
.
begin
()
+
6
,
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
std
::
multiplies
<
index_t
>
());
const
index_t
Di
=
a_n_c_wis_lengths
[
2
];
const
index_t
Di
=
a_
g_
n_c_wis_lengths
[
3
];
const
index_t
Hi
=
a_n_c_wis_lengths
[
3
];
const
index_t
Hi
=
a_
g_
n_c_wis_lengths
[
4
];
const
index_t
Wi
=
a_n_c_wis_lengths
[
4
];
const
index_t
Wi
=
a_
g_
n_c_wis_lengths
[
5
];
const
index_t
Do
=
e_n_k_wos_lengths
[
2
];
const
index_t
Do
=
e_
g_
n_k_wos_lengths
[
3
];
const
index_t
Ho
=
e_n_k_wos_lengths
[
3
];
const
index_t
Ho
=
e_
g_
n_k_wos_lengths
[
4
];
const
index_t
Wo
=
e_n_k_wos_lengths
[
4
];
const
index_t
Wo
=
e_
g_
n_k_wos_lengths
[
5
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
...
@@ -504,9 +622,9 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -504,9 +622,9 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
else
else
{
{
const
index_t
Z
=
b_k_c_xs_lengths
[
2
];
const
index_t
Z
=
b_
g_
k_c_xs_lengths
[
3
];
const
index_t
Y
=
b_k_c_xs_lengths
[
3
];
const
index_t
Y
=
b_
g_
k_c_xs_lengths
[
4
];
const
index_t
X
=
b_k_c_xs_lengths
[
4
];
const
index_t
X
=
b_
g_
k_c_xs_lengths
[
5
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
...
@@ -571,16 +689,16 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -571,16 +689,16 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
KZYXC
>
,
is_same_v
<
BLay
,
tensor_layout
::
convolution
::
KZYXC
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
static
auto
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
MakeBGridDescriptor_N_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_strides
)
{
{
const
index_t
K
=
b_k_c_xs_lengths
[
0
];
const
index_t
K
=
b_
g_
k_c_xs_lengths
[
1
];
const
index_t
C
=
b_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_
g_
k_c_xs_lengths
[
2
];
const
index_t
GemmNRaw
=
K
;
const
index_t
GemmNRaw
=
K
;
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_k_c_xs_lengths
.
begin
()
+
2
,
const
index_t
GemmKRaw
=
C
*
std
::
accumulate
(
b_
g_
k_c_xs_lengths
.
begin
()
+
3
,
b_k_c_xs_lengths
.
begin
()
+
2
+
NDimSpatial
,
b_
g_
k_c_xs_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
std
::
multiplies
<
index_t
>
());
...
@@ -599,14 +717,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -599,14 +717,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NDHWK
>
,
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NDHWK
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
static
auto
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_strides
)
{
{
const
index_t
N
=
e_n_k_wos_lengths
[
0
];
const
index_t
N
=
e_
g_
n_k_wos_lengths
[
1
];
const
index_t
K
=
e_n_k_wos_lengths
[
1
];
const
index_t
K
=
e_
g_
n_k_wos_lengths
[
2
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_n_k_wos_lengths
.
begin
()
+
2
,
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_
g_
n_k_wos_lengths
.
begin
()
+
3
,
e_n_k_wos_lengths
.
begin
()
+
2
+
NDimSpatial
,
e_
g_
n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
std
::
multiplies
<
index_t
>
());
...
@@ -627,18 +745,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -627,18 +745,18 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NDHW_K
>
,
is_same_v
<
ELay
,
tensor_layout
::
convolution
::
NDHW_K
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
static
auto
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
MakeEGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_strides
)
{
{
namespace
ctc
=
ck
::
tensor_layout
::
convolution
;
namespace
ctc
=
ck
::
tensor_layout
::
convolution
;
const
index_t
N
=
e_n_k_wos_lengths
[
0
];
const
index_t
N
=
e_
g_
n_k_wos_lengths
[
1
];
const
index_t
K
=
e_n_k_wos_lengths
[
1
];
const
index_t
K
=
e_
g_
n_k_wos_lengths
[
2
];
const
index_t
WoStride
=
e_n_k_wos_strides
[
NDimSpatial
+
1
];
const
index_t
WoStride
=
e_
g_
n_k_wos_strides
[
NDimSpatial
+
2
];
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_n_k_wos_lengths
.
begin
()
+
2
,
const
index_t
GemmMRaw
=
N
*
std
::
accumulate
(
e_
g_
n_k_wos_lengths
.
begin
()
+
3
,
e_n_k_wos_lengths
.
begin
()
+
2
+
NDimSpatial
,
e_
g_
n_k_wos_lengths
.
begin
()
+
3
+
NDimSpatial
,
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<
index_t
>
());
std
::
multiplies
<
index_t
>
());
...
@@ -654,15 +772,15 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -654,15 +772,15 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
static
auto
MakeDsGridDescriptor_M_N
(
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_
g_
n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_strides
)
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_
g_
n_k_wos_strides
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_n_k_wos_lengths
[
i
],
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_
g_
n_k_wos_lengths
[
i
],
ds_n_k_wos_strides
[
i
]);
ds_
g_
n_k_wos_strides
[
i
]);
},
},
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
...
@@ -731,26 +849,27 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -731,26 +849,27 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
Argument
(
Argument
(
const
void
*
p_a
,
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_lengths
,
ds_g_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
ds_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
BElementwiseOperation
&
b_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a
)},
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b
)},
p_ds_grid_
{},
// FIXME
p_ds_grid_
{},
// FIXME
...
@@ -764,56 +883,72 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -764,56 +883,72 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{},
block_2_etile_map_
{},
compute_ptr_offset_of_batch_
{},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
cde_element_op_
{
cde_element_op
},
a_n_c_wis_lengths_
{
a_n_c_wis_lengths
},
a_
g_
n_c_wis_lengths_
{
a_
g_
n_c_wis_lengths
},
a_n_c_wis_strides_
{
a_n_c_wis_strides
},
a_
g_
n_c_wis_strides_
{
a_
g_
n_c_wis_strides
},
b_k_c_xs_lengths_
{
b_k_c_xs_lengths
},
b_
g_
k_c_xs_lengths_
{
b_
g_
k_c_xs_lengths
},
b_k_c_xs_strides_
{
b_k_c_xs_strides
},
b_
g_
k_c_xs_strides_
{
b_
g_
k_c_xs_strides
},
ds_n_k_wos_lengths_
{
ds_n_k_wos_lengths
},
ds_
g_
n_k_wos_lengths_
{
ds_
g_
n_k_wos_lengths
},
ds_n_k_wos_strides_
{
ds_n_k_wos_strides
},
ds_
g_
n_k_wos_strides_
{
ds_
g_
n_k_wos_strides
},
e_n_k_wos_lengths_
{
e_n_k_wos_lengths
},
e_
g_
n_k_wos_lengths_
{
e_
g_
n_k_wos_lengths
},
e_n_k_wos_strides_
{
e_n_k_wos_strides
},
e_
g_
n_k_wos_strides_
{
e_
g_
n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
input_right_pads_
{
input_right_pads
}
{
{
a_grid_desc_m_k_
=
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
a_n_c_wis_lengths
,
// A desc
a_n_c_wis_strides
,
a_grid_desc_m_k_
=
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
a_g_n_c_wis_lengths
,
b_k_c_xs_lengths
,
a_g_n_c_wis_strides
,
b_k_c_xs_strides
,
b_g_k_c_xs_lengths
,
e_n_k_wos_lengths
,
b_g_k_c_xs_strides
,
e_n_k_wos_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
input_right_pads
);
input_right_pads
);
// B Desc
b_grid_desc_n_k_
=
b_grid_desc_n_k_
=
DeviceOp
::
MakeBGridDescriptor_N_K
<
BLayout
>
(
b_k_c_xs_lengths
,
b_k_c_xs_strides
);
DeviceOp
::
MakeBGridDescriptor_N_K
<
BLayout
>
(
b_
g_
k_c_xs_lengths
,
b_
g_
k_c_xs_strides
);
e_grid_desc_m_n_
=
// E Desc
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_n_k_wos_lengths
,
e_n_k_wos_strides
);
e_grid_desc_m_n_
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
// A Des
a_grid_desc_ak0_m_ak1_
=
a_grid_desc_ak0_m_ak1_
=
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
);
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
);
// B Desc
b_grid_desc_bk0_n_bk1_
=
b_grid_desc_bk0_n_bk1_
=
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
);
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
);
// Block-to-e-tile
block_2_etile_map_
=
Block2ETileMap
{
e_grid_desc_m_n_
};
block_2_etile_map_
=
Block2ETileMap
{
e_grid_desc_m_n_
};
// populate pointer and desc for Ds
// A/B/E Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
// D batch stride
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
ds_n_k_wos_lengths
[
i
],
ds_n_k_wos_strides
[
i
]);
ds_
g_
n_k_wos_lengths
[
i
],
ds_
g_
n_k_wos_strides
[
i
]);
});
});
// populate desc for Ds/E
// populate desc for Ds/E
...
@@ -865,20 +1000,22 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -865,20 +1000,22 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// block-to-e-tile map
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
Block2ETileMap
block_2_etile_map_
;
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
// element-wise op
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
2
>
a_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_
g_
n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
2
>
a_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_
g_
n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
2
>
b_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_
g_
k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
2
>
b_k_c_xs_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_
g_
k_c_xs_strides_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>
ds_n_k_wos_lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_
g_
n_k_wos_lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>
ds_n_k_wos_strides_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_
g_
n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
2
>
e_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_
g_
n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
2
>
e_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_
g_
n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
...
@@ -906,7 +1043,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -906,7 +1043,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
arg
.
a_g_n_c_wis_lengths_
[
0
];
const
auto
K
=
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
@@ -914,7 +1052,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -914,7 +1052,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_multiple_d_xdl_cshuffle
<
const
auto
kernel
=
kernel_
batch_
gemm_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
typename
GridwiseGemm
::
DsGridPointer
,
...
@@ -927,6 +1065,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -927,6 +1065,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
has_main_loop
>
;
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
...
@@ -941,11 +1080,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -941,11 +1080,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_g_n_c_wis_lengths_
[
0
],
// Group count
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_etile_map_
);
arg
.
block_2_etile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
};
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
...
@@ -991,6 +1132,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -991,6 +1132,10 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
false
;
return
false
;
}
}
int
itmp
=
0
;
printf
(
"%d
\n
"
,
itmp
++
);
// check ConvolutionForwardSpecialization
// check ConvolutionForwardSpecialization
if
constexpr
(
ConvForwardSpecialization
==
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
...
@@ -998,7 +1143,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -998,7 +1143,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// check if it's 1x1, stride=1 conv
// check if it's 1x1, stride=1 conv
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
{
const
index_t
X
=
arg
.
b_k_c_xs_lengths_
[
i
+
2
];
const
index_t
X
=
arg
.
b_
g_
k_c_xs_lengths_
[
i
+
2
];
const
index_t
ConvStride
=
arg
.
conv_filter_strides_
[
i
];
const
index_t
ConvStride
=
arg
.
conv_filter_strides_
[
i
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
...
@@ -1015,7 +1160,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1015,7 +1160,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
// check if it's 1x1 conv
// check if it's 1x1 conv
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
{
const
index_t
X
=
arg
.
b_k_c_xs_lengths_
[
i
+
2
];
const
index_t
X
=
arg
.
b_
g_
k_c_xs_lengths_
[
i
+
2
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
LeftPad
=
arg
.
input_left_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
const
index_t
RightPad
=
arg
.
input_right_pads_
[
i
];
...
@@ -1026,11 +1171,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1026,11 +1171,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
}
}
}
}
printf
(
"%d
\n
"
,
itmp
++
);
// check vector access of A
// check vector access of A
if
constexpr
(
is_same_v
<
ALayout
,
ctc
::
NWC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWC
>
||
if
constexpr
(
is_same_v
<
ALayout
,
ctc
::
NWC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWC
>
)
is_same_v
<
ALayout
,
ctc
::
NDHWC
>
)
{
{
const
index_t
C
=
arg
.
a_n_c_wis_lengths_
[
1
];
const
index_t
C
=
arg
.
a_
g_
n_c_wis_lengths_
[
2
];
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
...
@@ -1042,11 +1189,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1042,11 +1189,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
false
;
return
false
;
}
}
printf
(
"%d
\n
"
,
itmp
++
);
// check vector access of B
// check vector access of B
if
constexpr
(
is_same_v
<
BLayout
,
ctc
::
KXC
>
||
is_same_v
<
BLayout
,
ctc
::
KYXC
>
||
if
constexpr
(
is_same_v
<
BLayout
,
ctc
::
KXC
>
||
is_same_v
<
BLayout
,
ctc
::
KYXC
>
||
is_same_v
<
BLayout
,
ctc
::
KZYXC
>
)
is_same_v
<
BLayout
,
ctc
::
KZYXC
>
)
{
{
const
index_t
C
=
arg
.
b_k_c_xs_lengths_
[
1
];
const
index_t
C
=
arg
.
b_
g_
k_c_xs_lengths_
[
2
];
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
{
{
...
@@ -1058,6 +1207,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1058,6 +1207,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
false
;
return
false
;
}
}
printf
(
"%d
\n
"
,
itmp
++
);
// check vector access of Ds
// check vector access of Ds
bool
valid
=
true
;
bool
valid
=
true
;
...
@@ -1068,7 +1219,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1068,7 +1219,7 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
is_same_v
<
DLayout
,
ctc
::
NDHWK
>
||
is_same_v
<
DLayout
,
ctc
::
NW_K
>
||
is_same_v
<
DLayout
,
ctc
::
NDHWK
>
||
is_same_v
<
DLayout
,
ctc
::
NW_K
>
||
is_same_v
<
DLayout
,
ctc
::
NHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
NDHW_K
>
)
is_same_v
<
DLayout
,
ctc
::
NHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
NDHW_K
>
)
{
{
const
index_t
K
=
arg
.
ds_n_k_wos_lengths_
[
i
][
1
];
const
index_t
K
=
arg
.
ds_
g_
n_k_wos_lengths_
[
i
][
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
{
...
@@ -1086,11 +1237,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1086,11 +1237,13 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
false
;
return
false
;
}
}
printf
(
"%d
\n
"
,
itmp
++
);
// check vector access of E
// check vector access of E
if
constexpr
(
is_same_v
<
ELayout
,
ctc
::
NWK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWK
>
||
if
constexpr
(
is_same_v
<
ELayout
,
ctc
::
NWK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWK
>
)
is_same_v
<
ELayout
,
ctc
::
NDHWK
>
)
{
{
const
index_t
K
=
arg
.
e_n_k_wos_lengths_
[
1
];
const
index_t
K
=
arg
.
e_
g_
n_k_wos_lengths_
[
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
{
...
@@ -1102,6 +1255,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1102,6 +1255,8 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
return
false
;
return
false
;
}
}
printf
(
"%d
\n
"
,
itmp
++
);
// check Gridwise GEMM
// check Gridwise GEMM
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
...
@@ -1120,14 +1275,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1120,14 +1275,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
const
void
*
p_b
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_
g_
n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_
g_
n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_
g_
n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_
g_
n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -1140,14 +1295,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1140,14 +1295,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
p_b
,
p_b
,
p_ds
,
p_ds
,
p_e
,
p_e
,
a_n_c_wis_lengths
,
a_
g_
n_c_wis_lengths
,
a_n_c_wis_strides
,
a_
g_
n_c_wis_strides
,
b_k_c_xs_lengths
,
b_
g_
k_c_xs_lengths
,
b_k_c_xs_strides
,
b_
g_
k_c_xs_strides
,
ds_n_k_wos_lengths
,
ds_
g_
n_k_wos_lengths
,
ds_n_k_wos_strides
,
ds_
g_
n_k_wos_strides
,
e_n_k_wos_lengths
,
e_
g_
n_k_wos_lengths
,
e_n_k_wos_strides
,
e_
g_
n_k_wos_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -1164,14 +1319,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1164,14 +1319,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
const
void
*
p_b
,
const
void
*
p_b
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
const
std
::
array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
void
*
p_e
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_
g_
n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
a_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_
g_
n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
b_k_c_xs_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_
g_
k_c_xs_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_
g_
n_k_wos_lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
2
>
,
NumDTensor
>&
ds_n_k_wos_strides
,
const
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_
g_
n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
2
>&
e_n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_
g_
n_k_wos_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -1184,14 +1339,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
...
@@ -1184,14 +1339,14 @@ struct DeviceConvFwdMultipleD_Xdl_CShuffle : public DeviceConvFwdMultipleD<NDimS
p_b
,
p_b
,
p_ds
,
p_ds
,
p_e
,
p_e
,
a_n_c_wis_lengths
,
a_
g_
n_c_wis_lengths
,
a_n_c_wis_strides
,
a_
g_
n_c_wis_strides
,
b_k_c_xs_lengths
,
b_
g_
k_c_xs_lengths
,
b_k_c_xs_strides
,
b_
g_
k_c_xs_strides
,
ds_n_k_wos_lengths
,
ds_
g_
n_k_wos_lengths
,
ds_n_k_wos_strides
,
ds_
g_
n_k_wos_strides
,
e_n_k_wos_lengths
,
e_
g_
n_k_wos_lengths
,
e_n_k_wos_strides
,
e_
g_
n_k_wos_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp
View file @
19173ab7
...
@@ -89,97 +89,33 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -89,97 +89,33 @@ struct ReferenceConvFwd : public device::BaseOperator
{
{
using
Argument
=
ReferenceConvFwd
::
Argument
;
using
Argument
=
ReferenceConvFwd
::
Argument
;
// FIXME: properly implement "TensorView" for doing transpose or refer to dimension by name
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
// tensor descriptor in NCHW/KXYC/NKHW dimensional order
HostTensorDescriptor
in_desc
=
arg
.
input_
.
mDesc
;
HostTensorDescriptor
wei_desc
=
arg
.
weight_
.
mDesc
;
HostTensorDescriptor
out_desc
=
arg
.
output_
.
mDesc
;
// input
if
constexpr
(
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NWC
>
)
{
in_desc
=
transpose_host_tensor_descriptor_given_new2old
(
arg
.
input_
.
mDesc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
}
else
if
constexpr
(
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NHWC
>
)
{
in_desc
=
transpose_host_tensor_descriptor_given_new2old
(
arg
.
input_
.
mDesc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
}
else
if
constexpr
(
is_same_v
<
InLayout
,
ck
::
tensor_layout
::
convolution
::
NDHWC
>
)
{
in_desc
=
transpose_host_tensor_descriptor_given_new2old
(
arg
.
input_
.
mDesc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
}
// weight
if
constexpr
(
is_same_v
<
WeiLayout
,
ck
::
tensor_layout
::
convolution
::
KXC
>
)
{
wei_desc
=
transpose_host_tensor_descriptor_given_new2old
(
arg
.
weight_
.
mDesc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
}
else
if
constexpr
(
is_same_v
<
WeiLayout
,
ck
::
tensor_layout
::
convolution
::
KYXC
>
)
{
wei_desc
=
transpose_host_tensor_descriptor_given_new2old
(
arg
.
weight_
.
mDesc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
}
else
if
constexpr
(
is_same_v
<
WeiLayout
,
ck
::
tensor_layout
::
convolution
::
KZYXC
>
)
{
wei_desc
=
transpose_host_tensor_descriptor_given_new2old
(
arg
.
weight_
.
mDesc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
}
// output
if
constexpr
(
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NWK
>
)
{
out_desc
=
transpose_host_tensor_descriptor_given_new2old
(
arg
.
output_
.
mDesc
,
std
::
vector
<
std
::
size_t
>
{
0
,
2
,
1
});
}
else
if
constexpr
(
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NHWK
>
)
{
out_desc
=
transpose_host_tensor_descriptor_given_new2old
(
arg
.
output_
.
mDesc
,
std
::
vector
<
std
::
size_t
>
{
0
,
3
,
1
,
2
});
}
else
if
constexpr
(
is_same_v
<
OutLayout
,
ck
::
tensor_layout
::
convolution
::
NDHWK
>
)
{
out_desc
=
transpose_host_tensor_descriptor_given_new2old
(
arg
.
output_
.
mDesc
,
std
::
vector
<
std
::
size_t
>
{
0
,
4
,
1
,
2
,
3
});
}
if
constexpr
(
NumDimSpatial
==
1
)
if
constexpr
(
NumDimSpatial
==
1
)
{
{
auto
f
_
nc
w
=
[
&
](
auto
n
,
auto
k
,
auto
wo
)
{
auto
f
u
nc
=
[
&
](
auto
g
,
auto
n
,
auto
k
,
auto
wo
)
{
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
c
=
0
;
c
<
wei_desc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
GetLengths
()[
2
];
++
c
)
{
{
for
(
std
::
size_t
x
=
0
;
x
<
wei_desc
.
GetLengths
()[
2
];
++
x
)
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
GetLengths
()[
3
];
++
x
)
{
{
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
x
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
if
(
wi
>=
0
&&
if
(
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
in_desc
.
GetLengths
()[
2
])
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
3
])
{
{
float
v_in
;
float
v_in
;
float
v_wei
;
float
v_wei
;
// FIXME hacky
arg
.
in_element_op_
(
arg
.
in_element_op_
(
v_in
,
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
wi
)));
ck
::
type_convert
<
float
>
(
arg
.
input_
.
mData
[
in_desc
.
GetOffsetFromMultiIndex
(
n
,
c
,
wi
)]));
// FIXME hacky
arg
.
wei_element_op_
(
arg
.
wei_element_op_
(
v_wei
,
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
x
)));
ck
::
type_convert
<
float
>
(
arg
.
weight_
.
mData
[
wei_desc
.
GetOffsetFromMultiIndex
(
k
,
c
,
x
)]));
v_acc
+=
v_in
*
v_wei
;
v_acc
+=
v_in
*
v_wei
;
}
}
...
@@ -190,33 +126,32 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -190,33 +126,32 @@ struct ReferenceConvFwd : public device::BaseOperator
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
out_element_op_
(
v_out
,
v_acc
);
// FIXME hacky
arg
.
output_
(
g
,
n
,
k
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
arg
.
output_
.
mData
[
out_desc
.
GetOffsetFromMultiIndex
({
n
,
k
,
wo
})]
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
};
make_ParallelTensorFunctor
(
f_ncw
,
make_ParallelTensorFunctor
(
func
,
out_desc
.
GetLengths
()[
0
],
arg
.
output_
.
GetLengths
()[
0
],
out_desc
.
GetLengths
()[
1
],
arg
.
output_
.
GetLengths
()[
1
],
out_desc
.
GetLengths
()[
2
])(
arg
.
output_
.
GetLengths
()[
2
],
arg
.
output_
.
GetLengths
()[
3
])(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
}
}
else
if
constexpr
(
NumDimSpatial
==
2
)
else
if
constexpr
(
NumDimSpatial
==
2
)
{
{
auto
f
_
nc
hw
=
[
&
](
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
auto
f
u
nc
=
[
&
](
auto
g
,
auto
n
,
auto
k
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
c
=
0
;
c
<
wei_desc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
GetLengths
()[
2
];
++
c
)
{
{
for
(
std
::
size_t
y
=
0
;
y
<
wei_desc
.
GetLengths
()[
2
];
++
y
)
for
(
std
::
size_t
y
=
0
;
y
<
arg
.
weight_
.
GetLengths
()[
3
];
++
y
)
{
{
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
x
=
0
;
x
<
wei_desc
.
GetLengths
()[
3
];
++
x
)
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
GetLengths
()[
4
];
++
x
)
{
{
auto
wi
=
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
1
])
+
...
@@ -224,26 +159,18 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -224,26 +159,18 @@ struct ReferenceConvFwd : public device::BaseOperator
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
if
(
hi
>=
0
&&
if
(
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
in_desc
.
GetLengths
()[
2
]
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
arg
.
input_
.
GetLengths
()[
3
]
&&
wi
>=
0
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
in_desc
.
GetLengths
()[
3
])
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
4
])
{
{
float
v_in
;
float
v_in
;
float
v_wei
;
float
v_wei
;
// FIXME hacky
arg
.
in_element_op_
(
arg
.
in_element_op_
(
v_in
,
v_in
,
ck
::
type_convert
<
float
>
(
arg
.
input_
(
g
,
n
,
c
,
hi
,
wi
)));
ck
::
type_convert
<
float
>
(
arg
.
input_
.
mData
[
in_desc
.
GetOffsetFromMultiIndex
(
n
,
c
,
hi
,
wi
)]));
// FIXME hacky
arg
.
wei_element_op_
(
arg
.
wei_element_op_
(
v_wei
,
v_wei
,
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
y
,
x
)));
ck
::
type_convert
<
float
>
(
arg
.
weight_
.
mData
[
wei_desc
.
GetOffsetFromMultiIndex
(
k
,
c
,
y
,
x
)]));
v_acc
+=
v_in
*
v_wei
;
v_acc
+=
v_in
*
v_wei
;
}
}
...
@@ -255,39 +182,38 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -255,39 +182,38 @@ struct ReferenceConvFwd : public device::BaseOperator
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
out_element_op_
(
v_out
,
v_acc
);
// FIXME hacky
arg
.
output_
(
g
,
n
,
k
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
arg
.
output_
.
mData
[
out_desc
.
GetOffsetFromMultiIndex
({
n
,
k
,
ho
,
wo
})]
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
};
make_ParallelTensorFunctor
(
f_nchw
,
make_ParallelTensorFunctor
(
func
,
out_desc
.
GetLengths
()[
0
],
arg
.
output_
.
GetLengths
()[
0
],
out_desc
.
GetLengths
()[
1
],
arg
.
output_
.
GetLengths
()[
1
],
out_desc
.
GetLengths
()[
2
],
arg
.
output_
.
GetLengths
()[
2
],
out_desc
.
GetLengths
()[
3
])(
arg
.
output_
.
GetLengths
()[
3
],
arg
.
output_
.
GetLengths
()[
4
])(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
}
}
else
if
constexpr
(
NumDimSpatial
==
3
)
else
if
constexpr
(
NumDimSpatial
==
3
)
{
{
auto
f
_
nc
hw
=
[
&
](
auto
n
,
auto
k
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
auto
f
u
nc
=
[
&
](
auto
g
,
auto
n
,
auto
k
,
auto
d_o
,
auto
ho
,
auto
wo
)
{
float
v_acc
=
0
;
float
v_acc
=
0
;
for
(
std
::
size_t
c
=
0
;
c
<
wei_desc
.
GetLengths
()[
1
];
++
c
)
for
(
std
::
size_t
c
=
0
;
c
<
arg
.
weight_
.
GetLengths
()[
2
];
++
c
)
{
{
for
(
std
::
size_t
z
=
0
;
z
<
wei_desc
.
GetLengths
()[
2
];
++
z
)
for
(
std
::
size_t
z
=
0
;
z
<
arg
.
weight_
.
GetLengths
()[
3
];
++
z
)
{
{
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
auto
di
=
static_cast
<
ck
::
long_index_t
>
(
d_o
*
arg
.
conv_strides_
[
0
])
+
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
z
*
arg
.
conv_dilations_
[
0
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
0
]);
for
(
std
::
size_t
y
=
0
;
y
<
wei_desc
.
GetLengths
()[
3
];
++
y
)
for
(
std
::
size_t
y
=
0
;
y
<
arg
.
weight_
.
GetLengths
()[
4
];
++
y
)
{
{
auto
hi
=
auto
hi
=
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
ho
*
arg
.
conv_strides_
[
1
])
+
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
y
*
arg
.
conv_dilations_
[
1
])
-
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
1
]);
for
(
std
::
size_t
x
=
0
;
x
<
wei_desc
.
GetLengths
()[
4
];
++
x
)
for
(
std
::
size_t
x
=
0
;
x
<
arg
.
weight_
.
GetLengths
()[
5
];
++
x
)
{
{
auto
wi
=
auto
wi
=
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
2
])
+
static_cast
<
ck
::
long_index_t
>
(
wo
*
arg
.
conv_strides_
[
2
])
+
...
@@ -295,29 +221,24 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -295,29 +221,24 @@ struct ReferenceConvFwd : public device::BaseOperator
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
static_cast
<
ck
::
long_index_t
>
(
arg
.
in_left_pads_
[
2
]);
if
(
di
>=
0
&&
if
(
di
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
ck
::
type_convert
<
std
::
size_t
>
(
di
)
<
in_desc
.
GetLengths
()[
2
]
&&
arg
.
input_
.
GetLengths
()[
3
]
&&
hi
>=
0
&&
hi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
ck
::
type_convert
<
std
::
size_t
>
(
hi
)
<
in_desc
.
GetLengths
()[
3
]
&&
arg
.
input_
.
GetLengths
()[
4
]
&&
wi
>=
0
&&
wi
>=
0
&&
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
in_desc
.
GetLengths
()[
4
])
ck
::
type_convert
<
std
::
size_t
>
(
wi
)
<
arg
.
input_
.
GetLengths
()[
5
])
{
{
float
v_in
;
float
v_in
;
float
v_wei
;
float
v_wei
;
// FIXME hacky
arg
.
in_element_op_
(
v_in
,
arg
.
in_element_op_
(
ck
::
type_convert
<
float
>
(
v_in
,
arg
.
input_
(
g
,
n
,
c
,
di
,
hi
,
wi
)));
ck
::
type_convert
<
float
>
(
arg
.
input_
.
mData
[
in_desc
.
GetOffsetFromMultiIndex
(
n
,
c
,
di
,
hi
,
wi
)]));
// FIXME hacky
arg
.
wei_element_op_
(
arg
.
wei_element_op_
(
v_wei
,
v_wei
,
ck
::
type_convert
<
float
>
(
ck
::
type_convert
<
float
>
(
arg
.
weight_
(
g
,
k
,
c
,
z
,
y
,
x
)));
arg
.
weight_
.
mData
[
wei_desc
.
GetOffsetFromMultiIndex
(
k
,
c
,
z
,
y
,
x
)]));
v_acc
+=
v_in
*
v_wei
;
v_acc
+=
v_in
*
v_wei
;
}
}
...
@@ -330,17 +251,16 @@ struct ReferenceConvFwd : public device::BaseOperator
...
@@ -330,17 +251,16 @@ struct ReferenceConvFwd : public device::BaseOperator
arg
.
out_element_op_
(
v_out
,
v_acc
);
arg
.
out_element_op_
(
v_out
,
v_acc
);
// FIXME hacky
arg
.
output_
(
g
,
n
,
k
,
d_o
,
ho
,
wo
)
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
arg
.
output_
.
mData
[
out_desc
.
GetOffsetFromMultiIndex
({
n
,
k
,
d_o
,
ho
,
wo
})]
=
ck
::
type_convert
<
OutDataType
>
(
v_out
);
};
};
make_ParallelTensorFunctor
(
f_nchw
,
make_ParallelTensorFunctor
(
func
,
out_desc
.
GetLengths
()[
0
],
arg
.
output_
.
GetLengths
()[
0
],
out_desc
.
GetLengths
()[
1
],
arg
.
output_
.
GetLengths
()[
1
],
out_desc
.
GetLengths
()[
2
],
arg
.
output_
.
GetLengths
()[
2
],
out_desc
.
GetLengths
()[
3
],
arg
.
output_
.
GetLengths
()[
3
],
out_desc
.
GetLengths
()[
4
])(
arg
.
output_
.
GetLengths
()[
4
],
arg
.
output_
.
GetLengths
()[
5
])(
std
::
thread
::
hardware_concurrency
());
std
::
thread
::
hardware_concurrency
());
return
0
;
return
0
;
...
...
library/include/ck/library/utility/convolution_parameter.hpp
View file @
19173ab7
...
@@ -18,6 +18,7 @@ struct ConvParam
...
@@ -18,6 +18,7 @@ struct ConvParam
{
{
ConvParam
();
ConvParam
();
ConvParam
(
ck
::
index_t
n_dim
,
ConvParam
(
ck
::
index_t
n_dim
,
ck
::
index_t
group_count
,
ck
::
index_t
n_batch
,
ck
::
index_t
n_batch
,
ck
::
index_t
n_out_channels
,
ck
::
index_t
n_out_channels
,
ck
::
index_t
n_in_channels
,
ck
::
index_t
n_in_channels
,
...
@@ -29,6 +30,7 @@ struct ConvParam
...
@@ -29,6 +30,7 @@ struct ConvParam
const
std
::
vector
<
ck
::
index_t
>&
right_pads
);
const
std
::
vector
<
ck
::
index_t
>&
right_pads
);
ck
::
index_t
num_dim_spatial_
;
ck
::
index_t
num_dim_spatial_
;
ck
::
index_t
G_
;
ck
::
index_t
N_
;
ck
::
index_t
N_
;
ck
::
index_t
K_
;
ck
::
index_t
K_
;
ck
::
index_t
C_
;
ck
::
index_t
C_
;
...
@@ -50,20 +52,22 @@ struct ConvParam
...
@@ -50,20 +52,22 @@ struct ConvParam
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
std
::
size_t
GetByte
()
const
std
::
size_t
GetByte
()
const
{
{
// sizeof(InDataType) * (N * C * <input spatial lengths product>) +
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
// sizeof(WeiDataType) * (K * C * <filter spatial lengths product>) +
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
// sizeof(OutDataType) * (N * K * <output spatial lengths product>);
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return
sizeof
(
InDataType
)
*
(
N_
*
C_
*
return
sizeof
(
InDataType
)
*
std
::
accumulate
(
std
::
begin
(
input_spatial_lengths_
),
(
G_
*
N_
*
C_
*
std
::
end
(
input_spatial_lengths_
),
std
::
accumulate
(
std
::
begin
(
input_spatial_lengths_
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
begin
(
input_spatial_lengths_
)
+
num_dim_spatial_
,
std
::
multiplies
<
std
::
size_t
>
()))
+
static_cast
<
std
::
size_t
>
(
1
),
sizeof
(
WeiDataType
)
*
(
K_
*
C_
*
std
::
multiplies
<
std
::
size_t
>
()))
+
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
sizeof
(
WeiDataType
)
*
std
::
end
(
filter_spatial_lengths_
),
(
G_
*
K_
*
C_
*
static_cast
<
std
::
size_t
>
(
1
),
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
std
::
multiplies
<
std
::
size_t
>
()))
+
std
::
begin
(
filter_spatial_lengths_
)
+
num_dim_spatial_
,
sizeof
(
OutDataType
)
*
(
N_
*
K_
*
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
()))
+
sizeof
(
OutDataType
)
*
(
G_
*
N_
*
K_
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths_
),
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths_
),
std
::
end
(
output_spatial_lengths_
),
std
::
end
(
output_spatial_lengths_
),
static_cast
<
std
::
size_t
>
(
1
),
static_cast
<
std
::
size_t
>
(
1
),
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
19173ab7
...
@@ -256,6 +256,10 @@ struct Tensor
...
@@ -256,6 +256,10 @@ struct Tensor
return
*
this
;
return
*
this
;
}
}
const
std
::
vector
<
std
::
size_t
>&
GetLengths
()
const
{
return
mDesc
.
GetLengths
();
}
const
std
::
vector
<
std
::
size_t
>&
GetStrides
()
const
{
return
mDesc
.
GetStrides
();
}
void
SetZero
()
void
SetZero
()
{
{
for
(
auto
&
v
:
mData
)
for
(
auto
&
v
:
mData
)
...
...
library/src/utility/convolution_parameter.cpp
View file @
19173ab7
...
@@ -10,6 +10,7 @@ namespace utils {
...
@@ -10,6 +10,7 @@ namespace utils {
namespace
conv
{
namespace
conv
{
ConvParam
::
ConvParam
(
ck
::
index_t
n_dim
,
ConvParam
::
ConvParam
(
ck
::
index_t
n_dim
,
ck
::
index_t
group_count
,
ck
::
index_t
n_batch
,
ck
::
index_t
n_batch
,
ck
::
index_t
n_out_channels
,
ck
::
index_t
n_out_channels
,
ck
::
index_t
n_in_channels
,
ck
::
index_t
n_in_channels
,
...
@@ -20,6 +21,7 @@ ConvParam::ConvParam(ck::index_t n_dim,
...
@@ -20,6 +21,7 @@ ConvParam::ConvParam(ck::index_t n_dim,
const
std
::
vector
<
ck
::
index_t
>&
left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
left_pads
,
const
std
::
vector
<
ck
::
index_t
>&
right_pads
)
const
std
::
vector
<
ck
::
index_t
>&
right_pads
)
:
num_dim_spatial_
(
n_dim
),
:
num_dim_spatial_
(
n_dim
),
G_
(
group_count
),
N_
(
n_batch
),
N_
(
n_batch
),
K_
(
n_out_channels
),
K_
(
n_out_channels
),
C_
(
n_in_channels
),
C_
(
n_in_channels
),
...
@@ -57,7 +59,7 @@ ConvParam::ConvParam(ck::index_t n_dim,
...
@@ -57,7 +59,7 @@ ConvParam::ConvParam(ck::index_t n_dim,
}
}
ConvParam
::
ConvParam
()
ConvParam
::
ConvParam
()
:
ConvParam
::
ConvParam
(
2
,
128
,
256
,
192
,
{
3
,
3
},
{
71
,
71
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
})
:
ConvParam
::
ConvParam
(
2
,
1
,
128
,
256
,
192
,
{
3
,
3
},
{
71
,
71
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
})
{
{
}
}
...
@@ -68,14 +70,14 @@ std::vector<ck::index_t> ConvParam::GetOutputSpatialLengths() const
...
@@ -68,14 +70,14 @@ std::vector<ck::index_t> ConvParam::GetOutputSpatialLengths() const
std
::
size_t
ConvParam
::
GetFlops
()
const
std
::
size_t
ConvParam
::
GetFlops
()
const
{
{
// 2 * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
// 2 *
G *
N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return
static_cast
<
std
::
size_t
>
(
2
)
*
N_
*
K_
*
C_
*
return
static_cast
<
std
::
size_t
>
(
2
)
*
G_
*
N_
*
K_
*
C_
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths_
),
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths_
),
std
::
end
(
output_spatial_lengths_
),
std
::
begin
(
output_spatial_lengths_
)
+
num_dim_spatial_
,
static_cast
<
std
::
size_t
>
(
1
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
())
*
std
::
multiplies
<
std
::
size_t
>
())
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
std
::
end
(
filter_spatial_lengths_
),
std
::
begin
(
filter_spatial_lengths_
)
+
num_dim_spatial_
,
static_cast
<
std
::
size_t
>
(
1
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
());
std
::
multiplies
<
std
::
size_t
>
());
}
}
...
@@ -87,13 +89,14 @@ std::size_t ConvParam::GetFlops() const
...
@@ -87,13 +89,14 @@ std::size_t ConvParam::GetFlops() const
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
utils
::
conv
::
ConvParam
&
p
)
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck
::
utils
::
conv
::
ConvParam
&
p
)
{
{
os
<<
"ConvParam {"
os
<<
"ConvParam {"
<<
"
\n
num_dim_spatial: "
<<
p
.
num_dim_spatial_
<<
"
\n
N: "
<<
p
.
N_
<<
"
\n
K: "
<<
p
.
K_
<<
"
\n
num_dim_spatial: "
<<
p
.
num_dim_spatial_
<<
"
\n
G: "
<<
p
.
G_
<<
"
\n
N: "
<<
p
.
N_
<<
"
\n
C: "
<<
p
.
C_
<<
"
\n
filter_spatial_lengths: "
<<
p
.
filter_spatial_lengths_
<<
"
\n
K: "
<<
p
.
K_
<<
"
\n
C: "
<<
p
.
C_
<<
"
\n
filter_spatial_lengths: "
<<
p
.
filter_spatial_lengths_
<<
"
\n
input_spatial_lengths: "
<<
p
.
input_spatial_lengths_
<<
"
\n
input_spatial_lengths: "
<<
p
.
input_spatial_lengths_
<<
"
\n
conv_filter_strides: "
<<
p
.
conv_filter_strides_
<<
"
\n
conv_filter_strides: "
<<
p
.
conv_filter_strides_
<<
"
\n
conv_filter_dilations: "
<<
p
.
conv_filter_dilations_
<<
"
\n
conv_filter_dilations: "
<<
p
.
conv_filter_dilations_
<<
"
\n
input_left_pads: "
<<
p
.
input_left_pads_
<<
"
\n
input_left_pads: "
<<
p
.
input_left_pads_
<<
"
\n
input_right_pads: "
<<
p
.
input_right_pads_
;
<<
"
\n
input_right_pads: "
<<
p
.
input_right_pads_
<<
"}
\n
"
;
return
os
;
return
os
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment