Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
09d4c3a4
You need to sign in or sign up before continuing.
Commit
09d4c3a4
authored
Oct 01, 2024
by
illsilin
Browse files
merge from public repo
parents
171ed358
8e4c3fb1
Changes
202
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2184 additions
and
469 deletions
+2184
-469
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+42
-185
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+354
-70
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
.../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+326
-50
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+17
-0
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
...operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
+349
-74
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+28
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+4
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+4
-4
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
+1
-1
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+9
-9
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
+38
-14
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+6
-6
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+29
-29
include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
...tion/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
+236
-0
include/ck/utility/amd_smfmac.hpp
include/ck/utility/amd_smfmac.hpp
+14
-12
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+164
-9
include/ck_tile/core/container/array.hpp
include/ck_tile/core/container/array.hpp
+12
-1
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+2
-0
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
...k_tile/host/convolution_host_tensor_descriptor_helper.hpp
+266
-0
include/ck_tile/host/convolution_parameter.hpp
include/ck_tile/host/convolution_parameter.hpp
+283
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
09d4c3a4
...
...
@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...
...
@@ -22,7 +23,6 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -257,6 +257,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
KPerBlock
/
K1Number
,
ConvBackwardWeightSpecialization
>
{};
static
constexpr
index_t
ClusterLengthMPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
static
constexpr
index_t
ClusterLengthNPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
static
constexpr
auto
conv_ngchw_to_nhwgc_transformer
=
TransformConvNGCHWToNHWGC
<
InLayout
,
WeiLayout
,
OutLayout
,
NDimSpatial
,
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
>
{};
static
constexpr
GemmSpecialization
GemmSpec
=
GemmSpecialization
::
Default
;
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
...
...
@@ -359,141 +372,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
batch
)[
I2
];
}
static
constexpr
index_t
ClusterLengthMPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
static
constexpr
index_t
ClusterLengthNPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeInputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
2
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
3
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
4
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeOutputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeInputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
5
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
2
];
const
index_t
&
DiStride
=
g_n_c_wis_strides
[
3
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
4
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
5
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeOutputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
5
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
Hi
*
Wi
*
G
*
C
;
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
using
InputTransposeDescType
=
remove_cvref_t
<
decltype
(
MakeInputTransposeDesc
<
NDimSpatial
>
({},
{}))
>
;
using
OutputTransposeDescType
=
remove_cvref_t
<
decltype
(
MakeOutputTransposeDesc
<
NDimSpatial
>
({},
{}))
>
;
using
NGCHWTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
using
NHWGCTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
...
...
@@ -572,8 +456,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
I1
>
;
using
GridwiseElementwiseTranspose
=
GridwiseElementwise
<
Tuple
<
Input
TransposeDescType
>
,
Tuple
<
Output
TransposeDescType
>
,
GridwiseElementwise
<
Tuple
<
NGCHW
TransposeDescType
>
,
Tuple
<
NHWGC
TransposeDescType
>
,
Tuple
<
const
ADataType
*>
,
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
...
...
@@ -652,43 +536,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
begin
(
output_spatial_lengths_
));
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_n_c_wis_strides_transposed
=
b_g_n_c_wis_strides
;
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_k_wos_strides_transposed
=
a_g_n_k_wos_strides
;
// NGKHW - transpose needed
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
b_g_n_c_wis_strides_transposed
[
I0
]
=
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I2
]
=
I1
;
a_g_n_k_wos_strides_transposed
[
I0
]
=
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I2
]
=
I1
;
if
constexpr
(
NDimSpatial
==
2
)
{
b_g_n_c_wis_strides_transposed
[
I3
]
=
input_spatial_lengths_
[
I1
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I4
]
=
Conv_G_
*
Conv_C_
;
a_g_n_k_wos_strides_transposed
[
I3
]
=
output_spatial_lengths_
[
I1
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I4
]
=
Conv_G_
*
Conv_K_
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
b_g_n_c_wis_strides_transposed
[
I3
]
=
input_spatial_lengths_
[
I1
]
*
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I4
]
=
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I5
]
=
Conv_G_
*
Conv_C_
;
a_g_n_k_wos_strides_transposed
[
I3
]
=
output_spatial_lengths_
[
I1
]
*
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I4
]
=
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I5
]
=
Conv_G_
*
Conv_K_
;
}
}
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
const
auto
descs
=
conv_to_gemm_transformer_v2
...
...
@@ -755,14 +607,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
a_in_transpose_desc_
=
MakeInputTransposeDesc
<
NDimSpatial
>
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
a_out_transpose_desc_
=
MakeOutputTransposeDesc
<
NDimSpatial
>
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
b_in_transpose_desc_
=
MakeInputTransposeDesc
<
NDimSpatial
>
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
b_out_transpose_desc_
=
MakeOutputTransposeDesc
<
NDimSpatial
>
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
elementwise_block_2_ctile_map_transpose_a_
=
Block2TileMapElementwise
{
a_in_transpose_desc_
.
GetLength
(
I0
),
a_in_transpose_desc_
.
GetLength
(
I1
)};
...
...
@@ -816,8 +672,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
Block2TileMapElementwise
elementwise_block_2_ctile_map_transpose_a_
,
elementwise_block_2_ctile_map_transpose_b_
;
Input
TransposeDescType
a_in_transpose_desc_
,
b_in_transpose_desc_
;
Output
TransposeDescType
a_out_transpose_desc_
,
b_out_transpose_desc_
;
NGCHW
TransposeDescType
a_in_transpose_desc_
,
b_in_transpose_desc_
;
NHWGC
TransposeDescType
a_out_transpose_desc_
,
b_out_transpose_desc_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_batch_
;
...
...
@@ -1569,13 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
(
arg
.
GetWorkspaceETensorSizeBytes
()
+
arg
.
GetWorkspaceATensorSizeBytes
())
/
sizeof
(
BDataType
);
// Different data type for A and B is not supported
auto
kernel_transpose
=
kernel_elementwise_dual
<
GridwiseElementwiseTranspose
,
ck
::
Tuple
<
Input
TransposeDescType
>
,
ck
::
Tuple
<
Input
TransposeDescType
>
,
ck
::
Tuple
<
Output
TransposeDescType
>
,
ck
::
Tuple
<
Output
TransposeDescType
>
,
ck
::
Tuple
<
NGCHW
TransposeDescType
>
,
ck
::
Tuple
<
NGCHW
TransposeDescType
>
,
ck
::
Tuple
<
NHWGC
TransposeDescType
>
,
ck
::
Tuple
<
NHWGC
TransposeDescType
>
,
ck
::
Tuple
<
const
ADataType
*>
,
ck
::
Tuple
<
B
DataType
*>
,
ck
::
Tuple
<
A
DataType
*>
,
Block2TileMapElementwise
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
09d4c3a4
...
...
@@ -15,9 +15,11 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
...
...
@@ -307,6 +309,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
// NGCHW is not supported for multiAB
static_assert
(
!
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
||
!
(
isMultiA
||
isMultiB
));
static
constexpr
index_t
NumATensor
=
GetNumABTensors
<
isMultiA
,
ADataType
>
();
static
constexpr
index_t
NumBTensor
=
GetNumABTensors
<
isMultiB
,
BDataType
>
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -315,6 +322,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
...
...
@@ -323,14 +332,33 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
EDataType
,
NumGroupsToMerge
>
;
static
constexpr
index_t
ClusterLengthNPerBlock
=
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
static
constexpr
auto
conv_ngchw_to_nhwgc_transformer
=
TransformConvNGCHWToNHWGC
<
ALayout
,
BLayout
,
ELayout
,
NDimSpatial
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_M_K
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
namespace
ctc
=
tensor_layout
::
convolution
;
using
Layout
=
std
::
conditional_t
<
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NHWGC
,
std
::
conditional_t
<
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NDHWGC
,
ALay
>>
;
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
A
Lay
>();
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
Lay
out
>();
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -353,8 +381,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
namespace
ctc
=
tensor_layout
::
convolution
;
using
Layout
=
std
::
conditional_t
<
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NHWGK
,
std
::
conditional_t
<
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NDHWGK
,
ELay
>>
;
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
E
Lay
>();
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
Lay
out
>();
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
@@ -442,6 +478,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
using
Block2TileMapElementwise
=
BlockToCTileMap_M00_N0_M01Adapt
<
NPerBlock
,
NPerBlock
>
;
using
NGCHWTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
using
NHWGCTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
static
constexpr
index_t
ElementwiseBlocksize
=
ClusterLengthNPerBlock
*
ClusterLengthNPerBlock
;
using
GridwiseElementwiseInputTranspose
=
GridwiseElementwise
<
Tuple
<
NGCHWTransposeDescType
>
,
Tuple
<
NHWGCTransposeDescType
>
,
Tuple
<
const
ADataType
*>
,
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
ElementwiseBlocksize
,
NPerBlock
,
NPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
I1
,
I0
>
;
using
GridwiseElementwiseOutputTranspose
=
GridwiseElementwise
<
Tuple
<
NHWGCTransposeDescType
>
,
Tuple
<
NGCHWTransposeDescType
>
,
Tuple
<
const
EDataType
*>
,
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
ElementwiseBlocksize
,
NPerBlock
,
NPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
I0
,
I1
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -471,17 +553,31 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_bs_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
)},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
ds_g_n_k_wos_lengths_
{
ds_g_n_k_wos_lengths
},
ds_g_n_k_wos_strides_
{
ds_g_n_k_wos_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
num_group_
{
a_g_n_c_wis_lengths_
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths_
,
a_g_n_c_wis_strides_
,
b_g_k_c_xs_lengths_
,
b_g_k_c_xs_strides_
,
e_g_n_k_wos_lengths_
,
e_g_n_k_wos_strides_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
},
conv_N_per_block_
{
conv_to_gemm_transformer_
.
N_
},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
conv_to_gemm_transformer_
)},
...
...
@@ -501,19 +597,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
compute_ptr_offset_of_n_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
a_g_n_c_wis_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
ds_g_n_k_wos_lengths_
{
ds_g_n_k_wos_lengths
},
ds_g_n_k_wos_strides_
{
ds_g_n_k_wos_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
e_g_n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
cde_element_op_
{
cde_element_op
}
{
// A/B/E Batch Stride
if
constexpr
(
isMultiA
||
isMultiB
)
...
...
@@ -521,7 +605,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
a_g_n_c_wis_strides
_
[
0
]
*
NumGroupsToMerge
;
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
...
...
@@ -537,20 +621,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// in case of MultiA is false but isMultiB is true
// BatchStrideA_ is not tuple.
compute_ptr_offset_of_n_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
a_g_n_c_wis_strides
_
[
1
]
*
conv_N_per_block_
;
}
else
{
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
);
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
a_g_n_c_wis_strides
_
[
1
]
*
conv_N_per_block_
;
}
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
b_g_k_c_xs_strides
_
[
0
]
*
NumGroupsToMerge
;
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
...
...
@@ -571,10 +655,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
else
{
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
a_g_n_c_wis_strides
_
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
b_g_k_c_xs_strides_
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides_
[
1
]
*
conv_N_per_block_
;
// p_as and p_bs are pointers
p_as_grid_
(
I0
)
=
static_cast
<
const
ADataType
*>
(
p_as
);
...
...
@@ -591,27 +676,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D batch stride
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
]
*
NumGroupsToMerge
;
ds_g_n_k_wos_strides
_
[
i
][
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
],
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
ds_g_n_k_wos_strides
_
[
i
][
1
]
*
conv_N_per_block_
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
_
,
a_g_n_c_wis_strides
_
,
b_g_k_c_xs_lengths
_
,
b_g_k_c_xs_strides
_
,
e_g_n_k_wos_lengths
_
,
ds_g_n_k_wos_strides
_
[
i
],
conv_filter_strides
_
,
conv_filter_dilations
_
,
input_left_pads
_
,
input_right_pads
_
};
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer_d
);
});
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides_
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides_
[
1
]
*
conv_N_per_block_
;
// populate desc for Ds/E
if
constexpr
(
isMultiA
||
isMultiB
)
...
...
@@ -653,6 +739,54 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_m_n_
);
}
}
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
// Use not modified base strides
a_in_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
);
a_out_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
);
e_in_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
e_out_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
elementwise_block_2_ctile_map_transpose_a_
=
Block2TileMapElementwise
{
a_in_transpose_desc_
.
GetLength
(
I0
),
a_in_transpose_desc_
.
GetLength
(
I1
)};
elementwise_block_2_ctile_map_transpose_e_
=
Block2TileMapElementwise
{
e_in_transpose_desc_
.
GetLength
(
I0
),
e_in_transpose_desc_
.
GetLength
(
I1
)};
}
}
std
::
size_t
GetWorkspaceATensorSizeBytes
()
const
{
return
sizeof
(
ADataType
)
*
a_in_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceETensorSizeBytes
()
const
{
return
sizeof
(
EDataType
)
*
e_out_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceSizeBytes
()
const
{
// Transpose require workspace for A and B
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
GetWorkspaceATensorSizeBytes
()
+
GetWorkspaceETensorSizeBytes
();
}
else
{
return
0
;
}
}
void
Print
()
const
...
...
@@ -671,6 +805,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
// tensor descriptors for problem definiton
index_t
num_group_
;
...
...
@@ -692,6 +840,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
Block2TileMapElementwise
elementwise_block_2_ctile_map_transpose_a_
,
elementwise_block_2_ctile_map_transpose_e_
;
NGCHWTransposeDescType
a_in_transpose_desc_
,
e_out_transpose_desc_
;
NHWGCTransposeDescType
a_out_transpose_desc_
,
e_in_transpose_desc_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
...
...
@@ -702,20 +855,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
};
// Invoker
...
...
@@ -723,7 +862,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
Gemm
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
...
...
@@ -794,6 +933,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
const
ADataType
*
p_a_grid
=
arg
.
p_as_grid_
.
At
(
I0
);
EDataType
*
p_e_grid
=
arg
.
p_e_grid_
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
p_a_grid
=
type_convert
<
const
ADataType
*>
(
arg
.
p_workspace_
);
p_e_grid
=
type_convert
<
EDataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceATensorSizeBytes
()
/
sizeof
(
EDataType
);
}
const
auto
kernel
=
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle
<
GridwiseGemm
,
const
ADataType
*
,
...
...
@@ -820,10 +970,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
.
p_a
s
_grid
_
.
At
(
I0
),
// Pass just A descriptor instead of tuple
p_a_grid
,
// Pass just A descriptor instead of tuple
arg
.
p_bs_grid_
.
At
(
I0
),
// Pass just B descriptor instead of tuple
arg
.
p_ds_grid_
,
arg
.
p_e_grid
_
,
p_e_grid
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
...
...
@@ -847,6 +997,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0.
f
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_transpose_a_
.
CalculateGridSize
(
arg
.
a_in_transpose_desc_
);
ADataType
*
p_a_out_grid
=
type_convert
<
ADataType
*>
(
arg
.
p_workspace_
);
auto
kernel_transpose
=
kernel_elementwise
<
GridwiseElementwiseInputTranspose
,
ck
::
Tuple
<
NGCHWTransposeDescType
>
,
ck
::
Tuple
<
NHWGCTransposeDescType
>
,
ck
::
Tuple
<
const
ADataType
*>
,
ck
::
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size
),
dim3
(
ElementwiseBlocksize
),
0
,
make_tuple
(
arg
.
a_in_transpose_desc_
),
make_tuple
(
arg
.
a_out_transpose_desc_
),
make_tuple
(
arg
.
p_as_grid_
.
At
(
I0
)),
make_tuple
(
p_a_out_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_a_
,
element_wise
::
PassThrough
{});
}
avg_time
+=
RunGemm
(
arg
,
stream_config
);
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_transpose_e_
.
CalculateGridSize
(
arg
.
e_in_transpose_desc_
);
const
EDataType
*
p_e_out_grid
=
type_convert
<
EDataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceATensorSizeBytes
()
/
sizeof
(
EDataType
);
EDataType
*
p_e_in_grid
=
arg
.
p_e_grid_
;
auto
kernel_transpose
=
kernel_elementwise
<
GridwiseElementwiseOutputTranspose
,
ck
::
Tuple
<
NHWGCTransposeDescType
>
,
ck
::
Tuple
<
NGCHWTransposeDescType
>
,
ck
::
Tuple
<
const
EDataType
*>
,
ck
::
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size
),
dim3
(
ElementwiseBlocksize
),
0
,
make_tuple
(
arg
.
e_in_transpose_desc_
),
make_tuple
(
arg
.
e_out_transpose_desc_
),
make_tuple
(
p_e_out_grid
),
make_tuple
(
p_e_in_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_e_
,
element_wise
::
PassThrough
{});
}
return
avg_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
...
...
@@ -941,7 +1164,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return
false
;
}
if
constexpr
(
!
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
())
if
constexpr
(
!
(
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCSpatial_GKSpatial_NGKSpatial
<
ALayout
,
BLayout
,
ELayout
>
()))
{
return
false
;
}
...
...
@@ -953,14 +1177,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ALayout
,
ctc
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
GNWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNDHWC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NGCW
>
||
is_same_v
<
ALayout
,
ctc
::
NGCHW
>
||
is_same_v
<
ALayout
,
ctc
::
NGCDHW
>
)
{
// Check access per C
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
// If not possible, check access per G
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
C
==
1
&&
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
()
&&
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
(
C
==
1
||
NumGroupsToMerge
==
1
)
&&
(
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCSpatial_GKSpatial_NGKSpatial
<
ALayout
,
BLayout
,
ELayout
>
())
&&
G
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
...
...
@@ -1036,6 +1262,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
});
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
if
((
G
*
C
)
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
if
((
G
*
K
)
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
const
index_t
input_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
a_g_n_c_wis_lengths_
.
begin
()
+
I3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
output_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
e_g_n_k_wos_lengths_
.
begin
()
+
I3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
input_spatial_acum
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
if
(
output_spatial_acum
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
if
(
!
valid
)
{
return
false
;
...
...
@@ -1046,7 +1301,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ELayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
GNWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NGKW
>
||
is_same_v
<
ELayout
,
ctc
::
NGKHW
>
||
is_same_v
<
ELayout
,
ctc
::
NGKDHW
>
)
{
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
...
...
@@ -1352,6 +1608,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return
str
.
str
();
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
auto
arg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg
)
{
return
arg
->
GetWorkspaceSizeBytes
();
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
p_arg_
->
p_workspace_
=
p_workspace
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"
);
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
View file @
09d4c3a4
...
...
@@ -15,10 +15,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -292,6 +294,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
...
...
@@ -302,13 +306,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
constexpr
index_t
ClusterLengthNPerBlock
=
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
static
constexpr
auto
conv_ngchw_to_nhwgc_transformer
=
TransformConvNGCHWToNHWGC
<
ALayout
,
BLayout
,
ELayout
,
NDimSpatial
,
MPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
>
{};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
namespace
ctc
=
tensor_layout
::
convolution
;
using
Layout
=
std
::
conditional_t
<
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NHWGC
,
std
::
conditional_t
<
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NDHWGC
,
ALay
>>
;
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
A
Lay
>();
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
Lay
out
>();
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -351,8 +374,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
namespace
ctc
=
tensor_layout
::
convolution
;
using
Layout
=
std
::
conditional_t
<
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NHWGK
,
std
::
conditional_t
<
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NDHWGK
,
ELay
>>
;
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
E
Lay
>();
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
Lay
out
>();
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
@@ -385,6 +416,53 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// Use appropriate gridwise gemm
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v3
<
GridwiseGemmV3TemplateParams
>
;
using
Block2TileMapElementwise
=
BlockToCTileMap_M00_N0_M01Adapt
<
NPerBlock
,
NPerBlock
>
;
using
NGCHWTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
using
NHWGCTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
static
constexpr
index_t
ElementwiseBlocksize
=
ClusterLengthNPerBlock
*
ClusterLengthNPerBlock
;
using
GridwiseElementwiseInputTranspose
=
GridwiseElementwise
<
Tuple
<
NGCHWTransposeDescType
>
,
Tuple
<
NHWGCTransposeDescType
>
,
Tuple
<
const
ADataType
*>
,
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
ElementwiseBlocksize
,
NPerBlock
,
NPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
I1
,
I0
>
;
using
GridwiseElementwiseOutputTranspose
=
GridwiseElementwise
<
Tuple
<
NHWGCTransposeDescType
>
,
Tuple
<
NGCHWTransposeDescType
>
,
Tuple
<
const
EDataType
*>
,
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
ElementwiseBlocksize
,
NPerBlock
,
NPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
I0
,
I1
>
;
static
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
...
...
@@ -428,17 +506,29 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
:
p_a_grid_
{},
p_b_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
)},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
num_group_
{
a_g_n_c_wis_lengths_
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths_
,
a_g_n_c_wis_strides_
,
b_g_k_c_xs_lengths_
,
b_g_k_c_xs_strides_
,
e_g_n_k_wos_lengths_
,
e_g_n_k_wos_strides_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
},
conv_N_per_block_
{
conv_to_gemm_transformer_
.
N_
},
a_grid_desc_ak0_m_ak1_
{
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
conv_to_gemm_transformer_
)},
...
...
@@ -451,32 +541,70 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
compute_ptr_offset_of_n_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
a_g_n_c_wis_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
e_g_n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
cde_element_op_
{
cde_element_op
}
{
// A/B/E Batch/N Stride
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
_
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
_
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
_
[
1
]
*
conv_N_per_block_
;
// p_as and p_bs are pointers
p_a_grid_
=
static_cast
<
const
ADataType
*>
(
p_as
);
p_b_grid_
=
static_cast
<
const
BDataType
*>
(
p_bs
);
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
_
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
_
[
1
]
*
conv_N_per_block_
;
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
// Use not modified base strides
a_in_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
);
a_out_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
);
e_in_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
e_out_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
elementwise_block_2_ctile_map_transpose_a_
=
Block2TileMapElementwise
{
a_in_transpose_desc_
.
GetLength
(
I0
),
a_in_transpose_desc_
.
GetLength
(
I1
)};
elementwise_block_2_ctile_map_transpose_e_
=
Block2TileMapElementwise
{
e_in_transpose_desc_
.
GetLength
(
I0
),
e_in_transpose_desc_
.
GetLength
(
I1
)};
}
}
std
::
size_t
GetWorkspaceATensorSizeBytes
()
const
{
return
sizeof
(
ADataType
)
*
a_in_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceETensorSizeBytes
()
const
{
return
sizeof
(
EDataType
)
*
e_out_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceSizeBytes
()
const
{
// Transpose require workspace for A and B
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
GetWorkspaceATensorSizeBytes
()
+
GetWorkspaceETensorSizeBytes
();
}
else
{
return
0
;
}
}
void
Print
()
const
...
...
@@ -492,6 +620,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const
BDataType
*
p_b_grid_
;
EDataType
*
p_e_grid_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
// tensor descriptors for problem definiton
index_t
num_group_
;
...
...
@@ -514,17 +654,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
// block-to-e-tile map
Block2TileMapElementwise
elementwise_block_2_ctile_map_transpose_a_
,
elementwise_block_2_ctile_map_transpose_e_
;
NGCHWTransposeDescType
a_in_transpose_desc_
,
e_out_transpose_desc_
;
NHWGCTransposeDescType
a_out_transpose_desc_
,
e_in_transpose_desc_
;
};
// Invoker
...
...
@@ -532,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
Gemm
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
...
...
@@ -561,8 +696,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
index_t
K_split
=
(
GemmK
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
ADataType
*
p_a_grid
=
arg
.
p_a_grid_
;
EDataType
*
p_e_grid
=
arg
.
p_e_grid_
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
p_a_grid
=
type_convert
<
const
ADataType
*>
(
arg
.
p_workspace_
);
p_e_grid
=
type_convert
<
EDataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceATensorSizeBytes
()
/
sizeof
(
EDataType
);
}
typename
GridwiseGemm
::
Argument
gemm_arg
{
arg
.
p_a_grid
_
,
arg
.
p_b_grid_
,
arg
.
p_e_grid
_
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
I1
};
p_a_grid
,
arg
.
p_b_grid_
,
p_e_grid
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
I1
};
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
stream_config
.
flush_cache
)
...
...
@@ -857,6 +1003,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return
ave_time
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0.
f
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_transpose_a_
.
CalculateGridSize
(
arg
.
a_in_transpose_desc_
);
ADataType
*
p_a_out_grid
=
type_convert
<
ADataType
*>
(
arg
.
p_workspace_
);
auto
kernel_transpose
=
kernel_elementwise
<
GridwiseElementwiseInputTranspose
,
ck
::
Tuple
<
NGCHWTransposeDescType
>
,
ck
::
Tuple
<
NHWGCTransposeDescType
>
,
ck
::
Tuple
<
const
ADataType
*>
,
ck
::
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size
),
dim3
(
ElementwiseBlocksize
),
0
,
make_tuple
(
arg
.
a_in_transpose_desc_
),
make_tuple
(
arg
.
a_out_transpose_desc_
),
make_tuple
(
arg
.
p_a_grid_
),
make_tuple
(
p_a_out_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_a_
,
element_wise
::
PassThrough
{});
}
avg_time
+=
RunGemm
(
arg
,
stream_config
);
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_transpose_e_
.
CalculateGridSize
(
arg
.
e_in_transpose_desc_
);
const
EDataType
*
p_e_out_grid
=
type_convert
<
EDataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceATensorSizeBytes
()
/
sizeof
(
EDataType
);
EDataType
*
p_e_in_grid
=
arg
.
p_e_grid_
;
auto
kernel_transpose
=
kernel_elementwise
<
GridwiseElementwiseOutputTranspose
,
ck
::
Tuple
<
NHWGCTransposeDescType
>
,
ck
::
Tuple
<
NGCHWTransposeDescType
>
,
ck
::
Tuple
<
const
EDataType
*>
,
ck
::
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size
),
dim3
(
ElementwiseBlocksize
),
0
,
make_tuple
(
arg
.
e_in_transpose_desc_
),
make_tuple
(
arg
.
e_out_transpose_desc_
),
make_tuple
(
p_e_out_grid
),
make_tuple
(
p_e_in_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_e_
,
element_wise
::
PassThrough
{});
}
return
avg_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
...
...
@@ -868,6 +1087,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
namespace
ctc
=
tensor_layout
::
convolution
;
const
index_t
G
=
arg
.
b_g_k_c_xs_lengths_
[
I0
];
const
index_t
K
=
arg
.
b_g_k_c_xs_lengths_
[
I1
];
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
I2
];
// check device
if
(
get_device_name
()
==
"gfx908"
)
{
...
...
@@ -924,10 +1147,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
is_same_v
<
ALayout
,
ctc
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
GNWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNDHWC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NGCW
>
||
is_same_v
<
ALayout
,
ctc
::
NGCHW
>
||
is_same_v
<
ALayout
,
ctc
::
NGCDHW
>
)
{
const
index_t
C
=
arg
.
a_g_n_c_wis_lengths_
[
2
];
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
...
...
@@ -947,8 +1169,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
{
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
2
];
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
...
...
@@ -959,15 +1179,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return
false
;
}
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
if
((
G
*
C
)
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
if
((
G
*
K
)
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
const
index_t
input_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
a_g_n_c_wis_lengths_
.
begin
()
+
I3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
output_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
e_g_n_k_wos_lengths_
.
begin
()
+
I3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
input_spatial_acum
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
if
(
output_spatial_acum
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
// check vector access of E
if
constexpr
(
is_same_v
<
ELayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
GNWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NGKW
>
||
is_same_v
<
ELayout
,
ctc
::
NGKHW
>
||
is_same_v
<
ELayout
,
ctc
::
NGKDHW
>
)
{
const
index_t
K
=
arg
.
e_g_n_k_wos_lengths_
[
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
...
...
@@ -1279,6 +1527,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return
str
.
str
();
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
auto
arg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg
)
{
return
arg
->
GetWorkspaceSizeBytes
();
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
p_arg_
->
p_workspace_
=
p_workspace
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"
);
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
09d4c3a4
...
...
@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK()
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNWK
>
;
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NGCW_GKXC_NGKW
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NGCW
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NGKW
>
;
}
// 2d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NHWGC_GKYXC_NHWGK
()
...
...
@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
is_GNDHWC_GKZYXC_GNDHWK
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NGCSpatial_GKSpatial_NGKSpatial
()
{
return
is_NGCW_GKXC_NGKW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
struct
ComputePtrOffsetOfStridedBatch
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp"
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -16,95 +27,359 @@ template <typename InDataType,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
,
ck
::
index_t
BlockSize
,
ck
::
index_t
Reduce
MThreadClusterSize
,
ck
::
index_t
Reduce
KThreadClusterSize
,
ck
::
index_t
Reduce
MThreadSliceSize
,
ck
::
index_t
Reduce
KThreadSliceSize
,
ck
::
index_t
MThreadClusterSize
,
ck
::
index_t
KThreadClusterSize
,
ck
::
index_t
MThreadSliceSize
,
ck
::
index_t
KThreadSliceSize
,
ck
::
index_t
InSrcOutDstVectorSize
>
struct
DevicePool2dFwd_NHWC_NHWC
:
public
DevicePool3dFwd_NDHWC_NDHWC
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
BlockSize
,
ReduceMThreadClusterSize
,
ReduceKThreadClusterSize
,
ReduceMThreadSliceSize
,
ReduceKThreadSliceSize
,
InSrcOutDstVectorSize
>
struct
DevicePool2dFwd_NHWC_NHWC
:
public
DevicePoolFwd
<
4
,
2
,
InDataType
,
OutDataType
,
IndexDataType
,
tensor_layout
::
convolution
::
NHWC
,
tensor_layout
::
convolution
::
NHWC
,
ReduceOpId
,
OutputIndex
>
{
using
DevicePool3D
=
DevicePool3dFwd_NDHWC_NDHWC
<
InDataType
,
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
InOutRank
=
4
;
static
constexpr
index_t
WindowRank
=
2
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
static
constexpr
ck
::
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
ck
::
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeABGridDescriptor_A_M_K_B_M
(
std
::
vector
<
ck
::
index_t
>
input_nchw_lengths
,
std
::
vector
<
ck
::
index_t
>
output_nchw_lengths
,
std
::
vector
<
ck
::
index_t
>
input_nchw_stride
,
std
::
vector
<
ck
::
index_t
>
output_nchw_stride
,
std
::
vector
<
ck
::
index_t
>
window_spatial_yx_lengths
,
std
::
vector
<
ck
::
index_t
>
window_yx_strides
,
std
::
vector
<
ck
::
index_t
>
window_yx_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_hw_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_hw_pads
)
{
const
index_t
N
=
input_nchw_lengths
[
0
];
const
index_t
C
=
input_nchw_lengths
[
1
];
const
index_t
Hi
=
input_nchw_lengths
[
2
];
const
index_t
Wi
=
input_nchw_lengths
[
3
];
const
index_t
Ho
=
output_nchw_lengths
[
2
];
const
index_t
Wo
=
output_nchw_lengths
[
3
];
const
index_t
Y
=
window_spatial_yx_lengths
[
0
];
const
index_t
X
=
window_spatial_yx_lengths
[
1
];
const
index_t
WindowStrideH
=
window_yx_strides
[
0
];
const
index_t
WindowStrideW
=
window_yx_strides
[
1
];
const
index_t
WindowDilationH
=
window_yx_dilations
[
0
];
const
index_t
WindowDilationW
=
window_yx_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_hw_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_hw_pads
[
1
];
const
index_t
InRightPadH
=
input_right_hw_pads
[
0
];
const
index_t
InRightPadW
=
input_right_hw_pads
[
1
];
const
index_t
MRaw
=
N
*
Ho
*
Wo
*
C
;
const
index_t
MPad
=
math
::
integer_least_multiple
(
MRaw
,
M_BlockTileSize
)
-
MRaw
;
const
index_t
KRaw
=
Y
*
X
;
const
index_t
KPad
=
math
::
integer_least_multiple
(
KRaw
,
K_BlockTileSize
)
-
KRaw
;
// A[ReduceM, ReduceK]
const
index_t
Ni_stride
=
input_nchw_stride
[
0
];
const
index_t
Ci_stride
=
input_nchw_stride
[
1
];
const
index_t
Hi_stride
=
input_nchw_stride
[
2
];
const
index_t
Wi_stride
=
input_nchw_stride
[
3
];
const
auto
in_grid_desc_n_hi_wi_c
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
Ni_stride
,
Hi_stride
,
Wi_stride
,
Ci_stride
));
const
auto
in_grid_desc_n_hip_wip_c
=
transform_tensor_descriptor
(
in_grid_desc_n_hi_wi_c
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_grid_desc_n_y_ho_x_wo_c
=
transform_tensor_descriptor
(
in_grid_desc_n_hip_wip_c
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
WindowDilationH
,
WindowStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
WindowDilationW
,
WindowStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_grid_desc_reducemraw_reducekraw
=
transform_tensor_descriptor
(
in_grid_desc_n_y_ho_x_wo_c
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
,
C
)),
make_merge_transform
(
make_tuple
(
Y
,
X
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_grid_desc_reducem_reducek
=
transform_tensor_descriptor
(
in_grid_desc_reducemraw_reducekraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B[ReduceM]
const
index_t
No_stride
=
output_nchw_stride
[
0
];
const
index_t
Co_stride
=
output_nchw_stride
[
1
];
const
index_t
Ho_stride
=
output_nchw_stride
[
2
];
const
index_t
Wo_stride
=
output_nchw_stride
[
3
];
const
auto
out_grid_desc_n_ho_wo_c
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
No_stride
,
Ho_stride
,
Wo_stride
,
Co_stride
));
const
auto
out_grid_desc_reducemraw
=
transform_tensor_descriptor
(
out_grid_desc_n_ho_wo_c
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
,
C
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
out_grid_desc_reducem
=
transform_tensor_descriptor
(
out_grid_desc_reducemraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
}
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
({},
{},
{},
{},
{},
{},
{},
{},
{}));
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_dev
,
OutDataType
*
p_out_dev
,
IndexDataType
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>&
input_nchw_lengths
,
std
::
vector
<
ck
::
index_t
>&
output_nchw_lengths
,
std
::
vector
<
ck
::
index_t
>&
input_nchw_stride
,
std
::
vector
<
ck
::
index_t
>&
output_nchw_stride
,
std
::
vector
<
ck
::
index_t
>&
,
// indices_nchw_stride
std
::
vector
<
ck
::
index_t
>&
window_spatial_yx_lengths
,
std
::
vector
<
ck
::
index_t
>&
window_yx_strides
,
std
::
vector
<
ck
::
index_t
>&
window_yx_dilations
,
std
::
vector
<
ck
::
index_t
>&
input_left_hw_pads
,
std
::
vector
<
ck
::
index_t
>&
input_right_hw_pads
)
:
p_in_dev_
{
p_in_dev
},
p_out_dev_
{
p_out_dev
},
p_out_indices_dev_
{
p_out_indices_dev
},
a_grid_desc_m_k_
{},
b_grid_desc_m_
{},
input_nchw_lengths_
{
input_nchw_lengths
},
output_nchw_lengths_
{
output_nchw_lengths
},
input_nchw_stride_
{
input_nchw_stride
},
output_nchw_stride_
{
output_nchw_stride
}
{
const
auto
descs
=
MakeABGridDescriptor_A_M_K_B_M
(
input_nchw_lengths
,
output_nchw_lengths
,
input_nchw_stride
,
output_nchw_stride
,
window_spatial_yx_lengths
,
window_yx_strides
,
window_yx_dilations
,
input_left_hw_pads
,
input_right_hw_pads
);
a_grid_desc_m_k_
=
descs
[
I0
];
b_grid_desc_m_
=
descs
[
I1
];
int32_t
reduceLength
=
window_spatial_yx_lengths
[
0
]
*
window_spatial_yx_lengths
[
1
];
std
::
tie
(
in_element_op_
,
acc_element_op_
)
=
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
reduceLength
);
}
const
InDataType
*
p_in_dev_
;
OutDataType
*
p_out_dev_
;
IndexDataType
*
p_out_indices_dev_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_M
b_grid_desc_m_
;
InElementwiseOperation
in_element_op_
;
AccElementwiseOperation
acc_element_op_
;
// for checking vector load/store
std
::
vector
<
ck
::
index_t
>
input_nchw_lengths_
;
std
::
vector
<
ck
::
index_t
>
output_nchw_lengths_
;
std
::
vector
<
ck
::
index_t
>
input_nchw_stride_
;
std
::
vector
<
ck
::
index_t
>
output_nchw_stride_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
// for NHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static
constexpr
index_t
InSrcOutDstVectorDim
=
0
;
// 0: M, 1: K
using
gridwise_reduce
=
GridwiseReduction_mk_to_m_threadwise
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
false
,
// propagate_nan
BlockSize
,
Reduce
MThread
Cluster
Size
,
Reduce
KThread
Cluster
Size
,
ReduceMThreadSliceSize
,
ReduceKThreadSlice
Size
,
MThread
Slice
Size
,
KThread
Slice
Size
,
InSrcOutDstVectorDim
,
InSrcOutDstVector
Size
,
InSrcOutDstVectorSize
>
;
std
::
unique_ptr
<
BaseArgument
>
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
OutputIndex
,
true
,
// pooling need to return global index
false
,
// don't have index input
InDataType
,
OutDataType
,
ComputeDataType
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
InElementwiseOperation
,
AccElementwiseOperation
>
;
ck
::
index_t
M
=
arg
.
a_grid_desc_m_k_
.
GetLength
(
I0
);
const
index_t
grid_size
=
(
M
/
M_BlockTileSize
);
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_m_
,
arg
.
in_element_op_
,
arg
.
acc_element_op_
,
float
(
1
),
arg
.
p_in_dev_
,
nullptr
,
float
(
0
),
arg
.
p_out_dev_
,
arg
.
p_out_indices_dev_
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
// C should be fastest dimension
if
(
pArg
->
input_nchw_stride_
[
1
]
!=
1
)
return
false
;
for
(
int
i
=
0
;
i
<
InOutRank
;
++
i
)
{
if
(
pArg
->
input_nchw_stride_
[
i
]
==
1
&&
pArg
->
input_nchw_lengths_
[
i
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
if
(
pArg
->
output_nchw_stride_
[
i
]
==
1
&&
pArg
->
output_nchw_lengths_
[
i
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
}
return
true
;
}
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
input_stride
,
std
::
vector
<
ck
::
index_t
>
output_stride
,
std
::
vector
<
ck
::
index_t
>
indices_stride
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
input_
nchw_
lengths
,
std
::
vector
<
ck
::
index_t
>
window_
yx_
lengths
,
std
::
vector
<
ck
::
index_t
>
output_
nchw_
lengths
,
std
::
vector
<
ck
::
index_t
>
input_
nchw_
stride
,
std
::
vector
<
ck
::
index_t
>
output_
nchw_
stride
,
std
::
vector
<
ck
::
index_t
>
indices_
nchw_
stride
,
std
::
vector
<
ck
::
index_t
>
window_
yx_
strides
,
std
::
vector
<
ck
::
index_t
>
window_
yx_
dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_
hw_
pads
,
std
::
vector
<
ck
::
index_t
>
input_right_
hw_
pads
,
std
::
vector
<
ck
::
index_t
>
pooling_dims
)
override
{
static
constexpr
index_t
InOutRank
=
4
;
static
constexpr
index_t
WindowRank
=
2
;
if
(
input_lengths
.
size
()
!=
InOutRank
||
window_lengths
.
size
()
!=
WindowRank
||
input_lengths
.
size
()
!=
InOutRank
||
window_strides
.
size
()
!=
WindowRank
||
window_dilations
.
size
()
!=
WindowRank
||
input_left_pads
.
size
()
!=
WindowRank
||
input_right_pads
.
size
()
!=
WindowRank
)
if
(
input_nchw_lengths
.
size
()
!=
InOutRank
||
window_yx_lengths
.
size
()
!=
WindowRank
||
input_nchw_lengths
.
size
()
!=
InOutRank
||
window_yx_strides
.
size
()
!=
WindowRank
||
window_yx_dilations
.
size
()
!=
WindowRank
||
input_left_hw_pads
.
size
()
!=
WindowRank
||
input_right_hw_pads
.
size
()
!=
WindowRank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
if
(
pooling_dims
!=
std
::
vector
<
ck
::
index_t
>
{
2
,
3
})
throw
std
::
runtime_error
(
"pooling_dims only support {2, 3} in pool2d so far"
);
// NCHW to NCDHW
input_lengths
.
insert
(
input_lengths
.
begin
()
+
2
,
1
);
output_lengths
.
insert
(
output_lengths
.
begin
()
+
2
,
1
);
input_stride
.
insert
(
input_stride
.
begin
()
+
2
,
0
);
output_stride
.
insert
(
output_stride
.
begin
()
+
2
,
0
);
indices_stride
.
insert
(
indices_stride
.
begin
()
+
2
,
0
);
// YX to ZYX
window_lengths
.
insert
(
window_lengths
.
begin
(),
1
);
window_strides
.
insert
(
window_strides
.
begin
(),
0
);
window_dilations
.
insert
(
window_dilations
.
begin
(),
0
);
input_left_pads
.
insert
(
input_left_pads
.
begin
(),
0
);
input_right_pads
.
insert
(
input_right_pads
.
begin
(),
0
);
pooling_dims
=
{
2
,
3
,
4
};
return
DevicePool3D
::
MakeArgumentPointer
(
p_in_dev
,
p_out_dev
,
p_out_indices_dev
,
input_lengths
,
window_lengths
,
output_lengths
,
input_stride
,
output_stride
,
indices_stride
,
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
,
pooling_dims
);
if
(
output_nchw_stride
!=
indices_nchw_stride
)
throw
std
::
runtime_error
(
"output_nchw_stride need to be equal to indices_nchw_stride for now"
);
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_dev
),
static_cast
<
OutDataType
*>
(
p_out_dev
),
static_cast
<
IndexDataType
*>
(
p_out_indices_dev
),
input_nchw_lengths
,
output_nchw_lengths
,
input_nchw_stride
,
output_nchw_stride
,
indices_nchw_stride
,
window_yx_lengths
,
window_yx_strides
,
window_yx_dilations
,
input_left_hw_pads
,
input_right_hw_pads
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DevicePool2dFwd_NHWC_NHWC<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcOutDstVectorSize_"
<<
InSrcOutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
09d4c3a4
...
...
@@ -355,12 +355,39 @@ struct UnaryDivide
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
/
type_convert
<
T
>
(
divider_
);
};
template
<
>
__host__
__device__
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
half_t
>
(
x_
/
divider_f_
);
};
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
bhalf_t
>
(
x_
/
divider_f_
);
};
template
<
>
__host__
__device__
void
operator
()
<
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
f8_t
>
(
x_
/
divider_f_
);
};
int32_t
divider_
=
1
;
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
09d4c3a4
...
...
@@ -221,7 +221,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
__host__
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
...
...
@@ -303,7 +303,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
...
...
@@ -576,12 +576,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
M
;
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideA
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
N
;
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideB
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
09d4c3a4
...
...
@@ -255,7 +255,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
__host__
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
...
...
@@ -337,7 +337,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
}
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
...
...
@@ -647,12 +647,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
M
;
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideA
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
N
;
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideB
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
View file @
09d4c3a4
...
...
@@ -315,7 +315,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
]
;
index_t
tmp
=
0
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
...
...
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
View file @
09d4c3a4
...
...
@@ -324,55 +324,55 @@ struct DppSelector
static
constexpr
auto
GetDpp
();
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
{
return
DppInstr
::
dpp8_f16_8x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
{
return
DppInstr
::
dpp8_f16_8x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
{
return
DppInstr
::
dpp8_f16_16x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
{
return
DppInstr
::
dpp8_f16_32x8x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
{
return
DppInstr
::
dpp8_f16_1x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
{
return
DppInstr
::
dpp8_f16_2x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
{
return
DppInstr
::
dpp8_f16_2x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
{
return
DppInstr
::
dpp8_f16_4x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
{
return
DppInstr
::
dpp8_f16_4x32x2
;
}
...
...
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
View file @
09d4c3a4
...
...
@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -305,8 +329,8 @@ struct SparseXdlopsGemm
"base base_type must be half or bfloat16!"
);
static_for
<
0
,
KPack
/
smfmac_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
],
p_c_thread
);
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
,
k
%
4
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
/
4
],
p_c_thread
);
});
}
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
09d4c3a4
...
...
@@ -415,7 +415,7 @@ struct WmmaSelector
static
constexpr
auto
GetWmma
();
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
;
...
...
@@ -425,7 +425,7 @@ struct WmmaSelector
}
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
;
...
...
@@ -435,19 +435,19 @@ struct WmmaSelector
}
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_f16_16x16x16_f16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
;
...
...
@@ -458,7 +458,7 @@ struct WmmaSelector
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
}
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
09d4c3a4
...
...
@@ -651,97 +651,97 @@ struct MfmaSelector
static
constexpr
auto
GetMfma
();
template
<
>
static
constexpr
auto
GetMfma
<
double
,
16
,
16
>
()
constexpr
auto
GetMfma
<
double
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f64_16x16x4f64
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
64
,
64
>
()
constexpr
auto
GetMfma
<
float
,
64
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
32
,
64
>
()
constexpr
auto
GetMfma
<
float
,
32
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
16
,
64
>
()
constexpr
auto
GetMfma
<
float
,
16
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
8
,
64
>
()
constexpr
auto
GetMfma
<
float
,
8
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
4
,
64
>
()
constexpr
auto
GetMfma
<
float
,
4
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
32
,
32
>
()
constexpr
auto
GetMfma
<
float
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x2xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
16
,
16
>
()
constexpr
auto
GetMfma
<
float
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x4xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
64
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
64
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
32
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
16
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
8
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
8
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
4
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
4
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
...
...
@@ -751,7 +751,7 @@ struct MfmaSelector
}
template
<
>
static
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
...
...
@@ -762,72 +762,72 @@ struct MfmaSelector
#if defined(CK_USE_AMD_MFMA_GFX940)
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x16i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x32i8
;
}
#else
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x16i8
;
}
#endif
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16f8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16f8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
,
bf8_t
>
()
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
,
bf8_t
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16bf8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
,
f8_t
>
()
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
,
f8_t
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
...
...
include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
0 → 100644
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace
ck
{
namespace
tensor_operation
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
index_t
NDimSpatial
,
index_t
MPerThread
,
index_t
NPerThread
>
struct
TransformConvNGCHWToNHWGC
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeNGCHWTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
I0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
I2
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
I3
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeNHWGCTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeNGCHWTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
I0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
I2
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
I3
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
I4
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeNHWGCTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeNGCHWTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I5
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
I0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
I2
];
const
index_t
&
DiStride
=
g_n_c_wis_strides
[
I3
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
I4
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
I5
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeNHWGCTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I5
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
DiStride
=
Hi
*
Wi
*
G
*
C
;
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
static
auto
TransposeStrides
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
g_n_c_wis_strides
)
{
if
constexpr
(
device
::
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
device
::
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides_transposed
;
const
auto
G
=
g_n_c_wis_lengths
[
I0
];
const
auto
C
=
g_n_c_wis_lengths
[
I2
];
g_n_c_wis_strides_transposed
[
I0
]
=
C
;
g_n_c_wis_strides_transposed
[
I1
]
=
g_n_c_wis_strides
[
I1
];
g_n_c_wis_strides_transposed
[
I2
]
=
I1
;
if
constexpr
(
NDimSpatial
==
2
)
{
g_n_c_wis_strides_transposed
[
I3
]
=
g_n_c_wis_lengths
[
I4
]
*
G
*
C
;
g_n_c_wis_strides_transposed
[
I4
]
=
G
*
C
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
g_n_c_wis_strides_transposed
[
I3
]
=
g_n_c_wis_lengths
[
I4
]
*
g_n_c_wis_lengths
[
I5
]
*
G
*
C
;
g_n_c_wis_strides_transposed
[
I4
]
=
g_n_c_wis_lengths
[
I5
]
*
G
*
C
;
g_n_c_wis_strides_transposed
[
I5
]
=
G
*
C
;
}
return
g_n_c_wis_strides_transposed
;
}
else
{
// transpose not needed
return
g_n_c_wis_strides
;
}
}
};
}
// namespace tensor_operation
}
// namespace ck
include/ck/utility/amd_smfmac.hpp
View file @
09d4c3a4
...
...
@@ -9,16 +9,18 @@ namespace ck {
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32f16
;
// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse
// indices from reg_idx
template
<
>
struct
intrin_smfmac_f32_16x16x32f16
<
16
,
16
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
@@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16;
template
<
>
struct
intrin_smfmac_f32_16x16x32bf16
<
16
,
16
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
@@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16;
template
<
>
struct
intrin_smfmac_f32_32x32x16f16
<
32
,
32
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
@@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16;
template
<
>
struct
intrin_smfmac_f32_32x32x16bf16
<
32
,
32
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
include/ck/utility/reduction_operator.hpp
View file @
09d4c3a4
...
...
@@ -52,12 +52,28 @@ struct Add
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half
_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8
_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
a
=
a
+
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
f8_t
>
(
a_
+
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
half_t
>
(
a_
+
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
...
...
@@ -112,12 +128,28 @@ struct Mul
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half
_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8
_t
>::
value
,
"The data type is not supported by the Mul accumulator!"
);
a
=
a
*
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
f8_t
>
(
a_
*
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
half_t
>
(
a_
*
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
...
...
@@ -137,6 +169,16 @@ struct Max
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
bhalf_t
>
(
val
);
}
if
constexpr
(
is_same_v
<
T
,
f8_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
f8_t
>
(
val
);
}
if
constexpr
(
is_same_v
<
T
,
half_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
half_t
>
(
val
);
}
else
{
return
NumericLimits
<
T
>::
Lowest
();
...
...
@@ -154,8 +196,7 @@ struct Max
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
...
...
@@ -171,12 +212,29 @@ struct Max
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
...
...
@@ -197,6 +255,30 @@ struct Max
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
struct
Min
...
...
@@ -209,6 +291,16 @@ struct Min
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
bhalf_t
>
(
val
);
}
else
if
constexpr
(
is_same_v
<
T
,
half_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
half_t
>
(
val
);
}
else
if
constexpr
(
is_same_v
<
T
,
f8_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
f8_t
>
(
val
);
}
else
{
return
NumericLimits
<
T
>::
Max
();
...
...
@@ -227,8 +319,7 @@ struct Min
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Min accumulator!"
);
if
(
a
>
b
)
...
...
@@ -244,6 +335,24 @@ struct Min
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
...
...
@@ -270,6 +379,30 @@ struct Min
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
{
a
=
b
;
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
struct
AMax
...
...
@@ -299,6 +432,15 @@ struct AMax
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
...
...
@@ -313,6 +455,18 @@ struct AMax
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
template
<
typename
T
>
...
...
@@ -352,7 +506,8 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set,
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
bhalf_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
||
is_same
<
DataType
,
f8_t
>::
value
;
};
template
<
typename
DataType
>
...
...
include/ck_tile/core/container/array.hpp
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <initializer_list>
#include <vector>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
...
...
@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
return
!
(
a
==
b
);
}
template
<
typename
T
,
index_t
N
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
std
::
vector
<
X
>&
x
)
{
array
<
T
,
N
>
arr
;
static_for
<
0
,
N
,
1
>
{}([
&
x
,
&
arr
](
auto
i
)
{
arr
(
i
)
=
x
[
i
];
});
return
arr
;
}
template
<
typename
T
,
index_t
N
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
X
&
x
)
{
...
...
include/ck_tile/host.hpp
View file @
09d4c3a4
...
...
@@ -5,6 +5,8 @@
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp"
...
...
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
0 → 100644
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
namespace
conv
{
namespace
detail
{
template
<
typename
OldLayout
>
CK_TILE_HOST
std
::
vector
<
std
::
size_t
>
get_layout_transpose_gnchw_to_old
()
{
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
)
{
return
{
0
,
1
,
2
,
3
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
)
{
return
{
0
,
1
,
2
,
3
,
4
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
{
return
{
0
,
1
,
2
,
3
,
4
,
5
};
}
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
)
{
return
{
0
,
1
,
3
,
2
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
)
{
return
{
0
,
1
,
4
,
2
,
3
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
{
return
{
0
,
1
,
5
,
2
,
3
,
4
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
)
{
return
{
2
,
0
,
3
,
1
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
)
{
return
{
3
,
0
,
4
,
1
,
2
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
{
return
{
4
,
0
,
5
,
1
,
2
,
3
};
}
else
{
printf
(
"%s
\n
"
,
__func__
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
}
}
// namespace detail
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
// regardless of physical layout
template
<
typename
InLayout
>
CK_TILE_HOST
HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
2
,
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
1
,
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
{
printf
(
"%s
\n
"
,
__func__
);
printf
(
"%s
\n
"
,
InLayout
::
name
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
return
transpose_host_tensor_descriptor_given_new2old
(
HostTensorDescriptor
(
physical_lengths
),
detail
::
get_layout_transpose_gnchw_to_old
<
InLayout
>
());
}
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
// regardless of physical layout
template
<
typename
WeiLayout
>
CK_TILE_HOST
HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXC
>
)
{
if
(
param
.
G_
!=
1
)
{
throw
std
::
runtime_error
(
"wrong! G != 1"
);
}
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
2
,
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
1
,
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
{
printf
(
"%s
\n
"
,
__func__
);
printf
(
"%s
\n
"
,
WeiLayout
::
name
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
return
transpose_host_tensor_descriptor_given_new2old
(
HostTensorDescriptor
(
physical_lengths
),
detail
::
get_layout_transpose_gnchw_to_old
<
WeiLayout
>
());
}
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
// regardless of physical layout
template
<
typename
OutLayout
>
CK_TILE_HOST
HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
// separate from legacy code above
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
2
,
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
1
,
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
{
printf
(
"%s
\n
"
,
__func__
);
printf
(
"%s
\n
"
,
OutLayout
::
name
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
return
transpose_host_tensor_descriptor_given_new2old
(
HostTensorDescriptor
(
physical_lengths
),
detail
::
get_layout_transpose_gnchw_to_old
<
OutLayout
>
());
}
}
// namespace conv
}
// namespace ck_tile
include/ck_tile/host/convolution_parameter.hpp
0 → 100644
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>
namespace
ck_tile
{
namespace
conv
{
struct
ConvParam
{
ConvParam
();
ConvParam
(
ck_tile
::
index_t
n_dim
,
ck_tile
::
index_t
group_count
,
ck_tile
::
index_t
n_batch
,
ck_tile
::
index_t
n_out_channels
,
ck_tile
::
index_t
n_in_channels
,
const
std
::
vector
<
ck_tile
::
index_t
>&
filters_len
,
const
std
::
vector
<
ck_tile
::
index_t
>&
input_len
,
const
std
::
vector
<
ck_tile
::
index_t
>&
strides
,
const
std
::
vector
<
ck_tile
::
index_t
>&
dilations
,
const
std
::
vector
<
ck_tile
::
index_t
>&
left_pads
,
const
std
::
vector
<
ck_tile
::
index_t
>&
right_pads
)
:
num_dim_spatial_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_dim
)),
G_
(
static_cast
<
ck_tile
::
long_index_t
>
(
group_count
)),
N_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_batch
)),
K_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_out_channels
)),
C_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_in_channels
)),
filter_spatial_lengths_
(
num_dim_spatial_
),
input_spatial_lengths_
(
num_dim_spatial_
),
output_spatial_lengths_
(
num_dim_spatial_
),
conv_filter_strides_
(
num_dim_spatial_
),
conv_filter_dilations_
(
num_dim_spatial_
),
input_left_pads_
(
num_dim_spatial_
),
input_right_pads_
(
num_dim_spatial_
)
{
if
(
static_cast
<
ck_tile
::
index_t
>
(
filter_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_strides_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_dilations_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_left_pads_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_right_pads_
.
size
())
!=
num_dim_spatial_
)
{
throw
(
std
::
runtime_error
(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"
));
}
for
(
ck_tile
::
index_t
i
=
0
;
i
<
num_dim_spatial_
;
++
i
)
{
filter_spatial_lengths_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
filters_len
[
i
]);
input_spatial_lengths_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
input_len
[
i
]);
conv_filter_strides_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
strides
[
i
]);
conv_filter_dilations_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
dilations
[
i
]);
input_left_pads_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
left_pads
[
i
]);
input_right_pads_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
right_pads
[
i
]);
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck_tile
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_filter_dilations_
[
i
]
+
1
;
output_spatial_lengths_
[
i
]
=
(
input_spatial_lengths_
[
i
]
+
input_left_pads_
[
i
]
+
input_right_pads_
[
i
]
-
x_eff
)
/
conv_filter_strides_
[
i
]
+
1
;
}
}
ConvParam
(
ck_tile
::
long_index_t
n_dim
,
ck_tile
::
long_index_t
group_count
,
ck_tile
::
long_index_t
n_batch
,
ck_tile
::
long_index_t
n_out_channels
,
ck_tile
::
long_index_t
n_in_channels
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
filters_len
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
input_len
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
strides
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
dilations
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
left_pads
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
right_pads
)
:
num_dim_spatial_
(
n_dim
),
G_
(
group_count
),
N_
(
n_batch
),
K_
(
n_out_channels
),
C_
(
n_in_channels
),
filter_spatial_lengths_
(
filters_len
),
input_spatial_lengths_
(
input_len
),
output_spatial_lengths_
(
num_dim_spatial_
),
conv_filter_strides_
(
strides
),
conv_filter_dilations_
(
dilations
),
input_left_pads_
(
left_pads
),
input_right_pads_
(
right_pads
)
{
if
(
static_cast
<
ck_tile
::
index_t
>
(
filter_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_strides_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_dilations_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_left_pads_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_right_pads_
.
size
())
!=
num_dim_spatial_
)
{
throw
(
std
::
runtime_error
(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"
));
}
for
(
ck_tile
::
index_t
i
=
0
;
i
<
num_dim_spatial_
;
++
i
)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck_tile
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_filter_dilations_
[
i
]
+
1
;
output_spatial_lengths_
[
i
]
=
(
input_spatial_lengths_
[
i
]
+
input_left_pads_
[
i
]
+
input_right_pads_
[
i
]
-
x_eff
)
/
conv_filter_strides_
[
i
]
+
1
;
}
}
ck_tile
::
long_index_t
num_dim_spatial_
;
ck_tile
::
long_index_t
G_
;
ck_tile
::
long_index_t
N_
;
ck_tile
::
long_index_t
K_
;
ck_tile
::
long_index_t
C_
;
std
::
vector
<
ck_tile
::
long_index_t
>
filter_spatial_lengths_
;
std
::
vector
<
ck_tile
::
long_index_t
>
input_spatial_lengths_
;
std
::
vector
<
ck_tile
::
long_index_t
>
output_spatial_lengths_
;
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_strides_
;
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_dilations_
;
std
::
vector
<
ck_tile
::
long_index_t
>
input_left_pads_
;
std
::
vector
<
ck_tile
::
long_index_t
>
input_right_pads_
;
std
::
vector
<
ck_tile
::
long_index_t
>
GetOutputSpatialLengths
()
const
{
return
output_spatial_lengths_
;
}
std
::
size_t
GetFlops
()
const
{
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return
static_cast
<
std
::
size_t
>
(
2
)
*
G_
*
N_
*
K_
*
C_
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths_
),
std
::
next
(
std
::
begin
(
output_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
())
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
std
::
next
(
std
::
begin
(
filter_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
());
}
template
<
typename
InDataType
>
std
::
size_t
GetInputByte
()
const
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return
sizeof
(
InDataType
)
*
(
G_
*
N_
*
C_
*
std
::
accumulate
(
std
::
begin
(
input_spatial_lengths_
),
std
::
next
(
std
::
begin
(
input_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
()));
}
template
<
typename
WeiDataType
>
std
::
size_t
GetWeightByte
()
const
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return
sizeof
(
WeiDataType
)
*
(
G_
*
K_
*
C_
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
std
::
next
(
std
::
begin
(
filter_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
()));
}
template
<
typename
OutDataType
>
std
::
size_t
GetOutputByte
()
const
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return
sizeof
(
OutDataType
)
*
(
G_
*
N_
*
K_
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths_
),
std
::
end
(
output_spatial_lengths_
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
()));
}
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
std
::
size_t
GetByte
()
const
{
return
GetInputByte
<
InDataType
>
()
+
GetWeightByte
<
WeiDataType
>
()
+
GetOutputByte
<
OutDataType
>
();
}
};
ConvParam
::
ConvParam
()
:
ConvParam
::
ConvParam
(
2
,
1
,
128
,
256
,
192
,
{
3
,
3
},
{
71
,
71
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
})
{
}
CK_TILE_HOST
std
::
string
get_conv_param_parser_helper_msg
()
{
std
::
string
msg
;
msg
+=
"Following arguments (depending on number of spatial dims):
\n
"
" Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)
\n
"
" G, N, K, C,
\n
"
" <filter spatial dimensions>, (ie Y, X for 2D)
\n
"
" <input image spatial dimensions>, (ie Hi, Wi for 2D)
\n
"
" <strides>, (ie Sy, Sx for 2D)
\n
"
" <dilations>, (ie Dy, Dx for 2D)
\n
"
" <left padding>, (ie LeftPy, LeftPx for 2D)
\n
"
" <right padding>, (ie RightPy, RightPx for 2D)
\n
"
;
return
msg
;
}
CK_TILE_HOST
ck_tile
::
conv
::
ConvParam
parse_conv_param
(
int
num_dim_spatial
,
int
arg_idx
,
char
*
const
argv
[])
{
const
ck_tile
::
long_index_t
G
=
std
::
stol
(
argv
[
arg_idx
++
]);
const
ck_tile
::
long_index_t
N
=
std
::
stol
(
argv
[
arg_idx
++
]);
const
ck_tile
::
long_index_t
K
=
std
::
stol
(
argv
[
arg_idx
++
]);
const
ck_tile
::
long_index_t
C
=
std
::
stol
(
argv
[
arg_idx
++
]);
std
::
vector
<
ck_tile
::
long_index_t
>
filter_spatial_lengths
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
input_spatial_lengths
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_strides
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_dilations
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
input_left_pads
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
input_right_pads
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
filter_spatial_lengths
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_spatial_lengths
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
conv_filter_strides
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
conv_filter_dilations
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_left_pads
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_right_pads
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
return
ck_tile
::
conv
::
ConvParam
{
num_dim_spatial
,
G
,
N
,
K
,
C
,
filter_spatial_lengths
,
input_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
}
}
// namespace conv
}
// namespace ck_tile
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment