Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
09d4c3a4
Commit
09d4c3a4
authored
Oct 01, 2024
by
illsilin
Browse files
merge from public repo
parents
171ed358
8e4c3fb1
Changes
202
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2184 additions
and
469 deletions
+2184
-469
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+42
-185
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+354
-70
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
.../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+326
-50
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+17
-0
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
...operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
+349
-74
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+28
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+4
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+4
-4
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
+1
-1
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+9
-9
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
+38
-14
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+6
-6
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+29
-29
include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
...tion/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
+236
-0
include/ck/utility/amd_smfmac.hpp
include/ck/utility/amd_smfmac.hpp
+14
-12
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+164
-9
include/ck_tile/core/container/array.hpp
include/ck_tile/core/container/array.hpp
+12
-1
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+2
-0
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
...k_tile/host/convolution_host_tensor_descriptor_helper.hpp
+266
-0
include/ck_tile/host/convolution_parameter.hpp
include/ck_tile/host/convolution_parameter.hpp
+283
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
09d4c3a4
...
...
@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...
...
@@ -22,7 +23,6 @@
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -257,6 +257,19 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
KPerBlock
/
K1Number
,
ConvBackwardWeightSpecialization
>
{};
static
constexpr
index_t
ClusterLengthMPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
static
constexpr
index_t
ClusterLengthNPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
static
constexpr
auto
conv_ngchw_to_nhwgc_transformer
=
TransformConvNGCHWToNHWGC
<
InLayout
,
WeiLayout
,
OutLayout
,
NDimSpatial
,
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
>
{};
static
constexpr
GemmSpecialization
GemmSpec
=
GemmSpecialization
::
Default
;
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
...
...
@@ -359,141 +372,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
batch
)[
I2
];
}
static
constexpr
index_t
ClusterLengthMPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
static
constexpr
index_t
ClusterLengthNPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeInputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
2
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
3
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
4
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeOutputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeInputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
5
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
2
];
const
index_t
&
DiStride
=
g_n_c_wis_strides
[
3
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
4
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
5
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeOutputTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
5
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
Hi
*
Wi
*
G
*
C
;
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerBlock
/
ClusterLengthMPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
),
Sequence
<
true
,
true
>
{});
}
using
InputTransposeDescType
=
remove_cvref_t
<
decltype
(
MakeInputTransposeDesc
<
NDimSpatial
>
({},
{}))
>
;
using
OutputTransposeDescType
=
remove_cvref_t
<
decltype
(
MakeOutputTransposeDesc
<
NDimSpatial
>
({},
{}))
>
;
using
NGCHWTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
using
NHWGCTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
...
...
@@ -572,8 +456,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
I1
>
;
using
GridwiseElementwiseTranspose
=
GridwiseElementwise
<
Tuple
<
Input
TransposeDescType
>
,
Tuple
<
Output
TransposeDescType
>
,
GridwiseElementwise
<
Tuple
<
NGCHW
TransposeDescType
>
,
Tuple
<
NHWGC
TransposeDescType
>
,
Tuple
<
const
ADataType
*>
,
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
...
...
@@ -652,43 +536,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
begin
(
output_spatial_lengths_
));
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_n_c_wis_strides_transposed
=
b_g_n_c_wis_strides
;
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_k_wos_strides_transposed
=
a_g_n_k_wos_strides
;
// NGKHW - transpose needed
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
b_g_n_c_wis_strides_transposed
[
I0
]
=
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I2
]
=
I1
;
a_g_n_k_wos_strides_transposed
[
I0
]
=
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I2
]
=
I1
;
if
constexpr
(
NDimSpatial
==
2
)
{
b_g_n_c_wis_strides_transposed
[
I3
]
=
input_spatial_lengths_
[
I1
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I4
]
=
Conv_G_
*
Conv_C_
;
a_g_n_k_wos_strides_transposed
[
I3
]
=
output_spatial_lengths_
[
I1
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I4
]
=
Conv_G_
*
Conv_K_
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
b_g_n_c_wis_strides_transposed
[
I3
]
=
input_spatial_lengths_
[
I1
]
*
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I4
]
=
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_C_
;
b_g_n_c_wis_strides_transposed
[
I5
]
=
Conv_G_
*
Conv_C_
;
a_g_n_k_wos_strides_transposed
[
I3
]
=
output_spatial_lengths_
[
I1
]
*
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I4
]
=
input_spatial_lengths_
[
I2
]
*
Conv_G_
*
Conv_K_
;
a_g_n_k_wos_strides_transposed
[
I5
]
=
Conv_G_
*
Conv_K_
;
}
}
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
const
auto
descs
=
conv_to_gemm_transformer_v2
...
...
@@ -755,14 +607,18 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
a_in_transpose_desc_
=
MakeInputTransposeDesc
<
NDimSpatial
>
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
a_out_transpose_desc_
=
MakeOutputTransposeDesc
<
NDimSpatial
>
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
);
b_in_transpose_desc_
=
MakeInputTransposeDesc
<
NDimSpatial
>
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
b_out_transpose_desc_
=
MakeOutputTransposeDesc
<
NDimSpatial
>
(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
b_g_n_c_wis_lengths
,
b_g_n_c_wis_strides
);
elementwise_block_2_ctile_map_transpose_a_
=
Block2TileMapElementwise
{
a_in_transpose_desc_
.
GetLength
(
I0
),
a_in_transpose_desc_
.
GetLength
(
I1
)};
...
...
@@ -816,8 +672,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
Block2TileMapElementwise
elementwise_block_2_ctile_map_transpose_a_
,
elementwise_block_2_ctile_map_transpose_b_
;
Input
TransposeDescType
a_in_transpose_desc_
,
b_in_transpose_desc_
;
Output
TransposeDescType
a_out_transpose_desc_
,
b_out_transpose_desc_
;
NGCHW
TransposeDescType
a_in_transpose_desc_
,
b_in_transpose_desc_
;
NHWGC
TransposeDescType
a_out_transpose_desc_
,
b_out_transpose_desc_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
compute_ptr_offset_of_batch_
;
...
...
@@ -1569,13 +1425,14 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
(
arg
.
GetWorkspaceETensorSizeBytes
()
+
arg
.
GetWorkspaceATensorSizeBytes
())
/
sizeof
(
BDataType
);
// Different data type for A and B is not supported
auto
kernel_transpose
=
kernel_elementwise_dual
<
GridwiseElementwiseTranspose
,
ck
::
Tuple
<
Input
TransposeDescType
>
,
ck
::
Tuple
<
Input
TransposeDescType
>
,
ck
::
Tuple
<
Output
TransposeDescType
>
,
ck
::
Tuple
<
Output
TransposeDescType
>
,
ck
::
Tuple
<
NGCHW
TransposeDescType
>
,
ck
::
Tuple
<
NGCHW
TransposeDescType
>
,
ck
::
Tuple
<
NHWGC
TransposeDescType
>
,
ck
::
Tuple
<
NHWGC
TransposeDescType
>
,
ck
::
Tuple
<
const
ADataType
*>
,
ck
::
Tuple
<
B
DataType
*>
,
ck
::
Tuple
<
A
DataType
*>
,
Block2TileMapElementwise
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
09d4c3a4
...
...
@@ -15,9 +15,11 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
...
...
@@ -307,6 +309,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
// NGCHW is not supported for multiAB
static_assert
(
!
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
||
!
(
isMultiA
||
isMultiB
));
static
constexpr
index_t
NumATensor
=
GetNumABTensors
<
isMultiA
,
ADataType
>
();
static
constexpr
index_t
NumBTensor
=
GetNumABTensors
<
isMultiB
,
BDataType
>
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -315,6 +322,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
...
...
@@ -323,14 +332,33 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
EDataType
,
NumGroupsToMerge
>
;
static
constexpr
index_t
ClusterLengthNPerBlock
=
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
static
constexpr
auto
conv_ngchw_to_nhwgc_transformer
=
TransformConvNGCHWToNHWGC
<
ALayout
,
BLayout
,
ELayout
,
NDimSpatial
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_M_K
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
namespace
ctc
=
tensor_layout
::
convolution
;
using
Layout
=
std
::
conditional_t
<
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NHWGC
,
std
::
conditional_t
<
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NDHWGC
,
ALay
>>
;
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
A
Lay
>();
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
Lay
out
>();
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -353,8 +381,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template
<
typename
ELay
>
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
namespace
ctc
=
tensor_layout
::
convolution
;
using
Layout
=
std
::
conditional_t
<
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NHWGK
,
std
::
conditional_t
<
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NDHWGK
,
ELay
>>
;
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
E
Lay
>();
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
Lay
out
>();
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
@@ -442,6 +478,52 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
using
Block2TileMapElementwise
=
BlockToCTileMap_M00_N0_M01Adapt
<
NPerBlock
,
NPerBlock
>
;
using
NGCHWTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
using
NHWGCTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
static
constexpr
index_t
ElementwiseBlocksize
=
ClusterLengthNPerBlock
*
ClusterLengthNPerBlock
;
using
GridwiseElementwiseInputTranspose
=
GridwiseElementwise
<
Tuple
<
NGCHWTransposeDescType
>
,
Tuple
<
NHWGCTransposeDescType
>
,
Tuple
<
const
ADataType
*>
,
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
ElementwiseBlocksize
,
NPerBlock
,
NPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
I1
,
I0
>
;
using
GridwiseElementwiseOutputTranspose
=
GridwiseElementwise
<
Tuple
<
NHWGCTransposeDescType
>
,
Tuple
<
NGCHWTransposeDescType
>
,
Tuple
<
const
EDataType
*>
,
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
ElementwiseBlocksize
,
NPerBlock
,
NPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
I0
,
I1
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -471,17 +553,31 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_bs_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
)},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
ds_g_n_k_wos_lengths_
{
ds_g_n_k_wos_lengths
},
ds_g_n_k_wos_strides_
{
ds_g_n_k_wos_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
num_group_
{
a_g_n_c_wis_lengths_
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths_
,
a_g_n_c_wis_strides_
,
b_g_k_c_xs_lengths_
,
b_g_k_c_xs_strides_
,
e_g_n_k_wos_lengths_
,
e_g_n_k_wos_strides_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
},
conv_N_per_block_
{
conv_to_gemm_transformer_
.
N_
},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
conv_to_gemm_transformer_
)},
...
...
@@ -501,19 +597,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
compute_ptr_offset_of_n_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
a_g_n_c_wis_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
ds_g_n_k_wos_lengths_
{
ds_g_n_k_wos_lengths
},
ds_g_n_k_wos_strides_
{
ds_g_n_k_wos_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
e_g_n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
cde_element_op_
{
cde_element_op
}
{
// A/B/E Batch Stride
if
constexpr
(
isMultiA
||
isMultiB
)
...
...
@@ -521,7 +605,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
a_g_n_c_wis_strides
_
[
0
]
*
NumGroupsToMerge
;
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
...
...
@@ -537,20 +621,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// in case of MultiA is false but isMultiB is true
// BatchStrideA_ is not tuple.
compute_ptr_offset_of_n_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
a_g_n_c_wis_strides
_
[
1
]
*
conv_N_per_block_
;
}
else
{
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
);
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
a_g_n_c_wis_strides
_
[
1
]
*
conv_N_per_block_
;
}
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
b_g_k_c_xs_strides
_
[
0
]
*
NumGroupsToMerge
;
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
...
...
@@ -571,10 +655,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
else
{
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
a_g_n_c_wis_strides
_
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
b_g_k_c_xs_strides_
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides_
[
1
]
*
conv_N_per_block_
;
// p_as and p_bs are pointers
p_as_grid_
(
I0
)
=
static_cast
<
const
ADataType
*>
(
p_as
);
...
...
@@ -591,27 +676,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D batch stride
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
]
*
NumGroupsToMerge
;
ds_g_n_k_wos_strides
_
[
i
][
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
],
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
ds_g_n_k_wos_strides
_
[
i
][
1
]
*
conv_N_per_block_
;
ConvToGemmFwdTransformer
conv_to_gemm_transformer_d
{
a_g_n_c_wis_lengths
_
,
a_g_n_c_wis_strides
_
,
b_g_k_c_xs_lengths
_
,
b_g_k_c_xs_strides
_
,
e_g_n_k_wos_lengths
_
,
ds_g_n_k_wos_strides
_
[
i
],
conv_filter_strides
_
,
conv_filter_dilations
_
,
input_left_pads
_
,
input_right_pads
_
};
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer_d
);
});
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides_
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides_
[
1
]
*
conv_N_per_block_
;
// populate desc for Ds/E
if
constexpr
(
isMultiA
||
isMultiB
)
...
...
@@ -653,6 +739,54 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_m_n_
);
}
}
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
// Use not modified base strides
a_in_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
);
a_out_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
);
e_in_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
e_out_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
elementwise_block_2_ctile_map_transpose_a_
=
Block2TileMapElementwise
{
a_in_transpose_desc_
.
GetLength
(
I0
),
a_in_transpose_desc_
.
GetLength
(
I1
)};
elementwise_block_2_ctile_map_transpose_e_
=
Block2TileMapElementwise
{
e_in_transpose_desc_
.
GetLength
(
I0
),
e_in_transpose_desc_
.
GetLength
(
I1
)};
}
}
std
::
size_t
GetWorkspaceATensorSizeBytes
()
const
{
return
sizeof
(
ADataType
)
*
a_in_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceETensorSizeBytes
()
const
{
return
sizeof
(
EDataType
)
*
e_out_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceSizeBytes
()
const
{
// Transpose require workspace for A and B
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
GetWorkspaceATensorSizeBytes
()
+
GetWorkspaceETensorSizeBytes
();
}
else
{
return
0
;
}
}
void
Print
()
const
...
...
@@ -671,6 +805,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
// tensor descriptors for problem definiton
index_t
num_group_
;
...
...
@@ -692,6 +840,11 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
Block2TileMapElementwise
elementwise_block_2_ctile_map_transpose_a_
,
elementwise_block_2_ctile_map_transpose_e_
;
NGCHWTransposeDescType
a_in_transpose_desc_
,
e_out_transpose_desc_
;
NHWGCTransposeDescType
a_out_transpose_desc_
,
e_in_transpose_desc_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
...
...
@@ -702,20 +855,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
};
// Invoker
...
...
@@ -723,7 +862,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
Gemm
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
...
...
@@ -794,6 +933,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
const
ADataType
*
p_a_grid
=
arg
.
p_as_grid_
.
At
(
I0
);
EDataType
*
p_e_grid
=
arg
.
p_e_grid_
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
p_a_grid
=
type_convert
<
const
ADataType
*>
(
arg
.
p_workspace_
);
p_e_grid
=
type_convert
<
EDataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceATensorSizeBytes
()
/
sizeof
(
EDataType
);
}
const
auto
kernel
=
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle
<
GridwiseGemm
,
const
ADataType
*
,
...
...
@@ -820,10 +970,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
arg
.
p_a
s
_grid
_
.
At
(
I0
),
// Pass just A descriptor instead of tuple
p_a_grid
,
// Pass just A descriptor instead of tuple
arg
.
p_bs_grid_
.
At
(
I0
),
// Pass just B descriptor instead of tuple
arg
.
p_ds_grid_
,
arg
.
p_e_grid
_
,
p_e_grid
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
...
...
@@ -847,6 +997,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0.
f
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_transpose_a_
.
CalculateGridSize
(
arg
.
a_in_transpose_desc_
);
ADataType
*
p_a_out_grid
=
type_convert
<
ADataType
*>
(
arg
.
p_workspace_
);
auto
kernel_transpose
=
kernel_elementwise
<
GridwiseElementwiseInputTranspose
,
ck
::
Tuple
<
NGCHWTransposeDescType
>
,
ck
::
Tuple
<
NHWGCTransposeDescType
>
,
ck
::
Tuple
<
const
ADataType
*>
,
ck
::
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size
),
dim3
(
ElementwiseBlocksize
),
0
,
make_tuple
(
arg
.
a_in_transpose_desc_
),
make_tuple
(
arg
.
a_out_transpose_desc_
),
make_tuple
(
arg
.
p_as_grid_
.
At
(
I0
)),
make_tuple
(
p_a_out_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_a_
,
element_wise
::
PassThrough
{});
}
avg_time
+=
RunGemm
(
arg
,
stream_config
);
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_transpose_e_
.
CalculateGridSize
(
arg
.
e_in_transpose_desc_
);
const
EDataType
*
p_e_out_grid
=
type_convert
<
EDataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceATensorSizeBytes
()
/
sizeof
(
EDataType
);
EDataType
*
p_e_in_grid
=
arg
.
p_e_grid_
;
auto
kernel_transpose
=
kernel_elementwise
<
GridwiseElementwiseOutputTranspose
,
ck
::
Tuple
<
NHWGCTransposeDescType
>
,
ck
::
Tuple
<
NGCHWTransposeDescType
>
,
ck
::
Tuple
<
const
EDataType
*>
,
ck
::
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size
),
dim3
(
ElementwiseBlocksize
),
0
,
make_tuple
(
arg
.
e_in_transpose_desc_
),
make_tuple
(
arg
.
e_out_transpose_desc_
),
make_tuple
(
p_e_out_grid
),
make_tuple
(
p_e_in_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_e_
,
element_wise
::
PassThrough
{});
}
return
avg_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
...
...
@@ -941,7 +1164,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
return
false
;
}
if
constexpr
(
!
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
())
if
constexpr
(
!
(
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCSpatial_GKSpatial_NGKSpatial
<
ALayout
,
BLayout
,
ELayout
>
()))
{
return
false
;
}
...
...
@@ -953,14 +1177,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ALayout
,
ctc
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
GNWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNDHWC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NGCW
>
||
is_same_v
<
ALayout
,
ctc
::
NGCHW
>
||
is_same_v
<
ALayout
,
ctc
::
NGCDHW
>
)
{
// Check access per C
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
// If not possible, check access per G
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
C
==
1
&&
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
()
&&
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
(
C
==
1
||
NumGroupsToMerge
==
1
)
&&
(
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCSpatial_GKSpatial_NGKSpatial
<
ALayout
,
BLayout
,
ELayout
>
())
&&
G
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
...
...
@@ -1036,6 +1262,35 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
});
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
if
((
G
*
C
)
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
if
((
G
*
K
)
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
const
index_t
input_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
a_g_n_c_wis_lengths_
.
begin
()
+
I3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
output_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
e_g_n_k_wos_lengths_
.
begin
()
+
I3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
input_spatial_acum
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
if
(
output_spatial_acum
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
if
(
!
valid
)
{
return
false
;
...
...
@@ -1046,7 +1301,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ELayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
GNWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NGKW
>
||
is_same_v
<
ELayout
,
ctc
::
NGKHW
>
||
is_same_v
<
ELayout
,
ctc
::
NGKDHW
>
)
{
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
...
...
@@ -1352,6 +1608,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return
str
.
str
();
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
auto
arg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg
)
{
return
arg
->
GetWorkspaceSizeBytes
();
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
p_arg_
->
p_workspace_
=
p_workspace
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"
);
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
View file @
09d4c3a4
...
...
@@ -15,10 +15,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -292,6 +294,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
ConvToGemmFwdTransformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
,
...
...
@@ -302,13 +306,32 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
static
constexpr
index_t
ClusterLengthNPerBlock
=
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
3
);
static
constexpr
auto
conv_ngchw_to_nhwgc_transformer
=
TransformConvNGCHWToNHWGC
<
ALayout
,
BLayout
,
ELayout
,
NDimSpatial
,
MPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
>
{};
template
<
typename
ALay
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
namespace
ctc
=
tensor_layout
::
convolution
;
using
Layout
=
std
::
conditional_t
<
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NHWGC
,
std
::
conditional_t
<
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NDHWGC
,
ALay
>>
;
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
A
Lay
>();
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
Lay
out
>();
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
@@ -351,8 +374,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static
auto
MakeEGridDescriptor_M_N
(
const
ConvToGemmFwdTransformer
&
conv_to_gemm_transformer
)
{
namespace
ctc
=
tensor_layout
::
convolution
;
using
Layout
=
std
::
conditional_t
<
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NHWGK
,
std
::
conditional_t
<
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
(),
ctc
::
NDHWGK
,
ELay
>>
;
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
E
Lay
>();
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
Lay
out
>();
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
...
...
@@ -385,6 +416,53 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// Use appropriate gridwise gemm
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v3
<
GridwiseGemmV3TemplateParams
>
;
using
Block2TileMapElementwise
=
BlockToCTileMap_M00_N0_M01Adapt
<
NPerBlock
,
NPerBlock
>
;
using
NGCHWTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
using
NHWGCTransposeDescType
=
remove_cvref_t
<
decltype
(
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>({},
{}))
>
;
static
constexpr
index_t
ElementwiseBlocksize
=
ClusterLengthNPerBlock
*
ClusterLengthNPerBlock
;
using
GridwiseElementwiseInputTranspose
=
GridwiseElementwise
<
Tuple
<
NGCHWTransposeDescType
>
,
Tuple
<
NHWGCTransposeDescType
>
,
Tuple
<
const
ADataType
*>
,
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
ElementwiseBlocksize
,
NPerBlock
,
NPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
I1
,
I0
>
;
using
GridwiseElementwiseOutputTranspose
=
GridwiseElementwise
<
Tuple
<
NHWGCTransposeDescType
>
,
Tuple
<
NGCHWTransposeDescType
>
,
Tuple
<
const
EDataType
*>
,
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
,
ElementwiseBlocksize
,
NPerBlock
,
NPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
NPerBlock
/
ClusterLengthNPerBlock
,
Sequence
<
1
,
0
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
Sequence
<
CDEBlockTransferScalarPerVector_NPerBlock
>
,
I0
,
I1
>
;
static
auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
{
...
...
@@ -428,17 +506,29 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
:
p_a_grid_
{},
p_b_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
)},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
conv_ngchw_to_nhwgc_transformer
.
TransposeStrides
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
num_group_
{
a_g_n_c_wis_lengths_
[
0
]},
conv_to_gemm_transformer_
{
a_g_n_c_wis_lengths_
,
a_g_n_c_wis_strides_
,
b_g_k_c_xs_lengths_
,
b_g_k_c_xs_strides_
,
e_g_n_k_wos_lengths_
,
e_g_n_k_wos_strides_
,
conv_filter_strides_
,
conv_filter_dilations_
,
input_left_pads_
,
input_right_pads_
},
conv_N_per_block_
{
conv_to_gemm_transformer_
.
N_
},
a_grid_desc_ak0_m_ak1_
{
MakeAGridDescriptor_AK0_M_AK1
<
ALayout
>
(
conv_to_gemm_transformer_
)},
...
...
@@ -451,32 +541,70 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
compute_ptr_offset_of_n_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
a_g_n_c_wis_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
e_g_n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
cde_element_op_
{
cde_element_op
}
{
// A/B/E Batch/N Stride
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
_
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
_
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
_
[
1
]
*
conv_N_per_block_
;
// p_as and p_bs are pointers
p_a_grid_
=
static_cast
<
const
ADataType
*>
(
p_as
);
p_b_grid_
=
static_cast
<
const
BDataType
*>
(
p_bs
);
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
_
[
0
];
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
_
[
1
]
*
conv_N_per_block_
;
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
// Use not modified base strides
a_in_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
);
a_out_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
);
e_in_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNHWGCTransposeDesc
<
NDimSpatial
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
e_out_transpose_desc_
=
conv_ngchw_to_nhwgc_transformer
.
template
MakeNGCHWTransposeDesc
<
NDimSpatial
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
elementwise_block_2_ctile_map_transpose_a_
=
Block2TileMapElementwise
{
a_in_transpose_desc_
.
GetLength
(
I0
),
a_in_transpose_desc_
.
GetLength
(
I1
)};
elementwise_block_2_ctile_map_transpose_e_
=
Block2TileMapElementwise
{
e_in_transpose_desc_
.
GetLength
(
I0
),
e_in_transpose_desc_
.
GetLength
(
I1
)};
}
}
std
::
size_t
GetWorkspaceATensorSizeBytes
()
const
{
return
sizeof
(
ADataType
)
*
a_in_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceETensorSizeBytes
()
const
{
return
sizeof
(
EDataType
)
*
e_out_transpose_desc_
.
GetElementSpaceSize
();
}
std
::
size_t
GetWorkspaceSizeBytes
()
const
{
// Transpose require workspace for A and B
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
GetWorkspaceATensorSizeBytes
()
+
GetWorkspaceETensorSizeBytes
();
}
else
{
return
0
;
}
}
void
Print
()
const
...
...
@@ -492,6 +620,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const
BDataType
*
p_b_grid_
;
EDataType
*
p_e_grid_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
// tensor descriptors for problem definiton
index_t
num_group_
;
...
...
@@ -514,17 +654,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
std
::
array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
std
::
array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
std
::
array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
// block-to-e-tile map
Block2TileMapElementwise
elementwise_block_2_ctile_map_transpose_a_
,
elementwise_block_2_ctile_map_transpose_e_
;
NGCHWTransposeDescType
a_in_transpose_desc_
,
e_out_transpose_desc_
;
NHWGCTransposeDescType
a_out_transpose_desc_
,
e_in_transpose_desc_
;
};
// Invoker
...
...
@@ -532,7 +667,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
Gemm
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
...
...
@@ -561,8 +696,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
index_t
K_split
=
(
GemmK
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
ADataType
*
p_a_grid
=
arg
.
p_a_grid_
;
EDataType
*
p_e_grid
=
arg
.
p_e_grid_
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
p_a_grid
=
type_convert
<
const
ADataType
*>
(
arg
.
p_workspace_
);
p_e_grid
=
type_convert
<
EDataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceATensorSizeBytes
()
/
sizeof
(
EDataType
);
}
typename
GridwiseGemm
::
Argument
gemm_arg
{
arg
.
p_a_grid
_
,
arg
.
p_b_grid_
,
arg
.
p_e_grid
_
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
I1
};
p_a_grid
,
arg
.
p_b_grid_
,
p_e_grid
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
I1
};
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
stream_config
.
flush_cache
)
...
...
@@ -857,6 +1003,79 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return
ave_time
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
float
avg_time
=
0.
f
;
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_transpose_a_
.
CalculateGridSize
(
arg
.
a_in_transpose_desc_
);
ADataType
*
p_a_out_grid
=
type_convert
<
ADataType
*>
(
arg
.
p_workspace_
);
auto
kernel_transpose
=
kernel_elementwise
<
GridwiseElementwiseInputTranspose
,
ck
::
Tuple
<
NGCHWTransposeDescType
>
,
ck
::
Tuple
<
NHWGCTransposeDescType
>
,
ck
::
Tuple
<
const
ADataType
*>
,
ck
::
Tuple
<
ADataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size
),
dim3
(
ElementwiseBlocksize
),
0
,
make_tuple
(
arg
.
a_in_transpose_desc_
),
make_tuple
(
arg
.
a_out_transpose_desc_
),
make_tuple
(
arg
.
p_a_grid_
),
make_tuple
(
p_a_out_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_a_
,
element_wise
::
PassThrough
{});
}
avg_time
+=
RunGemm
(
arg
,
stream_config
);
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_transpose_e_
.
CalculateGridSize
(
arg
.
e_in_transpose_desc_
);
const
EDataType
*
p_e_out_grid
=
type_convert
<
EDataType
*>
(
arg
.
p_workspace_
)
+
arg
.
GetWorkspaceATensorSizeBytes
()
/
sizeof
(
EDataType
);
EDataType
*
p_e_in_grid
=
arg
.
p_e_grid_
;
auto
kernel_transpose
=
kernel_elementwise
<
GridwiseElementwiseOutputTranspose
,
ck
::
Tuple
<
NHWGCTransposeDescType
>
,
ck
::
Tuple
<
NGCHWTransposeDescType
>
,
ck
::
Tuple
<
const
EDataType
*>
,
ck
::
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
element_wise
::
PassThrough
>
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_transpose
,
dim3
(
grid_size
),
dim3
(
ElementwiseBlocksize
),
0
,
make_tuple
(
arg
.
e_in_transpose_desc_
),
make_tuple
(
arg
.
e_out_transpose_desc_
),
make_tuple
(
p_e_out_grid
),
make_tuple
(
p_e_in_grid
),
arg
.
elementwise_block_2_ctile_map_transpose_e_
,
element_wise
::
PassThrough
{});
}
return
avg_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
...
...
@@ -868,6 +1087,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
namespace
ctc
=
tensor_layout
::
convolution
;
const
index_t
G
=
arg
.
b_g_k_c_xs_lengths_
[
I0
];
const
index_t
K
=
arg
.
b_g_k_c_xs_lengths_
[
I1
];
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
I2
];
// check device
if
(
get_device_name
()
==
"gfx908"
)
{
...
...
@@ -924,10 +1147,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
is_same_v
<
ALayout
,
ctc
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
ctc
::
GNWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNHWC
>
||
is_same_v
<
ALayout
,
ctc
::
GNDHWC
>
||
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NGCW
>
||
is_same_v
<
ALayout
,
ctc
::
NGCHW
>
||
is_same_v
<
ALayout
,
ctc
::
NGCDHW
>
)
{
const
index_t
C
=
arg
.
a_g_n_c_wis_lengths_
[
2
];
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
...
...
@@ -947,8 +1169,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
{
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
2
];
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
...
...
@@ -959,15 +1179,43 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return
false
;
}
if
constexpr
(
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
if
((
G
*
C
)
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
if
((
G
*
K
)
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
const
index_t
input_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
a_g_n_c_wis_lengths_
.
begin
()
+
I3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
output_spatial_acum
=
ck
::
accumulate_n
<
index_t
>
(
arg
.
e_g_n_k_wos_lengths_
.
begin
()
+
I3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
(
input_spatial_acum
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
if
(
output_spatial_acum
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
// check vector access of E
if
constexpr
(
is_same_v
<
ELayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
ELayout
,
ctc
::
GNWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNHWK
>
||
is_same_v
<
ELayout
,
ctc
::
GNDHWK
>
||
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NGKW
>
||
is_same_v
<
ELayout
,
ctc
::
NGKHW
>
||
is_same_v
<
ELayout
,
ctc
::
NGKDHW
>
)
{
const
index_t
K
=
arg
.
e_g_n_k_wos_lengths_
[
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
...
...
@@ -1279,6 +1527,34 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
return
str
.
str
();
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
auto
arg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg
)
{
return
arg
->
GetWorkspaceSizeBytes
();
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
p_arg_
->
p_workspace_
=
p_workspace
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"
);
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
09d4c3a4
...
...
@@ -26,6 +26,15 @@ constexpr bool is_GNWC_GKXC_GNWK()
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNWK
>
;
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NGCW_GKXC_NGKW
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NGCW
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NGKW
>
;
}
// 2d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NHWGC_GKYXC_NHWGK
()
...
...
@@ -91,6 +100,14 @@ constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK()
is_GNDHWC_GKZYXC_GNDHWK
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NGCSpatial_GKSpatial_NGKSpatial
()
{
return
is_NGCW_GKXC_NGKW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCHW_GKYXC_NGKHW
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NGCDHW_GKZYXC_NGKDHW
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
struct
ComputePtrOffsetOfStridedBatch
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp"
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -16,95 +27,359 @@ template <typename InDataType,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
OutputIndex
,
ck
::
index_t
BlockSize
,
ck
::
index_t
Reduce
MThreadClusterSize
,
ck
::
index_t
Reduce
KThreadClusterSize
,
ck
::
index_t
Reduce
MThreadSliceSize
,
ck
::
index_t
Reduce
KThreadSliceSize
,
ck
::
index_t
MThreadClusterSize
,
ck
::
index_t
KThreadClusterSize
,
ck
::
index_t
MThreadSliceSize
,
ck
::
index_t
KThreadSliceSize
,
ck
::
index_t
InSrcOutDstVectorSize
>
struct
DevicePool2dFwd_NHWC_NHWC
:
public
DevicePool3dFwd_NDHWC_NDHWC
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
BlockSize
,
ReduceMThreadClusterSize
,
ReduceKThreadClusterSize
,
ReduceMThreadSliceSize
,
ReduceKThreadSliceSize
,
InSrcOutDstVectorSize
>
struct
DevicePool2dFwd_NHWC_NHWC
:
public
DevicePoolFwd
<
4
,
2
,
InDataType
,
OutDataType
,
IndexDataType
,
tensor_layout
::
convolution
::
NHWC
,
tensor_layout
::
convolution
::
NHWC
,
ReduceOpId
,
OutputIndex
>
{
using
DevicePool3D
=
DevicePool3dFwd_NDHWC_NDHWC
<
InDataType
,
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
InOutRank
=
4
;
static
constexpr
index_t
WindowRank
=
2
;
using
ReduceOperation
=
typename
reduce_binary_operator
<
ReduceOpId
>::
opType
;
using
InElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
InElementwiseOperation
;
using
AccElementwiseOperation
=
typename
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
AccElementwiseOperation
;
static
constexpr
ck
::
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
ck
::
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeABGridDescriptor_A_M_K_B_M
(
std
::
vector
<
ck
::
index_t
>
input_nchw_lengths
,
std
::
vector
<
ck
::
index_t
>
output_nchw_lengths
,
std
::
vector
<
ck
::
index_t
>
input_nchw_stride
,
std
::
vector
<
ck
::
index_t
>
output_nchw_stride
,
std
::
vector
<
ck
::
index_t
>
window_spatial_yx_lengths
,
std
::
vector
<
ck
::
index_t
>
window_yx_strides
,
std
::
vector
<
ck
::
index_t
>
window_yx_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_hw_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_hw_pads
)
{
const
index_t
N
=
input_nchw_lengths
[
0
];
const
index_t
C
=
input_nchw_lengths
[
1
];
const
index_t
Hi
=
input_nchw_lengths
[
2
];
const
index_t
Wi
=
input_nchw_lengths
[
3
];
const
index_t
Ho
=
output_nchw_lengths
[
2
];
const
index_t
Wo
=
output_nchw_lengths
[
3
];
const
index_t
Y
=
window_spatial_yx_lengths
[
0
];
const
index_t
X
=
window_spatial_yx_lengths
[
1
];
const
index_t
WindowStrideH
=
window_yx_strides
[
0
];
const
index_t
WindowStrideW
=
window_yx_strides
[
1
];
const
index_t
WindowDilationH
=
window_yx_dilations
[
0
];
const
index_t
WindowDilationW
=
window_yx_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_hw_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_hw_pads
[
1
];
const
index_t
InRightPadH
=
input_right_hw_pads
[
0
];
const
index_t
InRightPadW
=
input_right_hw_pads
[
1
];
const
index_t
MRaw
=
N
*
Ho
*
Wo
*
C
;
const
index_t
MPad
=
math
::
integer_least_multiple
(
MRaw
,
M_BlockTileSize
)
-
MRaw
;
const
index_t
KRaw
=
Y
*
X
;
const
index_t
KPad
=
math
::
integer_least_multiple
(
KRaw
,
K_BlockTileSize
)
-
KRaw
;
// A[ReduceM, ReduceK]
const
index_t
Ni_stride
=
input_nchw_stride
[
0
];
const
index_t
Ci_stride
=
input_nchw_stride
[
1
];
const
index_t
Hi_stride
=
input_nchw_stride
[
2
];
const
index_t
Wi_stride
=
input_nchw_stride
[
3
];
const
auto
in_grid_desc_n_hi_wi_c
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
Ni_stride
,
Hi_stride
,
Wi_stride
,
Ci_stride
));
const
auto
in_grid_desc_n_hip_wip_c
=
transform_tensor_descriptor
(
in_grid_desc_n_hi_wi_c
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_grid_desc_n_y_ho_x_wo_c
=
transform_tensor_descriptor
(
in_grid_desc_n_hip_wip_c
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
WindowDilationH
,
WindowStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
WindowDilationW
,
WindowStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_grid_desc_reducemraw_reducekraw
=
transform_tensor_descriptor
(
in_grid_desc_n_y_ho_x_wo_c
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
,
C
)),
make_merge_transform
(
make_tuple
(
Y
,
X
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_grid_desc_reducem_reducek
=
transform_tensor_descriptor
(
in_grid_desc_reducemraw_reducekraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
// B[ReduceM]
const
index_t
No_stride
=
output_nchw_stride
[
0
];
const
index_t
Co_stride
=
output_nchw_stride
[
1
];
const
index_t
Ho_stride
=
output_nchw_stride
[
2
];
const
index_t
Wo_stride
=
output_nchw_stride
[
3
];
const
auto
out_grid_desc_n_ho_wo_c
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
No_stride
,
Ho_stride
,
Wo_stride
,
Co_stride
));
const
auto
out_grid_desc_reducemraw
=
transform_tensor_descriptor
(
out_grid_desc_n_ho_wo_c
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
,
C
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
out_grid_desc_reducem
=
transform_tensor_descriptor
(
out_grid_desc_reducemraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
make_tuple
(
in_grid_desc_reducem_reducek
,
out_grid_desc_reducem
);
}
using
ABGridDescs
=
decltype
(
MakeABGridDescriptor_A_M_K_B_M
({},
{},
{},
{},
{},
{},
{},
{},
{}));
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I0
])
>
;
using
BGridDesc_M
=
remove_cvref_t
<
decltype
(
ABGridDescs
{}[
I1
])
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_in_dev
,
OutDataType
*
p_out_dev
,
IndexDataType
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>&
input_nchw_lengths
,
std
::
vector
<
ck
::
index_t
>&
output_nchw_lengths
,
std
::
vector
<
ck
::
index_t
>&
input_nchw_stride
,
std
::
vector
<
ck
::
index_t
>&
output_nchw_stride
,
std
::
vector
<
ck
::
index_t
>&
,
// indices_nchw_stride
std
::
vector
<
ck
::
index_t
>&
window_spatial_yx_lengths
,
std
::
vector
<
ck
::
index_t
>&
window_yx_strides
,
std
::
vector
<
ck
::
index_t
>&
window_yx_dilations
,
std
::
vector
<
ck
::
index_t
>&
input_left_hw_pads
,
std
::
vector
<
ck
::
index_t
>&
input_right_hw_pads
)
:
p_in_dev_
{
p_in_dev
},
p_out_dev_
{
p_out_dev
},
p_out_indices_dev_
{
p_out_indices_dev
},
a_grid_desc_m_k_
{},
b_grid_desc_m_
{},
input_nchw_lengths_
{
input_nchw_lengths
},
output_nchw_lengths_
{
output_nchw_lengths
},
input_nchw_stride_
{
input_nchw_stride
},
output_nchw_stride_
{
output_nchw_stride
}
{
const
auto
descs
=
MakeABGridDescriptor_A_M_K_B_M
(
input_nchw_lengths
,
output_nchw_lengths
,
input_nchw_stride
,
output_nchw_stride
,
window_spatial_yx_lengths
,
window_yx_strides
,
window_yx_dilations
,
input_left_hw_pads
,
input_right_hw_pads
);
a_grid_desc_m_k_
=
descs
[
I0
];
b_grid_desc_m_
=
descs
[
I1
];
int32_t
reduceLength
=
window_spatial_yx_lengths
[
0
]
*
window_spatial_yx_lengths
[
1
];
std
::
tie
(
in_element_op_
,
acc_element_op_
)
=
reduce_unary_operator
<
ReduceOpId
,
true
,
true
>::
GetElementwiseOperator
(
reduceLength
);
}
const
InDataType
*
p_in_dev_
;
OutDataType
*
p_out_dev_
;
IndexDataType
*
p_out_indices_dev_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_M
b_grid_desc_m_
;
InElementwiseOperation
in_element_op_
;
AccElementwiseOperation
acc_element_op_
;
// for checking vector load/store
std
::
vector
<
ck
::
index_t
>
input_nchw_lengths_
;
std
::
vector
<
ck
::
index_t
>
output_nchw_lengths_
;
std
::
vector
<
ck
::
index_t
>
input_nchw_stride_
;
std
::
vector
<
ck
::
index_t
>
output_nchw_stride_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
// for NHWC, the dim C is the fastest dimension, and is not reduced.
// Hence, it is in M dimension for reduction kernel.
static
constexpr
index_t
InSrcOutDstVectorDim
=
0
;
// 0: M, 1: K
using
gridwise_reduce
=
GridwiseReduction_mk_to_m_threadwise
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
ReduceOpId
,
OutputIndex
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
ReduceOperation
,
InElementwiseOperation
,
AccElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
false
,
// propagate_nan
BlockSize
,
Reduce
MThread
Cluster
Size
,
Reduce
KThread
Cluster
Size
,
ReduceMThreadSliceSize
,
ReduceKThreadSlice
Size
,
MThread
Slice
Size
,
KThread
Slice
Size
,
InSrcOutDstVectorDim
,
InSrcOutDstVector
Size
,
InSrcOutDstVectorSize
>
;
std
::
unique_ptr
<
BaseArgument
>
const
auto
kernel
=
kernel_reduce_threadwise
<
gridwise_reduce
,
OutputIndex
,
true
,
// pooling need to return global index
false
,
// don't have index input
InDataType
,
OutDataType
,
ComputeDataType
,
IndexDataType
,
AGridDesc_M_K
,
BGridDesc_M
,
InElementwiseOperation
,
AccElementwiseOperation
>
;
ck
::
index_t
M
=
arg
.
a_grid_desc_m_k_
.
GetLength
(
I0
);
const
index_t
grid_size
=
(
M
/
M_BlockTileSize
);
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_m_
,
arg
.
in_element_op_
,
arg
.
acc_element_op_
,
float
(
1
),
arg
.
p_in_dev_
,
nullptr
,
float
(
0
),
arg
.
p_out_dev_
,
arg
.
p_out_indices_dev_
);
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
// C should be fastest dimension
if
(
pArg
->
input_nchw_stride_
[
1
]
!=
1
)
return
false
;
for
(
int
i
=
0
;
i
<
InOutRank
;
++
i
)
{
if
(
pArg
->
input_nchw_stride_
[
i
]
==
1
&&
pArg
->
input_nchw_lengths_
[
i
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
if
(
pArg
->
output_nchw_stride_
[
i
]
==
1
&&
pArg
->
output_nchw_lengths_
[
i
]
%
InSrcOutDstVectorSize
!=
0
)
return
false
;
}
return
true
;
}
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_in_dev
,
void
*
p_out_dev
,
void
*
p_out_indices_dev
,
std
::
vector
<
ck
::
index_t
>
input_lengths
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
output_lengths
,
std
::
vector
<
ck
::
index_t
>
input_stride
,
std
::
vector
<
ck
::
index_t
>
output_stride
,
std
::
vector
<
ck
::
index_t
>
indices_stride
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
std
::
vector
<
ck
::
index_t
>
input_
nchw_
lengths
,
std
::
vector
<
ck
::
index_t
>
window_
yx_
lengths
,
std
::
vector
<
ck
::
index_t
>
output_
nchw_
lengths
,
std
::
vector
<
ck
::
index_t
>
input_
nchw_
stride
,
std
::
vector
<
ck
::
index_t
>
output_
nchw_
stride
,
std
::
vector
<
ck
::
index_t
>
indices_
nchw_
stride
,
std
::
vector
<
ck
::
index_t
>
window_
yx_
strides
,
std
::
vector
<
ck
::
index_t
>
window_
yx_
dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_
hw_
pads
,
std
::
vector
<
ck
::
index_t
>
input_right_
hw_
pads
,
std
::
vector
<
ck
::
index_t
>
pooling_dims
)
override
{
static
constexpr
index_t
InOutRank
=
4
;
static
constexpr
index_t
WindowRank
=
2
;
if
(
input_lengths
.
size
()
!=
InOutRank
||
window_lengths
.
size
()
!=
WindowRank
||
input_lengths
.
size
()
!=
InOutRank
||
window_strides
.
size
()
!=
WindowRank
||
window_dilations
.
size
()
!=
WindowRank
||
input_left_pads
.
size
()
!=
WindowRank
||
input_right_pads
.
size
()
!=
WindowRank
)
if
(
input_nchw_lengths
.
size
()
!=
InOutRank
||
window_yx_lengths
.
size
()
!=
WindowRank
||
input_nchw_lengths
.
size
()
!=
InOutRank
||
window_yx_strides
.
size
()
!=
WindowRank
||
window_yx_dilations
.
size
()
!=
WindowRank
||
input_left_hw_pads
.
size
()
!=
WindowRank
||
input_right_hw_pads
.
size
()
!=
WindowRank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
if
(
pooling_dims
!=
std
::
vector
<
ck
::
index_t
>
{
2
,
3
})
throw
std
::
runtime_error
(
"pooling_dims only support {2, 3} in pool2d so far"
);
// NCHW to NCDHW
input_lengths
.
insert
(
input_lengths
.
begin
()
+
2
,
1
);
output_lengths
.
insert
(
output_lengths
.
begin
()
+
2
,
1
);
input_stride
.
insert
(
input_stride
.
begin
()
+
2
,
0
);
output_stride
.
insert
(
output_stride
.
begin
()
+
2
,
0
);
indices_stride
.
insert
(
indices_stride
.
begin
()
+
2
,
0
);
// YX to ZYX
window_lengths
.
insert
(
window_lengths
.
begin
(),
1
);
window_strides
.
insert
(
window_strides
.
begin
(),
0
);
window_dilations
.
insert
(
window_dilations
.
begin
(),
0
);
input_left_pads
.
insert
(
input_left_pads
.
begin
(),
0
);
input_right_pads
.
insert
(
input_right_pads
.
begin
(),
0
);
pooling_dims
=
{
2
,
3
,
4
};
return
DevicePool3D
::
MakeArgumentPointer
(
p_in_dev
,
p_out_dev
,
p_out_indices_dev
,
input_lengths
,
window_lengths
,
output_lengths
,
input_stride
,
output_stride
,
indices_stride
,
window_strides
,
window_dilations
,
input_left_pads
,
input_right_pads
,
pooling_dims
);
if
(
output_nchw_stride
!=
indices_nchw_stride
)
throw
std
::
runtime_error
(
"output_nchw_stride need to be equal to indices_nchw_stride for now"
);
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_dev
),
static_cast
<
OutDataType
*>
(
p_out_dev
),
static_cast
<
IndexDataType
*>
(
p_out_indices_dev
),
input_nchw_lengths
,
output_nchw_lengths
,
input_nchw_stride
,
output_nchw_stride
,
indices_nchw_stride
,
window_yx_lengths
,
window_yx_strides
,
window_yx_dilations
,
input_left_hw_pads
,
input_right_hw_pads
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DevicePool2dFwd_NHWC_NHWC<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"InSrcOutDstVectorSize_"
<<
InSrcOutDstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
09d4c3a4
...
...
@@ -355,12 +355,39 @@ struct UnaryDivide
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
/
type_convert
<
T
>
(
divider_
);
};
template
<
>
__host__
__device__
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
half_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
half_t
>
(
x_
/
divider_f_
);
};
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
>
(
bhalf_t
&
y
,
const
bhalf_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
bhalf_t
>
(
x_
/
divider_f_
);
};
template
<
>
__host__
__device__
void
operator
()
<
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
float
x_
=
type_convert
<
float
>
(
x
);
float
divider_f_
=
type_convert
<
float
>
(
divider_
);
y
=
type_convert
<
f8_t
>
(
x_
/
divider_f_
);
};
int32_t
divider_
=
1
;
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
09d4c3a4
...
...
@@ -221,7 +221,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
__host__
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
...
...
@@ -303,7 +303,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
...
...
@@ -576,12 +576,12 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
M
;
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideA
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
N
;
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideB
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
09d4c3a4
...
...
@@ -255,7 +255,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
__host__
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
...
...
@@ -337,7 +337,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
}
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
...
...
@@ -647,12 +647,12 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
M
;
a_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideA
;
}
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
N
;
b_k_split_offset
=
blockIdx
.
z
*
karg
.
KRead
*
karg
.
StrideB
;
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp
View file @
09d4c3a4
...
...
@@ -315,7 +315,7 @@ struct ThreadwiseTensorSliceTransfer_v5r1
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
]
;
index_t
tmp
=
0
;
static_for
<
0
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
...
...
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
View file @
09d4c3a4
...
...
@@ -324,55 +324,55 @@ struct DppSelector
static
constexpr
auto
GetDpp
();
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
{
return
DppInstr
::
dpp8_f16_8x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
{
return
DppInstr
::
dpp8_f16_8x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
{
return
DppInstr
::
dpp8_f16_16x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
{
return
DppInstr
::
dpp8_f16_32x8x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
{
return
DppInstr
::
dpp8_f16_1x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
{
return
DppInstr
::
dpp8_f16_2x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
{
return
DppInstr
::
dpp8_f16_2x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
{
return
DppInstr
::
dpp8_f16_4x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
{
return
DppInstr
::
dpp8_f16_4x32x2
;
}
...
...
include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp
View file @
09d4c3a4
...
...
@@ -35,10 +35,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32f16>
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_16x16x32f16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -57,10 +63,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16f16>
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_32x32x16f16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -79,10 +91,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_16x16x32bf16>
static
constexpr
index_t
k_per_blk
=
8
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_16x16x32bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -101,10 +119,16 @@ struct smfmac<SmfmacInstr::smfmac_f32_32x32x16bf16>
static
constexpr
index_t
k_per_blk
=
16
;
static
constexpr
bool
is_k_reduction
=
true
;
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
int32_t
&
idx
,
FloatC
&
reg_c
)
const
template
<
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
idx_part
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
const
index_t
&
idx
,
FloatC
&
reg_c
)
const
{
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
idx
,
reg_c
);
intrin_smfmac_f32_32x32x16bf16
<
MPerXdlops
,
NPerXdlops
>::
Run
<
FloatC
,
idx_part
>
(
a
,
b
,
idx
,
reg_c
);
}
};
...
...
@@ -305,8 +329,8 @@ struct SparseXdlopsGemm
"base base_type must be half or bfloat16!"
);
static_for
<
0
,
KPack
/
smfmac_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
],
p_c_thread
);
smfmac_instr
.
template
run
<
MPerXdlops
,
NPerXdlops
,
k
%
4
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
idx
[
k
/
4
],
p_c_thread
);
});
}
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
09d4c3a4
...
...
@@ -415,7 +415,7 @@ struct WmmaSelector
static
constexpr
auto
GetWmma
();
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
;
...
...
@@ -425,7 +425,7 @@ struct WmmaSelector
}
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
;
...
...
@@ -435,19 +435,19 @@ struct WmmaSelector
}
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_f16_16x16x16_f16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
;
}
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
;
...
...
@@ -458,7 +458,7 @@ struct WmmaSelector
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
}
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
09d4c3a4
...
...
@@ -651,97 +651,97 @@ struct MfmaSelector
static
constexpr
auto
GetMfma
();
template
<
>
static
constexpr
auto
GetMfma
<
double
,
16
,
16
>
()
constexpr
auto
GetMfma
<
double
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f64_16x16x4f64
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
64
,
64
>
()
constexpr
auto
GetMfma
<
float
,
64
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
32
,
64
>
()
constexpr
auto
GetMfma
<
float
,
32
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
16
,
64
>
()
constexpr
auto
GetMfma
<
float
,
16
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
8
,
64
>
()
constexpr
auto
GetMfma
<
float
,
8
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
4
,
64
>
()
constexpr
auto
GetMfma
<
float
,
4
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x1xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
32
,
32
>
()
constexpr
auto
GetMfma
<
float
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x2xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
float
,
16
,
16
>
()
constexpr
auto
GetMfma
<
float
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x4xf32
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
64
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
64
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
32
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
16
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
8
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
8
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
4
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
4
,
64
>
()
{
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
...
...
@@ -751,7 +751,7 @@ struct MfmaSelector
}
template
<
>
static
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
...
...
@@ -762,72 +762,72 @@ struct MfmaSelector
#if defined(CK_USE_AMD_MFMA_GFX940)
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x16i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x32i8
;
}
#else
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_i32_16x16x16i8
;
}
#endif
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16f8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16f8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
,
bf8_t
>
()
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
,
bf8_t
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x16bf8f8
;
}
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
,
f8_t
>
()
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
,
f8_t
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
...
...
include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp
0 → 100644
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
namespace
ck
{
namespace
tensor_operation
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
index_t
NDimSpatial
,
index_t
MPerThread
,
index_t
NPerThread
>
struct
TransformConvNGCHWToNHWGC
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeNGCHWTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
I0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
I2
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
I3
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
auto
MakeNHWGCTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeNGCHWTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
I0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
I2
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
I3
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
I4
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
MakeNHWGCTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeNGCHWTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I5
];
const
index_t
&
GStride
=
g_n_c_wis_strides
[
I0
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
&
CStride
=
g_n_c_wis_strides
[
I2
];
const
index_t
&
DiStride
=
g_n_c_wis_strides
[
I3
];
const
index_t
&
HiStride
=
g_n_c_wis_strides
[
I4
];
const
index_t
&
WiStride
=
g_n_c_wis_strides
[
I5
];
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
MakeNHWGCTransposeDesc
(
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides
)
{
const
index_t
&
G
=
g_n_c_wis_lengths
[
I0
];
const
index_t
&
N
=
g_n_c_wis_lengths
[
I1
];
const
index_t
&
C
=
g_n_c_wis_lengths
[
I2
];
const
index_t
&
Di
=
g_n_c_wis_lengths
[
I3
];
const
index_t
&
Hi
=
g_n_c_wis_lengths
[
I4
];
const
index_t
&
Wi
=
g_n_c_wis_lengths
[
I5
];
const
index_t
&
NStride
=
g_n_c_wis_strides
[
I1
];
const
index_t
DiStride
=
Hi
*
Wi
*
G
*
C
;
const
index_t
HiStride
=
Wi
*
G
*
C
;
const
index_t
WiStride
=
G
*
C
;
const
index_t
GStride
=
C
;
const
index_t
CStride
=
1
;
const
auto
desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
G
,
C
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
GStride
,
CStride
,
DiStride
,
HiStride
,
WiStride
));
const
auto
merged_desc
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
G
,
C
)),
make_merge_transform
(
make_tuple
(
Di
,
Hi
,
Wi
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
device
::
PadTensorDescriptor
(
merged_desc
,
make_tuple
(
MPerThread
,
NPerThread
),
Sequence
<
true
,
true
>
{});
}
static
auto
TransposeStrides
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
g_n_c_wis_strides
)
{
if
constexpr
(
device
::
is_NGCHW_GKYXC_NGKHW
<
ALayout
,
BLayout
,
ELayout
>
()
||
device
::
is_NGCDHW_GKZYXC_NGKDHW
<
ALayout
,
BLayout
,
ELayout
>
())
{
std
::
array
<
index_t
,
NDimSpatial
+
3
>
g_n_c_wis_strides_transposed
;
const
auto
G
=
g_n_c_wis_lengths
[
I0
];
const
auto
C
=
g_n_c_wis_lengths
[
I2
];
g_n_c_wis_strides_transposed
[
I0
]
=
C
;
g_n_c_wis_strides_transposed
[
I1
]
=
g_n_c_wis_strides
[
I1
];
g_n_c_wis_strides_transposed
[
I2
]
=
I1
;
if
constexpr
(
NDimSpatial
==
2
)
{
g_n_c_wis_strides_transposed
[
I3
]
=
g_n_c_wis_lengths
[
I4
]
*
G
*
C
;
g_n_c_wis_strides_transposed
[
I4
]
=
G
*
C
;
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
g_n_c_wis_strides_transposed
[
I3
]
=
g_n_c_wis_lengths
[
I4
]
*
g_n_c_wis_lengths
[
I5
]
*
G
*
C
;
g_n_c_wis_strides_transposed
[
I4
]
=
g_n_c_wis_lengths
[
I5
]
*
G
*
C
;
g_n_c_wis_strides_transposed
[
I5
]
=
G
*
C
;
}
return
g_n_c_wis_strides_transposed
;
}
else
{
// transpose not needed
return
g_n_c_wis_strides
;
}
}
};
}
// namespace tensor_operation
}
// namespace ck
include/ck/utility/amd_smfmac.hpp
View file @
09d4c3a4
...
...
@@ -9,16 +9,18 @@ namespace ck {
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32f16
;
// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse
// indices from reg_idx
template
<
>
struct
intrin_smfmac_f32_16x16x32f16
<
16
,
16
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
@@ -34,13 +36,13 @@ struct intrin_smfmac_f32_16x16x32bf16;
template
<
>
struct
intrin_smfmac_f32_16x16x32bf16
<
16
,
16
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
@@ -56,13 +58,13 @@ struct intrin_smfmac_f32_32x32x16f16;
template
<
>
struct
intrin_smfmac_f32_32x32x16f16
<
32
,
32
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
@@ -78,13 +80,13 @@ struct intrin_smfmac_f32_32x32x16bf16;
template
<
>
struct
intrin_smfmac_f32_32x32x16bf16
<
32
,
32
>
{
template
<
class
FloatC
>
template
<
class
FloatC
,
index_t
abid
=
0
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
t32
_t
&
reg_idx
,
FloatC
&
reg_c
)
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
in
dex
_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
abid
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
...
...
include/ck/utility/reduction_operator.hpp
View file @
09d4c3a4
...
...
@@ -52,12 +52,28 @@ struct Add
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half
_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8
_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
a
=
a
+
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
f8_t
>
(
a_
+
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
half_t
>
(
a_
+
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
...
...
@@ -112,12 +128,28 @@ struct Mul
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
half
_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8
_t
>::
value
,
"The data type is not supported by the Mul accumulator!"
);
a
=
a
*
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
f8_t
>
(
a_
*
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
a
=
type_convert
<
half_t
>
(
a_
*
b_
);
}
__host__
__device__
inline
constexpr
void
operator
()(
bhalf_t
&
a
,
bhalf_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
...
...
@@ -137,6 +169,16 @@ struct Max
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
bhalf_t
>
(
val
);
}
if
constexpr
(
is_same_v
<
T
,
f8_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
f8_t
>
(
val
);
}
if
constexpr
(
is_same_v
<
T
,
half_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Lowest
();
return
type_convert
<
half_t
>
(
val
);
}
else
{
return
NumericLimits
<
T
>::
Lowest
();
...
...
@@ -154,8 +196,7 @@ struct Max
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
...
...
@@ -171,12 +212,29 @@ struct Max
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
...
...
@@ -197,6 +255,30 @@ struct Max
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
struct
Min
...
...
@@ -209,6 +291,16 @@ struct Min
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
bhalf_t
>
(
val
);
}
else
if
constexpr
(
is_same_v
<
T
,
half_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
half_t
>
(
val
);
}
else
if
constexpr
(
is_same_v
<
T
,
f8_t
>
)
{
float
val
=
NumericLimits
<
float
>::
Max
();
return
type_convert
<
f8_t
>
(
val
);
}
else
{
return
NumericLimits
<
T
>::
Max
();
...
...
@@ -227,8 +319,7 @@ struct Min
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Min accumulator!"
);
if
(
a
>
b
)
...
...
@@ -244,6 +335,24 @@ struct Min
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
...
...
@@ -270,6 +379,30 @@ struct Min
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
half_t
&
a
,
half_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
{
a
=
b
;
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
>
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
struct
AMax
...
...
@@ -299,6 +432,15 @@ struct AMax
a
=
b
;
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
...
...
@@ -313,6 +455,18 @@ struct AMax
changed
=
true
;
}
}
__host__
__device__
inline
constexpr
void
operator
()(
f8_t
&
a
,
f8_t
b
,
bool
&
changed
)
const
{
float
a_
=
type_convert
<
float
>
(
a
);
float
b_
=
type_convert
<
float
>
(
b
);
if
(
a_
<
b_
)
{
a
=
b
;
changed
=
true
;
}
}
};
template
<
typename
T
>
...
...
@@ -352,7 +506,8 @@ struct InMemoryDataOperationSupportedOnDataType<InMemoryDataOperationEnum::Set,
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
bhalf_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
||
is_same
<
DataType
,
f8_t
>::
value
;
};
template
<
typename
DataType
>
...
...
include/ck_tile/core/container/array.hpp
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <initializer_list>
#include <vector>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
...
...
@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
return
!
(
a
==
b
);
}
template
<
typename
T
,
index_t
N
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
std
::
vector
<
X
>&
x
)
{
array
<
T
,
N
>
arr
;
static_for
<
0
,
N
,
1
>
{}([
&
x
,
&
arr
](
auto
i
)
{
arr
(
i
)
=
x
[
i
];
});
return
arr
;
}
template
<
typename
T
,
index_t
N
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
X
&
x
)
{
...
...
include/ck_tile/host.hpp
View file @
09d4c3a4
...
...
@@ -5,6 +5,8 @@
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp"
...
...
include/ck_tile/host/convolution_host_tensor_descriptor_helper.hpp
0 → 100644
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
namespace
conv
{
namespace
detail
{
template
<
typename
OldLayout
>
CK_TILE_HOST
std
::
vector
<
std
::
size_t
>
get_layout_transpose_gnchw_to_old
()
{
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
)
{
return
{
0
,
1
,
2
,
3
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
)
{
return
{
0
,
1
,
2
,
3
,
4
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
{
return
{
0
,
1
,
2
,
3
,
4
,
5
};
}
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
)
{
return
{
0
,
1
,
3
,
2
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
)
{
return
{
0
,
1
,
4
,
2
,
3
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
{
return
{
0
,
1
,
5
,
2
,
3
,
4
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
)
{
return
{
2
,
0
,
3
,
1
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
)
{
return
{
3
,
0
,
4
,
1
,
2
};
}
else
if
constexpr
(
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
||
std
::
is_same_v
<
OldLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
{
return
{
4
,
0
,
5
,
1
,
2
,
3
};
}
else
{
printf
(
"%s
\n
"
,
__func__
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
}
}
// namespace detail
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
// regardless of physical layout
template
<
typename
InLayout
>
CK_TILE_HOST
HostTensorDescriptor
make_input_host_tensor_descriptor_g_n_c_wis_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCHW
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNCDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
2
,
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
>
||
std
::
is_same_v
<
InLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
1
,
param
.
input_spatial_lengths_
.
begin
(),
param
.
input_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
{
printf
(
"%s
\n
"
,
__func__
);
printf
(
"%s
\n
"
,
InLayout
::
name
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
return
transpose_host_tensor_descriptor_given_new2old
(
HostTensorDescriptor
(
physical_lengths
),
detail
::
get_layout_transpose_gnchw_to_old
<
InLayout
>
());
}
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
// regardless of physical layout
template
<
typename
WeiLayout
>
CK_TILE_HOST
HostTensorDescriptor
make_weight_host_tensor_descriptor_g_k_c_xs_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXC
>
)
{
if
(
param
.
G_
!=
1
)
{
throw
std
::
runtime_error
(
"wrong! G != 1"
);
}
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCYX
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKCZYX
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKYXC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GKZYXC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
2
,
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KYXGC
>
||
std
::
is_same_v
<
WeiLayout
,
ck_tile
::
tensor_layout
::
convolution
::
KZYXGC
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
K_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
C_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
1
,
param
.
filter_spatial_lengths_
.
begin
(),
param
.
filter_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
{
printf
(
"%s
\n
"
,
__func__
);
printf
(
"%s
\n
"
,
WeiLayout
::
name
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
return
transpose_host_tensor_descriptor_given_new2old
(
HostTensorDescriptor
(
physical_lengths
),
detail
::
get_layout_transpose_gnchw_to_old
<
WeiLayout
>
());
}
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
// regardless of physical layout
template
<
typename
OutLayout
>
CK_TILE_HOST
HostTensorDescriptor
make_output_host_tensor_descriptor_g_n_k_wos_packed
(
const
ck_tile
::
conv
::
ConvParam
&
param
)
{
std
::
vector
<
std
::
size_t
>
physical_lengths
;
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKHW
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNKDHW
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
end
(),
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
// separate from legacy code above
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNHWK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
GNDHWK
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
2
,
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
if
constexpr
(
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NHWGK
>
||
std
::
is_same_v
<
OutLayout
,
ck_tile
::
tensor_layout
::
convolution
::
NDHWGK
>
)
{
physical_lengths
=
std
::
vector
<
std
::
size_t
>
{
static_cast
<
std
::
size_t
>
(
param
.
N_
),
static_cast
<
std
::
size_t
>
(
param
.
G_
),
static_cast
<
std
::
size_t
>
(
param
.
K_
)};
physical_lengths
.
insert
(
physical_lengths
.
begin
()
+
1
,
param
.
output_spatial_lengths_
.
begin
(),
param
.
output_spatial_lengths_
.
begin
()
+
param
.
num_dim_spatial_
);
}
else
{
printf
(
"%s
\n
"
,
__func__
);
printf
(
"%s
\n
"
,
OutLayout
::
name
);
throw
std
::
runtime_error
(
"wrong! unsupported layout"
);
}
return
transpose_host_tensor_descriptor_given_new2old
(
HostTensorDescriptor
(
physical_lengths
),
detail
::
get_layout_transpose_gnchw_to_old
<
OutLayout
>
());
}
}
// namespace conv
}
// namespace ck_tile
include/ck_tile/host/convolution_parameter.hpp
0 → 100644
View file @
09d4c3a4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <numeric>
#include <iterator>
#include <vector>
namespace
ck_tile
{
namespace
conv
{
struct
ConvParam
{
ConvParam
();
ConvParam
(
ck_tile
::
index_t
n_dim
,
ck_tile
::
index_t
group_count
,
ck_tile
::
index_t
n_batch
,
ck_tile
::
index_t
n_out_channels
,
ck_tile
::
index_t
n_in_channels
,
const
std
::
vector
<
ck_tile
::
index_t
>&
filters_len
,
const
std
::
vector
<
ck_tile
::
index_t
>&
input_len
,
const
std
::
vector
<
ck_tile
::
index_t
>&
strides
,
const
std
::
vector
<
ck_tile
::
index_t
>&
dilations
,
const
std
::
vector
<
ck_tile
::
index_t
>&
left_pads
,
const
std
::
vector
<
ck_tile
::
index_t
>&
right_pads
)
:
num_dim_spatial_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_dim
)),
G_
(
static_cast
<
ck_tile
::
long_index_t
>
(
group_count
)),
N_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_batch
)),
K_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_out_channels
)),
C_
(
static_cast
<
ck_tile
::
long_index_t
>
(
n_in_channels
)),
filter_spatial_lengths_
(
num_dim_spatial_
),
input_spatial_lengths_
(
num_dim_spatial_
),
output_spatial_lengths_
(
num_dim_spatial_
),
conv_filter_strides_
(
num_dim_spatial_
),
conv_filter_dilations_
(
num_dim_spatial_
),
input_left_pads_
(
num_dim_spatial_
),
input_right_pads_
(
num_dim_spatial_
)
{
if
(
static_cast
<
ck_tile
::
index_t
>
(
filter_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_strides_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_dilations_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_left_pads_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_right_pads_
.
size
())
!=
num_dim_spatial_
)
{
throw
(
std
::
runtime_error
(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"
));
}
for
(
ck_tile
::
index_t
i
=
0
;
i
<
num_dim_spatial_
;
++
i
)
{
filter_spatial_lengths_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
filters_len
[
i
]);
input_spatial_lengths_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
input_len
[
i
]);
conv_filter_strides_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
strides
[
i
]);
conv_filter_dilations_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
dilations
[
i
]);
input_left_pads_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
left_pads
[
i
]);
input_right_pads_
[
i
]
=
static_cast
<
ck_tile
::
long_index_t
>
(
right_pads
[
i
]);
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck_tile
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_filter_dilations_
[
i
]
+
1
;
output_spatial_lengths_
[
i
]
=
(
input_spatial_lengths_
[
i
]
+
input_left_pads_
[
i
]
+
input_right_pads_
[
i
]
-
x_eff
)
/
conv_filter_strides_
[
i
]
+
1
;
}
}
ConvParam
(
ck_tile
::
long_index_t
n_dim
,
ck_tile
::
long_index_t
group_count
,
ck_tile
::
long_index_t
n_batch
,
ck_tile
::
long_index_t
n_out_channels
,
ck_tile
::
long_index_t
n_in_channels
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
filters_len
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
input_len
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
strides
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
dilations
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
left_pads
,
const
std
::
vector
<
ck_tile
::
long_index_t
>&
right_pads
)
:
num_dim_spatial_
(
n_dim
),
G_
(
group_count
),
N_
(
n_batch
),
K_
(
n_out_channels
),
C_
(
n_in_channels
),
filter_spatial_lengths_
(
filters_len
),
input_spatial_lengths_
(
input_len
),
output_spatial_lengths_
(
num_dim_spatial_
),
conv_filter_strides_
(
strides
),
conv_filter_dilations_
(
dilations
),
input_left_pads_
(
left_pads
),
input_right_pads_
(
right_pads
)
{
if
(
static_cast
<
ck_tile
::
index_t
>
(
filter_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_spatial_lengths_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_strides_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
conv_filter_dilations_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_left_pads_
.
size
())
!=
num_dim_spatial_
||
static_cast
<
ck_tile
::
index_t
>
(
input_right_pads_
.
size
())
!=
num_dim_spatial_
)
{
throw
(
std
::
runtime_error
(
"ConvParam::ConvParam: "
"parameter size is different from number of declared dimensions!"
));
}
for
(
ck_tile
::
index_t
i
=
0
;
i
<
num_dim_spatial_
;
++
i
)
{
// XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const
ck_tile
::
long_index_t
x_eff
=
(
filter_spatial_lengths_
[
i
]
-
1
)
*
conv_filter_dilations_
[
i
]
+
1
;
output_spatial_lengths_
[
i
]
=
(
input_spatial_lengths_
[
i
]
+
input_left_pads_
[
i
]
+
input_right_pads_
[
i
]
-
x_eff
)
/
conv_filter_strides_
[
i
]
+
1
;
}
}
ck_tile
::
long_index_t
num_dim_spatial_
;
ck_tile
::
long_index_t
G_
;
ck_tile
::
long_index_t
N_
;
ck_tile
::
long_index_t
K_
;
ck_tile
::
long_index_t
C_
;
std
::
vector
<
ck_tile
::
long_index_t
>
filter_spatial_lengths_
;
std
::
vector
<
ck_tile
::
long_index_t
>
input_spatial_lengths_
;
std
::
vector
<
ck_tile
::
long_index_t
>
output_spatial_lengths_
;
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_strides_
;
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_dilations_
;
std
::
vector
<
ck_tile
::
long_index_t
>
input_left_pads_
;
std
::
vector
<
ck_tile
::
long_index_t
>
input_right_pads_
;
std
::
vector
<
ck_tile
::
long_index_t
>
GetOutputSpatialLengths
()
const
{
return
output_spatial_lengths_
;
}
std
::
size_t
GetFlops
()
const
{
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
return
static_cast
<
std
::
size_t
>
(
2
)
*
G_
*
N_
*
K_
*
C_
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths_
),
std
::
next
(
std
::
begin
(
output_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
())
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
std
::
next
(
std
::
begin
(
filter_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
());
}
template
<
typename
InDataType
>
std
::
size_t
GetInputByte
()
const
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return
sizeof
(
InDataType
)
*
(
G_
*
N_
*
C_
*
std
::
accumulate
(
std
::
begin
(
input_spatial_lengths_
),
std
::
next
(
std
::
begin
(
input_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
()));
}
template
<
typename
WeiDataType
>
std
::
size_t
GetWeightByte
()
const
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return
sizeof
(
WeiDataType
)
*
(
G_
*
K_
*
C_
*
std
::
accumulate
(
std
::
begin
(
filter_spatial_lengths_
),
std
::
next
(
std
::
begin
(
filter_spatial_lengths_
),
num_dim_spatial_
),
1
,
std
::
multiplies
<>
()));
}
template
<
typename
OutDataType
>
std
::
size_t
GetOutputByte
()
const
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return
sizeof
(
OutDataType
)
*
(
G_
*
N_
*
K_
*
std
::
accumulate
(
std
::
begin
(
output_spatial_lengths_
),
std
::
end
(
output_spatial_lengths_
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
()));
}
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
>
std
::
size_t
GetByte
()
const
{
return
GetInputByte
<
InDataType
>
()
+
GetWeightByte
<
WeiDataType
>
()
+
GetOutputByte
<
OutDataType
>
();
}
};
ConvParam
::
ConvParam
()
:
ConvParam
::
ConvParam
(
2
,
1
,
128
,
256
,
192
,
{
3
,
3
},
{
71
,
71
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
})
{
}
CK_TILE_HOST
std
::
string
get_conv_param_parser_helper_msg
()
{
std
::
string
msg
;
msg
+=
"Following arguments (depending on number of spatial dims):
\n
"
" Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)
\n
"
" G, N, K, C,
\n
"
" <filter spatial dimensions>, (ie Y, X for 2D)
\n
"
" <input image spatial dimensions>, (ie Hi, Wi for 2D)
\n
"
" <strides>, (ie Sy, Sx for 2D)
\n
"
" <dilations>, (ie Dy, Dx for 2D)
\n
"
" <left padding>, (ie LeftPy, LeftPx for 2D)
\n
"
" <right padding>, (ie RightPy, RightPx for 2D)
\n
"
;
return
msg
;
}
CK_TILE_HOST
ck_tile
::
conv
::
ConvParam
parse_conv_param
(
int
num_dim_spatial
,
int
arg_idx
,
char
*
const
argv
[])
{
const
ck_tile
::
long_index_t
G
=
std
::
stol
(
argv
[
arg_idx
++
]);
const
ck_tile
::
long_index_t
N
=
std
::
stol
(
argv
[
arg_idx
++
]);
const
ck_tile
::
long_index_t
K
=
std
::
stol
(
argv
[
arg_idx
++
]);
const
ck_tile
::
long_index_t
C
=
std
::
stol
(
argv
[
arg_idx
++
]);
std
::
vector
<
ck_tile
::
long_index_t
>
filter_spatial_lengths
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
input_spatial_lengths
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_strides
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
conv_filter_dilations
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
input_left_pads
(
num_dim_spatial
);
std
::
vector
<
ck_tile
::
long_index_t
>
input_right_pads
(
num_dim_spatial
);
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
filter_spatial_lengths
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_spatial_lengths
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
conv_filter_strides
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
conv_filter_dilations
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_left_pads
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
for
(
int
i
=
0
;
i
<
num_dim_spatial
;
++
i
)
{
input_right_pads
[
i
]
=
std
::
stol
(
argv
[
arg_idx
++
]);
}
return
ck_tile
::
conv
::
ConvParam
{
num_dim_spatial
,
G
,
N
,
K
,
C
,
filter_spatial_lengths
,
input_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
}
}
// namespace conv
}
// namespace ck_tile
Prev
1
2
3
4
5
6
7
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment