Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c2976d7a
"vscode:/vscode.git/clone" did not exist on "ce1e69dd57105409f55d8bebfacea9b01a536808"
Commit
c2976d7a
authored
Mar 04, 2022
by
ltqin
Browse files
Merge branch 'develop' into ck_conv_bww_fp16
parents
e46ea9fd
0619ebf7
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1550 additions
and
1287 deletions
+1550
-1287
composable_kernel/include/config.hpp
composable_kernel/include/config.hpp
+6
-0
composable_kernel/include/tensor_operation/element_wise_operation.hpp
...ernel/include/tensor_operation/element_wise_operation.hpp
+1
-0
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
...ude/tensor_operation/threadwise_tensor_slice_transfer.hpp
+51
-386
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp
+38
-284
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r1.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v6r1.hpp
+26
-164
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r2.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v6r2.hpp
+27
-172
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r3.hpp
...ensor_operation/threadwise_tensor_slice_transfer_v6r3.hpp
+29
-182
composable_kernel/include/utility/dynamic_buffer.hpp
composable_kernel/include/utility/dynamic_buffer.hpp
+25
-0
composable_kernel/include/utility/magic_division.hpp
composable_kernel/include/utility/magic_division.hpp
+19
-10
composable_kernel/include/utility/tensor_space_filling_curve.hpp
...ble_kernel/include/utility/tensor_space_filling_curve.hpp
+22
-7
device_operation/CMakeLists.txt
device_operation/CMakeLists.txt
+17
-4
device_operation/include/convolution_backward_data_specialization.hpp
...tion/include/convolution_backward_data_specialization.hpp
+17
-0
device_operation/include/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
...ion/include/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+814
-0
device_operation/include/device_conv_bwd_data.hpp
device_operation/include/device_conv_bwd_data.hpp
+47
-0
device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
...vice_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
+83
-0
device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
...evice_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
+85
-0
device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
...evice_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
+82
-0
device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
...vice_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
+83
-0
device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
...rc/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
+39
-39
device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
...rc/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
+39
-39
No files found.
composable_kernel/include/config.hpp
View file @
c2976d7a
...
@@ -151,6 +151,12 @@
...
@@ -151,6 +151,12 @@
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE 1
#define CK_WORKAROUND_SWDEV_XXXXXX_THREAD_WISE_COPY_V1R5_TYPE_CONVERT_ISSUE 1
#endif
#endif
// workaround for verifaction failure, due to compiler regression, for conv bwd-data fp16 using some
// tuning parameter
#ifndef CK_WORKAROUND_SWDEV_325164
#define CK_WORKAROUND_SWDEV_325164 1
#endif
namespace
ck
{
namespace
ck
{
enum
InMemoryDataOperationEnum_t
enum
InMemoryDataOperationEnum_t
...
...
composable_kernel/include/tensor_operation/element_wise_operation.hpp
View file @
c2976d7a
#ifndef CK_ELEMENT_WISE_OPERATION_HPP
#ifndef CK_ELEMENT_WISE_OPERATION_HPP
#define CK_ELEMENT_WISE_OPERATION_HPP
#define CK_ELEMENT_WISE_OPERATION_HPP
#include "data_type.hpp"
#include "data_type.hpp"
#include "data_type.hpp"
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp
View file @
c2976d7a
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -67,8 +68,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -67,8 +68,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
__device__
constexpr
ThreadwiseTensorSliceTransfer_v1r3
(
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin_idx
,
const
Index
&
dst_slice_origin_idx
,
...
@@ -85,16 +84,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -85,16 +84,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
dst_coord_
=
make_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
}
template
<
typename
SrcSliceOriginIdx
,
template
<
typename
SrcSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
typename
SrcBuffer
,
typename
DstBuffer
,
typename
DstStepHacks
>
__device__
void
Run
(
const
SrcDesc
&
,
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
DstBuffer
&
dst_buf
)
const
DstStepHacks
&
dst_step_hacks
)
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc need to known at compile-time"
);
"wrong! SrcDesc need to known at compile-time"
);
...
@@ -108,9 +103,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -108,9 +103,6 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
...
@@ -119,85 +111,26 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -119,85 +111,26 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr
auto
dst_scalar_step_in_vector
=
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
// TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
static_assert
(
DstScalarPerVector
==
SpaceFillingCurve
::
ScalarPerVector
,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"
);
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
dst_vector
;
using
dst_vector_t
=
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
constexpr
auto
ordered_access_lengths
=
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// make forward steps
static_for
<
0
,
num_accesses
,
1
>
{}([
&
](
auto
idx_1d
)
{
const
auto
dst_forward_steps
=
generate_tuple
(
constexpr
auto
idx_md
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
forward_step_idx
,
dst_step_hacks
[
I0
][
i
]);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
dst_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
backward_step_idx
,
dst_step_hacks
[
I1
][
i
]);
},
Number
<
nDim
>
{});
// loop over tensor and copy
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate dst data index
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_access_idx
[
i
]
:
ordered_access_lengths
[
i
]
-
1
-
ordered_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
dst_scalar_per_access
;
}();
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
dst_vector
;
using
dst_vector_t
=
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
// copy data from src_buf into dst_vector
// copy data from src_buf into dst_vector
// TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve?
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
dst_data_
idx
+
i
*
dst_scalar_step_in_vector
);
src_slice_origin_idx
+
idx
_md
+
i
*
dst_scalar_step_in_vector
);
SrcData
dst_v
;
SrcData
dst_v
;
...
@@ -212,69 +145,18 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -212,69 +145,18 @@ struct ThreadwiseTensorSliceTransfer_v1r3
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
// copy data from dst_vector into dst_buf
// copy data from dst_vector into dst_buf
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
Set
)
dst_buf
.
template
Update
<
DstInMemOp
,
dst_vector_t
>(
{
dst_coord_
.
GetOffset
(),
dst_buf
.
template
Set
<
dst_vector_t
>(
is_dst_valid
,
dst_coord_
.
GetOffset
(),
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
is_dst_valid
,
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
}
else
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
AtomicAdd
)
{
dst_buf
.
template
AtomicAdd
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
}
else
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
Add
)
{
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
tmp
;
tmp
.
template
AsType
<
dst_vector_t
>()(
Number
<
0
>
{})
=
dst_buf
.
template
Get
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
);
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
t
)
{
dst_vector
.
template
AsType
<
DstData
>()(
t
)
+=
tmp
.
template
AsType
<
DstData
>()[
t
];
});
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector
.
template
AsType
<
dst_vector_t
>()[
Number
<
0
>
{}]);
}
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
if
constexpr
(
idx_1d
.
value
!=
num_accesses
-
1
)
{
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_access_idx
[
i
]
<
ordered_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_access_idx
[
j
]
==
ordered_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
make_tensor_coordinate_step
(
dst_desc
,
forward_step
));
}
}
();
// move
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dim_access_order
[
i
]]);
}
}
});
});
});
// move dst coordinate back to slice origin (or not)
// move dst coordinate back to slice origin (or not)
...
@@ -287,82 +169,20 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -287,82 +169,20 @@ struct ThreadwiseTensorSliceTransfer_v1r3
}
}
}
}
template
<
typename
SrcSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
constexpr
index_t
ntransform_dst
=
remove_cvref_t
<
DstDesc
>::
GetNumOfTransform
();
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_dst
,
0
>::
type
{};
constexpr
auto
dst_step_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
Run
(
SrcDesc
{},
SrcSliceOriginIdx
{},
src_buf
,
dst_desc
,
dst_buf
,
dst_step_hacks
);
}
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate dst data index after last iteration in Run(), if it has not being reset by
// RunWrite()
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
dst_scalar_per_access
;
}();
//
constexpr
auto
reset_dst_data_step
=
[
&
]()
{
Index
reset_dst_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_dst_data_step_
(
i
)
=
-
dst_data_idx
[
i
];
});
return
reset_dst_data_step_
;
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
}();
constexpr
auto
reset_step
=
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_accesses
-
1
>
{},
Number
<
0
>
{});
return
reset_
dst_data_
step
;
return
reset_step
;
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
...
@@ -383,7 +203,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -383,7 +203,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3
private:
private:
DstCoord
dst_coord_
;
DstCoord
dst_coord_
;
const
DstElementwiseOperation
dst_element_op_
;
const
DstElementwiseOperation
dst_element_op_
;
};
//
namespace ck
};
//
struct ThreadwiseTensorSliceTransfer_v1r3
// Assume:
// Assume:
// 1. src:
// 1. src:
...
@@ -428,16 +248,12 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -428,16 +248,12 @@ struct ThreadwiseTensorSliceTransfer_v2
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
src_coord_
=
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
);
}
}
template
<
typename
SrcBuffer
,
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
DstSliceOriginIdx
>
typename
DstBuffer
,
typename
DstSliceOriginIdx
,
typename
SrcStepHacks
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
,
DstBuffer
&
dst_buf
)
const
SrcStepHacks
&
src_step_hacks
)
{
{
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! DstDesc need to known at compile-time"
);
"wrong! DstDesc need to known at compile-time"
);
...
@@ -453,9 +269,6 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -453,9 +269,6 @@ struct ThreadwiseTensorSliceTransfer_v2
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
constexpr
auto
dst_slice_origin_idx
=
DstSliceOriginIdx
{};
constexpr
auto
dst_slice_origin_idx
=
DstSliceOriginIdx
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
...
@@ -464,80 +277,19 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -464,80 +277,19 @@ struct ThreadwiseTensorSliceTransfer_v2
constexpr
auto
src_scalar_step_in_vector
=
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
SrcVectorDim
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
remove_cv_t
<
decltype
(
src_scalar_per_access
)
>>
;
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// make forward steps
const
auto
src_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
src_desc
,
forward_step_idx
,
src_step_hacks
[
I0
][
i
]);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
src_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
src_desc
,
backward_step_idx
,
src_step_hacks
[
I1
][
i
]);
},
Number
<
nDim
>
{});
// loop over tensor and copy
// loop over tensor and copy
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate src data index
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_access_idx
[
i
]
:
ordered_access_lengths
[
i
]
-
1
-
ordered_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
src_scalar_per_access
;
}();
static_for
<
0
,
num_accesses
,
1
>
{}([
&
](
auto
idx_1d
)
{
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
src_vector
;
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
src_vector
;
using
src_vector_t
=
using
src_vector_t
=
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
::
type
;
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
::
type
;
constexpr
auto
src_data_idx
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
const
bool
is_src_valid
=
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
...
@@ -555,38 +307,13 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -555,38 +307,13 @@ struct ThreadwiseTensorSliceTransfer_v2
dst_buf
(
Number
<
dst_offset
>
{})
=
src_vector
.
template
AsType
<
SrcData
>()[
i
];
dst_buf
(
Number
<
dst_offset
>
{})
=
src_vector
.
template
AsType
<
SrcData
>()[
i
];
});
});
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
if
constexpr
(
idx_1d
.
value
!=
num_accesses
-
1
)
{
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
move_on_dim_
(
i
)
=
ordered_access_idx
[
i
]
<
ordered_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_access_idx
[
j
]
==
ordered_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
move_tensor_coordinate
(
src_desc
,
src_coord_
,
make_tensor_coordinate_step
(
dst_desc
,
forward_step
));
}
}
();
// move
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_steps
[
dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_steps
[
dim_access_order
[
i
]]);
}
}
});
});
});
// move src coordinate back to slice origin (or not)
// move src coordinate back to slice origin (or not)
...
@@ -599,82 +326,20 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -599,82 +326,20 @@ struct ThreadwiseTensorSliceTransfer_v2
}
}
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
typename
DstSliceOriginIdx
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
{
constexpr
index_t
ntransform_src
=
SrcDesc
::
GetNumOfTransform
();
constexpr
auto
zeros
=
typename
uniform_sequence_gen
<
ntransform_src
,
0
>::
type
{};
constexpr
auto
src_step_hacks
=
make_tuple
(
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}),
generate_tuple
([
&
](
auto
)
{
return
zeros
;
},
Number
<
nDim
>
{}));
Run
(
src_desc
,
src_buf
,
DstDesc
{},
DstSliceOriginIdx
{},
dst_buf
,
src_step_hacks
);
}
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
__device__
static
constexpr
auto
GetSrcCoordinateResetStep
()
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
remove_cv_t
<
decltype
(
src_scalar_per_access
)
>>
;
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate src data index after last iteration in Run(), if it has not being reset by
// RunWrite()
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
src_scalar_per_access
;
}();
//
constexpr
auto
reset_src_data_step
=
[
&
]()
{
Index
reset_src_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_src_data_step_
(
i
)
=
-
src_data_idx
[
i
];
});
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
constexpr
auto
reset_step
=
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_accesses
-
1
>
{},
Number
<
0
>
{});
return
reset_src_data_step_
;
return
reset_step
;
}();
return
reset_src_data_step
;
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
c2976d7a
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "static_tensor.hpp"
#include "static_tensor.hpp"
#include "tensor_space_filling_curve.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -123,73 +124,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -123,73 +124,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
SrcDimAccessOrder
,
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
remove_cv_t
<
decltype
(
src_scalar_per_access
)
>>
;
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// make forward steps
const
auto
src_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
src_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
src_desc
,
forward_step_idx
);
// loop over space-filling curve
},
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
Number
<
nDim
>
{});
// make backward steps
const
auto
src_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
src_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
src_desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
// loop over tensor and copy
// loop over tensor and copy
static_ford
<
decltype
(
ordered_src_access_lengths
)
>
{}([
&
](
auto
ordered_src_access_idx
)
{
static_for
<
0
,
num_accesses
,
1
>
{}([
&
](
auto
idx_1d
)
{
// judge move forward or move backward
constexpr
auto
src_data_idx
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate src data index
constexpr
auto
src_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_src_access_idx
[
i
]
:
ordered_src_access_lengths
[
i
]
-
1
-
ordered_src_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
src_dim_access_order
)
*
src_scalar_per_access
;
}();
constexpr
auto
src_data_idx_seq
=
generate_sequence_v2
(
constexpr
auto
src_data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
src_data_idx
[
i
]
>
{};
},
Number
<
src_data_idx
.
Size
()
>
{});
[
&
](
auto
i
)
{
return
Number
<
src_data_idx
[
i
]
>
{};
},
Number
<
src_data_idx
.
Size
()
>
{});
...
@@ -218,39 +162,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -218,39 +162,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
.
template
SetAsType
<
src_vector_t
>(
.
template
SetAsType
<
src_vector_t
>(
src_data_idx_seq
,
src_vector_container
.
template
AsType
<
src_vector_t
>()[
I0
]);
src_data_idx_seq
,
src_vector_container
.
template
AsType
<
src_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
// move coordinate
if
constexpr
(
idx_1d
.
value
!=
num_accesses
-
1
)
{
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
move_tensor_coordinate
(
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
src_desc
,
src_coord_
,
make_tensor_coordinate_step
(
src_desc
,
forward_step
));
move_on_dim_
(
i
)
=
ordered_src_access_idx
[
i
]
<
ordered_src_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_src_access_idx
[
j
]
==
ordered_src_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
}
();
// move src coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_steps
[
src_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_steps
[
src_dim_access_order
[
i
]]);
}
}
});
});
});
// move src coordinate back to slice origin (or not)
// move src coordinate back to slice origin (or not)
...
@@ -374,73 +292,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -374,73 +292,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DstDimAccessOrder
,
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// make forward steps
const
auto
dst_forward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
// make backward steps
const
auto
dst_backward_steps
=
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
dst_scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
dst_desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
// loop over tensor and copy
// loop over tensor and copy
static_ford
<
decltype
(
ordered_dst_access_lengths
)
>
{}([
&
](
auto
ordered_dst_access_idx
)
{
static_for
<
0
,
num_accesses
,
1
>
{}([
&
](
auto
idx_1d
)
{
// judge move forward or move backward
constexpr
auto
dst_data_idx
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate dst data index
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_dst_access_idx
[
i
]
:
ordered_dst_access_lengths
[
i
]
-
1
-
ordered_dst_access_idx
[
i
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
dst_dim_access_order
)
*
dst_scalar_per_access
;
}();
constexpr
auto
dst_data_idx_seq
=
generate_sequence_v2
(
constexpr
auto
dst_data_idx_seq
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
return
Number
<
dst_data_idx
[
i
]
>
{};
},
Number
<
dst_data_idx
.
Size
()
>
{});
[
&
](
auto
i
)
{
return
Number
<
dst_data_idx
[
i
]
>
{};
},
Number
<
dst_data_idx
.
Size
()
>
{});
...
@@ -470,39 +330,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -470,39 +330,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
is_dst_valid
,
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
// move coordinate
if
constexpr
(
idx_1d
.
value
!=
num_accesses
-
1
)
{
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
move_tensor_coordinate
(
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
dst_desc
,
dst_coord_
,
make_tensor_coordinate_step
(
dst_desc
,
forward_step
));
move_on_dim_
(
i
)
=
ordered_dst_access_idx
[
i
]
<
ordered_dst_access_lengths
[
i
]
-
1
;
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_dst_access_idx
[
j
]
==
ordered_dst_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
}
();
// move dst coord
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dst_dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dst_dim_access_order
[
i
]]);
}
}
});
});
});
// move dst coordinate back to slice origin (or not)
// move dst coordinate back to slice origin (or not)
...
@@ -522,55 +356,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -522,55 +356,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
SrcDimAccessOrder
,
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
remove_cv_t
<
decltype
(
src_scalar_per_access
)
>>
;
constexpr
auto
ordered_src_access_lengths
=
container_reorder_given_new2old
(
src_access_lengths
,
src_dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_src_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_src_access_lengths
[
j
]
+
ordered_src_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate src data index after last iteration in RunRead(), if it has not being reset by
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
// RunRead()
constexpr
auto
reset_step
=
constexpr
auto
src_data_idx
=
[
&
]()
{
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_accesses
-
1
>
{},
Number
<
0
>
{});
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
return
reset_step
;
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_src_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
src_dim_access_order
)
*
src_scalar_per_access
;
}();
//
constexpr
auto
reset_src_data_step
=
[
&
]()
{
Index
reset_src_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_src_data_step_
(
i
)
=
-
src_data_idx
[
i
];
});
return
reset_src_data_step_
;
}();
return
reset_src_data_step
;
}
}
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
__device__
static
constexpr
auto
GetDstCoordinateResetStep
()
...
@@ -580,55 +374,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -580,55 +374,15 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DstDimAccessOrder
,
constexpr
auto
dst_dim_access_order
=
DstDimAccessOrder
{};
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
constexpr
auto
ordered_dst_access_lengths
=
container_reorder_given_new2old
(
dst_access_lengths
,
dst_dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_dst_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_dst_access_lengths
[
j
]
+
ordered_dst_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
// RunWrite()
constexpr
auto
dst_data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_dst_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
dst_dim_access_order
)
*
dst_scalar_per_access
;
}();
//
constexpr
auto
reset_dst_data_step
=
[
&
]()
{
Index
reset_dst_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_dst_data_step_
(
i
)
=
-
dst_data_idx
[
i
];
});
return
reset_dst_data_step_
;
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
}();
constexpr
auto
reset_step
=
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_accesses
-
1
>
{},
Number
<
0
>
{});
return
reset_
dst_data_
step
;
return
reset_step
;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r1.hpp
View file @
c2976d7a
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -40,9 +41,6 @@ struct ThreadwiseTensorSliceTransfer_v6r1
...
@@ -40,9 +41,6 @@ struct ThreadwiseTensorSliceTransfer_v6r1
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
SrcCoord
=
decltype
(
make_tensor_coordinate
(
SrcDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v6r1
(
const
SrcDesc
&
src_desc
,
__device__
constexpr
ThreadwiseTensorSliceTransfer_v6r1
(
const
SrcDesc
&
src_desc
,
...
@@ -79,70 +77,14 @@ struct ThreadwiseTensorSliceTransfer_v6r1
...
@@ -79,70 +77,14 @@ struct ThreadwiseTensorSliceTransfer_v6r1
constexpr
auto
scalar_per_access
=
generate_sequence
(
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
auto
make_forward_steps
=
[
&
](
auto
desc
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
};
auto
make_backward_steps
=
[
&
](
auto
desc
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
};
// make forward steps
const
auto
src_forward_steps
=
make_forward_steps
(
src_desc
);
const
auto
dst_forward_steps
=
make_forward_steps
(
dst_desc
);
// make backward steps
const
auto
src_backward_steps
=
make_backward_steps
(
src_desc
);
const
auto
dst_backward_steps
=
make_backward_steps
(
dst_desc
);
// loop over slice window
// loop over space-filling curve
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
static_for
<
0
,
num_accesses
,
1
>
{}([
&
](
auto
idx_1d
)
{
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
ScalarPerVector
>
;
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
ScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
...
@@ -168,59 +110,20 @@ struct ThreadwiseTensorSliceTransfer_v6r1
...
@@ -168,59 +110,20 @@ struct ThreadwiseTensorSliceTransfer_v6r1
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
// copy data from dst_vector into dst_buf
// copy data from dst_vector into dst_buf
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
Set
)
dst_buf
.
template
Update
<
DstInMemOp
,
dst_vector_t
>(
{
dst_coord_
.
GetOffset
(),
dst_buf
.
template
Set
<
dst_vector_t
>(
is_dst_valid
,
dst_coord_
.
GetOffset
(),
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
}
else
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
AtomicAdd
)
{
dst_buf
.
template
AtomicAdd
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
}
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
// move coordinate
if
constexpr
(
idx_1d
.
value
!=
num_accesses
-
1
)
{
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
move_tensor_coordinate
(
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
src_desc
,
src_coord_
,
make_tensor_coordinate_step
(
src_desc
,
forward_step
));
move_on_dim_
(
i
)
=
ordered_access_idx
[
i
]
<
ordered_access_lengths
[
i
]
-
1
;
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
make_tensor_coordinate_step
(
dst_desc
,
forward_step
));
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_on_dim_
(
i
)
&=
ordered_access_idx
[
j
]
==
ordered_access_lengths
[
j
]
-
1
;
});
});
return
move_on_dim_
;
}
}
();
// move coordinate
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_forward_steps
[
dim_access_order
[
i
]]);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
src_desc
,
src_coord_
,
src_backward_steps
[
dim_access_order
[
i
]]);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dim_access_order
[
i
]]);
}
}
});
});
});
// move coordinate back to slice origin (or not)
// move coordinate back to slice origin (or not)
...
@@ -243,59 +146,18 @@ struct ThreadwiseTensorSliceTransfer_v6r1
...
@@ -243,59 +146,18 @@ struct ThreadwiseTensorSliceTransfer_v6r1
__device__
static
constexpr
auto
GetCoordinateResetStep
()
__device__
static
constexpr
auto
GetCoordinateResetStep
()
{
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
scalar_per_access
=
generate_sequence
(
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate data index after last iteration in Run(), if it has not being reset
constexpr
auto
data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
scalar_per_access
;
}();
//
constexpr
auto
reset_data_step
=
[
&
]()
{
Index
reset_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_data_step_
(
i
)
=
-
data_idx
[
i
];
});
return
reset_data_step_
;
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
}();
constexpr
auto
reset_step
=
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_accesses
-
1
>
{},
Number
<
0
>
{});
return
reset_
data_
step
;
return
reset_step
;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
...
@@ -332,7 +194,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1
...
@@ -332,7 +194,7 @@ struct ThreadwiseTensorSliceTransfer_v6r1
SrcCoord
src_coord_
;
SrcCoord
src_coord_
;
DstCoord
dst_coord_
;
DstCoord
dst_coord_
;
const
ElementwiseOperation
element_op_
;
const
ElementwiseOperation
element_op_
;
};
};
// namespace ck
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r2.hpp
View file @
c2976d7a
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -44,10 +45,6 @@ struct ThreadwiseTensorSliceTransfer_v6r2
...
@@ -44,10 +45,6 @@ struct ThreadwiseTensorSliceTransfer_v6r2
using
Src1Coord
=
decltype
(
make_tensor_coordinate
(
Src1Desc
{},
Index
{}));
using
Src1Coord
=
decltype
(
make_tensor_coordinate
(
Src1Desc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
Src0CoordStep
=
decltype
(
make_tensor_coordinate_step
(
Src0Desc
{},
Index
{}));
using
Src1CoordStep
=
decltype
(
make_tensor_coordinate_step
(
Src1Desc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v6r2
(
const
Src0Desc
&
src0_desc
,
__device__
constexpr
ThreadwiseTensorSliceTransfer_v6r2
(
const
Src0Desc
&
src0_desc
,
...
@@ -96,72 +93,14 @@ struct ThreadwiseTensorSliceTransfer_v6r2
...
@@ -96,72 +93,14 @@ struct ThreadwiseTensorSliceTransfer_v6r2
constexpr
auto
scalar_per_access
=
generate_sequence
(
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
auto
make_forward_steps
=
[
&
](
auto
desc
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
};
auto
make_backward_steps
=
[
&
](
auto
desc
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
};
// make forward steps
const
auto
src0_forward_steps
=
make_forward_steps
(
src0_desc
);
const
auto
src1_forward_steps
=
make_forward_steps
(
src1_desc
);
const
auto
dst_forward_steps
=
make_forward_steps
(
dst_desc
);
// make backward steps
const
auto
src0_backward_steps
=
make_backward_steps
(
src0_desc
);
const
auto
src1_backward_steps
=
make_backward_steps
(
src1_desc
);
const
auto
dst_backward_steps
=
make_backward_steps
(
dst_desc
);
// loop over slice window
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// loop over space-filling curve
static_for
<
0
,
num_accesses
,
1
>
{}([
&
](
auto
idx_1d
)
{
using
src0_vector_type
=
vector_type_maker_t
<
Src0Data
,
ScalarPerVector
>
;
using
src0_vector_type
=
vector_type_maker_t
<
Src0Data
,
ScalarPerVector
>
;
using
src0_vector_t
=
typename
src0_vector_type
::
type
;
using
src0_vector_t
=
typename
src0_vector_type
::
type
;
...
@@ -197,65 +136,22 @@ struct ThreadwiseTensorSliceTransfer_v6r2
...
@@ -197,65 +136,22 @@ struct ThreadwiseTensorSliceTransfer_v6r2
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
// copy data from dst_vector into dst_buf
// copy data from dst_vector into dst_buf
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
Set
)
dst_buf
.
template
Update
<
DstInMemOp
,
dst_vector_t
>(
{
dst_coord_
.
GetOffset
(),
dst_buf
.
template
Set
<
dst_vector_t
>(
is_dst_valid
,
dst_coord_
.
GetOffset
(),
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
}
else
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
AtomicAdd
)
{
dst_buf
.
template
AtomicAdd
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
}
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
// move coordinate
if
constexpr
(
idx_1d
.
value
!=
num_accesses
-
1
)
{
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
move_tensor_coordinate
(
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
src0_desc
,
src0_coord_
,
make_tensor_coordinate_step
(
src0_desc
,
forward_step
));
move_on_dim_
(
i
)
=
ordered_access_idx
[
i
]
<
ordered_access_lengths
[
i
]
-
1
;
move_tensor_coordinate
(
src1_desc
,
src1_coord_
,
make_tensor_coordinate_step
(
src1_desc
,
forward_step
));
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_tensor_coordinate
(
move_on_dim_
(
i
)
&=
ordered_access_idx
[
j
]
==
ordered_access_lengths
[
j
]
-
1
;
dst_desc
,
dst_coord_
,
make_tensor_coordinate_step
(
dst_desc
,
forward_step
));
});
});
return
move_on_dim_
;
}
}
();
// move coordinate
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
src0_desc
,
src0_coord_
,
src0_forward_steps
[
dim_access_order
[
i
]]);
move_tensor_coordinate
(
src1_desc
,
src1_coord_
,
src1_forward_steps
[
dim_access_order
[
i
]]);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
src0_desc
,
src0_coord_
,
src0_backward_steps
[
dim_access_order
[
i
]]);
move_tensor_coordinate
(
src1_desc
,
src1_coord_
,
src1_backward_steps
[
dim_access_order
[
i
]]);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dim_access_order
[
i
]]);
}
}
});
});
});
// move coordinate back to slice origin (or not)
// move coordinate back to slice origin (or not)
...
@@ -286,59 +182,18 @@ struct ThreadwiseTensorSliceTransfer_v6r2
...
@@ -286,59 +182,18 @@ struct ThreadwiseTensorSliceTransfer_v6r2
__device__
static
constexpr
auto
GetCoordinateResetStep
()
__device__
static
constexpr
auto
GetCoordinateResetStep
()
{
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
scalar_per_access
=
generate_sequence
(
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate data index after last iteration in Run(), if it has not being reset
constexpr
auto
data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
scalar_per_access
;
}();
//
constexpr
auto
reset_data_step
=
[
&
]()
{
Index
reset_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_data_step_
(
i
)
=
-
data_idx
[
i
];
});
return
reset_data_step_
;
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
}();
constexpr
auto
reset_step
=
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_accesses
-
1
>
{},
Number
<
0
>
{});
return
reset_
data_
step
;
return
reset_step
;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
...
...
composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v6r3.hpp
View file @
c2976d7a
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include "common_header.hpp"
#include "common_header.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_descriptor_helper.hpp"
#include "tensor_space_filling_curve.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -48,11 +49,6 @@ struct ThreadwiseTensorSliceTransfer_v6r3
...
@@ -48,11 +49,6 @@ struct ThreadwiseTensorSliceTransfer_v6r3
using
Src2Coord
=
decltype
(
make_tensor_coordinate
(
Src2Desc
{},
Index
{}));
using
Src2Coord
=
decltype
(
make_tensor_coordinate
(
Src2Desc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
DstCoord
=
decltype
(
make_tensor_coordinate
(
DstDesc
{},
Index
{}));
using
Src0CoordStep
=
decltype
(
make_tensor_coordinate_step
(
Src0Desc
{},
Index
{}));
using
Src1CoordStep
=
decltype
(
make_tensor_coordinate_step
(
Src1Desc
{},
Index
{}));
using
Src2CoordStep
=
decltype
(
make_tensor_coordinate_step
(
Src2Desc
{},
Index
{}));
using
DstCoordStep
=
decltype
(
make_tensor_coordinate_step
(
DstDesc
{},
Index
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v6r3
(
const
Src0Desc
&
src0_desc
,
__device__
constexpr
ThreadwiseTensorSliceTransfer_v6r3
(
const
Src0Desc
&
src0_desc
,
...
@@ -112,74 +108,14 @@ struct ThreadwiseTensorSliceTransfer_v6r3
...
@@ -112,74 +108,14 @@ struct ThreadwiseTensorSliceTransfer_v6r3
constexpr
auto
scalar_per_access
=
generate_sequence
(
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
auto
make_forward_steps
=
[
&
](
auto
desc
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
Index
forward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
forward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
desc
,
forward_step_idx
);
},
Number
<
nDim
>
{});
};
auto
make_backward_steps
=
[
&
](
auto
desc
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
Index
backward_step_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
backward_step_idx
(
j
)
=
(
i
.
value
==
j
.
value
)
?
-
scalar_per_access
[
i
]
:
0
;
});
return
make_tensor_coordinate_step
(
desc
,
backward_step_idx
);
},
Number
<
nDim
>
{});
};
// make forward steps
const
auto
src0_forward_steps
=
make_forward_steps
(
src0_desc
);
const
auto
src1_forward_steps
=
make_forward_steps
(
src1_desc
);
const
auto
src2_forward_steps
=
make_forward_steps
(
src2_desc
);
const
auto
dst_forward_steps
=
make_forward_steps
(
dst_desc
);
// make backward steps
const
auto
src0_backward_steps
=
make_backward_steps
(
src0_desc
);
const
auto
src1_backward_steps
=
make_backward_steps
(
src1_desc
);
const
auto
src2_backward_steps
=
make_backward_steps
(
src2_desc
);
const
auto
dst_backward_steps
=
make_backward_steps
(
dst_desc
);
// loop over slice window
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
// judge move forward or move backward
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_idx
[
I0
];
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_idx
[
j
];
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// loop over space-filling curve
static_for
<
0
,
num_accesses
,
1
>
{}([
&
](
auto
idx_1d
)
{
using
src0_vector_type
=
vector_type_maker_t
<
Src0Data
,
ScalarPerVector
>
;
using
src0_vector_type
=
vector_type_maker_t
<
Src0Data
,
ScalarPerVector
>
;
using
src0_vector_t
=
typename
src0_vector_type
::
type
;
using
src0_vector_t
=
typename
src0_vector_type
::
type
;
...
@@ -224,72 +160,24 @@ struct ThreadwiseTensorSliceTransfer_v6r3
...
@@ -224,72 +160,24 @@ struct ThreadwiseTensorSliceTransfer_v6r3
const
bool
is_dst_valid
=
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
dst_desc
,
dst_coord_
);
// copy data from dst_vector into dst_buf
dst_buf
.
template
Update
<
DstInMemOp
,
dst_vector_t
>(
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
Set
)
dst_coord_
.
GetOffset
(),
{
is_dst_valid
,
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
}
else
if
constexpr
(
DstInMemOp
==
InMemoryDataOperationEnum_t
::
AtomicAdd
)
{
dst_buf
.
template
AtomicAdd
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
}
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
// move coordinate
if
constexpr
(
idx_1d
.
value
!=
num_accesses
-
1
)
{
{
StaticallyIndexedArray
<
bool
,
nDim
>
move_on_dim_
;
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
move_tensor_coordinate
(
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
src0_desc
,
src0_coord_
,
make_tensor_coordinate_step
(
src0_desc
,
forward_step
));
move_on_dim_
(
i
)
=
ordered_access_idx
[
i
]
<
ordered_access_lengths
[
i
]
-
1
;
move_tensor_coordinate
(
src1_desc
,
src1_coord_
,
make_tensor_coordinate_step
(
src1_desc
,
forward_step
));
static_for
<
i
+
1
,
nDim
,
1
>
{}([
&
](
auto
j
)
{
move_tensor_coordinate
(
move_on_dim_
(
i
)
&=
ordered_access_idx
[
j
]
==
ordered_access_lengths
[
j
]
-
1
;
src2_desc
,
src2_coord_
,
make_tensor_coordinate_step
(
src2_desc
,
forward_step
));
});
move_tensor_coordinate
(
});
dst_desc
,
dst_coord_
,
make_tensor_coordinate_step
(
dst_desc
,
forward_step
));
return
move_on_dim_
;
}
}
();
// move coordinate
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
move_on_dim
[
i
])
{
if
constexpr
(
forward_sweep
[
i
])
{
move_tensor_coordinate
(
src0_desc
,
src0_coord_
,
src0_forward_steps
[
dim_access_order
[
i
]]);
move_tensor_coordinate
(
src1_desc
,
src1_coord_
,
src1_forward_steps
[
dim_access_order
[
i
]]);
move_tensor_coordinate
(
src2_desc
,
src2_coord_
,
src2_forward_steps
[
dim_access_order
[
i
]]);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_forward_steps
[
dim_access_order
[
i
]]);
}
else
{
move_tensor_coordinate
(
src0_desc
,
src0_coord_
,
src0_backward_steps
[
dim_access_order
[
i
]]);
move_tensor_coordinate
(
src1_desc
,
src1_coord_
,
src1_backward_steps
[
dim_access_order
[
i
]]);
move_tensor_coordinate
(
src2_desc
,
src2_coord_
,
src2_backward_steps
[
dim_access_order
[
i
]]);
move_tensor_coordinate
(
dst_desc
,
dst_coord_
,
dst_backward_steps
[
dim_access_order
[
i
]]);
}
}
});
});
});
// move coordinate back to slice origin (or not)
// move coordinate back to slice origin (or not)
...
@@ -328,59 +216,18 @@ struct ThreadwiseTensorSliceTransfer_v6r3
...
@@ -328,59 +216,18 @@ struct ThreadwiseTensorSliceTransfer_v6r3
__device__
static
constexpr
auto
GetCoordinateResetStep
()
__device__
static
constexpr
auto
GetCoordinateResetStep
()
{
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr
auto
scalar_per_access
=
generate_sequence
(
constexpr
auto
scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
VectorDim
,
ScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
scalar_per_access
;
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
remove_cv_t
<
decltype
(
scalar_per_access
)
>>
;
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
// judge move forward or move backward during the last iteration
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
index_t
tmp
=
ordered_access_lengths
[
I0
]
-
1
;
static_for
<
1
,
i
,
1
>
{}([
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_lengths
[
j
]
-
1
;
});
forward_sweep_
(
i
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate data index after last iteration in Run(), if it has not being reset
constexpr
auto
data_idx
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
ordered_idx
(
i
)
=
forward_sweep
[
i
]
?
ordered_access_lengths
[
i
]
-
1
:
0
;
});
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
scalar_per_access
;
}();
//
constexpr
auto
reset_data_step
=
[
&
]()
{
Index
reset_data_step_
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
i
)
{
reset_data_step_
(
i
)
=
-
data_idx
[
i
];
});
return
reset_data_step_
;
constexpr
auto
num_accesses
=
SpaceFillingCurve
::
GetNumOfAccess
();
}();
constexpr
auto
reset_step
=
SpaceFillingCurve
::
GetStepBetween
(
Number
<
num_accesses
-
1
>
{},
Number
<
0
>
{});
return
reset_
data_
step
;
return
reset_step
;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
...
...
composable_kernel/include/utility/dynamic_buffer.hpp
View file @
c2976d7a
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "amd_buffer_addressing.hpp"
#include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp"
#include "c_style_pointer_cast.hpp"
#include "config.hpp"
#include "enable_if.hpp"
#include "enable_if.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -108,6 +109,30 @@ struct DynamicBuffer
...
@@ -108,6 +109,30 @@ struct DynamicBuffer
}
}
}
}
template
<
InMemoryDataOperationEnum_t
Op
,
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
Update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
InMemoryDataOperationEnum_t
::
Set
)
{
this
->
template
Set
<
X
>(
i
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
InMemoryDataOperationEnum_t
::
AtomicAdd
)
{
this
->
template
AtomicAdd
<
X
>(
i
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
InMemoryDataOperationEnum_t
::
Add
)
{
auto
tmp
=
this
->
template
Get
<
X
>(
i
,
is_valid_element
);
this
->
template
Set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
// tmp += x;
// this->template Set<X>(i, is_valid_element, tmp);
}
}
template
<
typename
X
,
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
...
...
composable_kernel/include/utility/magic_division.hpp
View file @
c2976d7a
...
@@ -25,21 +25,30 @@ struct MagicDivision
...
@@ -25,21 +25,30 @@ struct MagicDivision
// uint32_t
// uint32_t
__host__
__device__
static
constexpr
auto
CalculateMagicNumbers
(
uint32_t
divisor
)
__host__
__device__
static
constexpr
auto
CalculateMagicNumbers
(
uint32_t
divisor
)
{
{
// assert(divisior >= 1 && divisior <= INT32_MAX);
// WARNING: magic division is only applicable for division inside this range.
uint32_t
shift
=
0
;
// You should use the return value of CalculateMagicNumbers, if division is not inside this
for
(
shift
=
0
;
shift
<
32
;
++
shift
)
// range. The "else" logic below is to quiet down run-time error.
if
(
divisor
>=
1
&&
divisor
<=
INT32_MAX
)
{
{
if
((
1U
<<
shift
)
>=
divisor
)
uint32_t
shift
=
0
;
for
(
shift
=
0
;
shift
<
32
;
++
shift
)
{
{
break
;
if
((
1U
<<
shift
)
>=
divisor
)
{
break
;
}
}
}
}
uint64_t
one
=
1
;
uint64_t
one
=
1
;
uint64_t
multiplier
=
((
one
<<
32
)
*
((
one
<<
shift
)
-
divisor
))
/
divisor
+
1
;
uint64_t
multiplier
=
((
one
<<
32
)
*
((
one
<<
shift
)
-
divisor
))
/
divisor
+
1
;
// assert(multiplier <= 0xffffffffUL);
// assert(multiplier <= 0xffffffffUL);
return
make_tuple
(
uint32_t
(
multiplier
),
shift
);
return
make_tuple
(
uint32_t
(
multiplier
),
shift
);
}
else
{
return
make_tuple
(
uint32_t
(
0
),
uint32_t
(
0
));
}
}
}
__host__
__device__
static
constexpr
uint32_t
CalculateMagicMultiplier
(
uint32_t
divisor
)
__host__
__device__
static
constexpr
uint32_t
CalculateMagicMultiplier
(
uint32_t
divisor
)
...
...
composable_kernel/include/utility/tensor_space_filling_curve.hpp
View file @
c2976d7a
#ifndef TENSOR_SPACE_FILLING_CURVE_HPP
#define TENSOR_SPACE_FILLING_CURVE_HPP
#include "math.hpp"
#include "math.hpp"
#include "sequence.hpp"
#include "sequence.hpp"
#include "sequence_helper.hpp"
#include "tensor_adaptor.hpp"
#include "tensor_adaptor.hpp"
#include "statically_indexed_array_multi_index.hpp"
#include "statically_indexed_array_multi_index.hpp"
#include "tuple_helper.hpp"
#include "tuple_helper.hpp"
...
@@ -37,13 +41,25 @@ struct SpaceFillingCurve
...
@@ -37,13 +41,25 @@ struct SpaceFillingCurve
ScalarPerVector
;
ScalarPerVector
;
}
}
template
<
index_t
AccessIdx1dBegin
,
index_t
AccessIdx1dEnd
>
static
__device__
__host__
constexpr
auto
GetStepBetween
(
Number
<
AccessIdx1dBegin
>
,
Number
<
AccessIdx1dEnd
>
)
{
static_assert
(
AccessIdx1dBegin
>=
0
,
"1D index should be non-negative"
);
static_assert
(
AccessIdx1dBegin
<
GetNumOfAccess
(),
"1D index should be larger than 0"
);
static_assert
(
AccessIdx1dEnd
>=
0
,
"1D index should be non-negative"
);
static_assert
(
AccessIdx1dEnd
<
GetNumOfAccess
(),
"1D index should be larger than 0"
);
constexpr
auto
idx_begin
=
GetIndex
(
Number
<
AccessIdx1dBegin
>
{});
constexpr
auto
idx_end
=
GetIndex
(
Number
<
AccessIdx1dEnd
>
{});
return
idx_end
-
idx_begin
;
}
template
<
index_t
AccessIdx1d
>
template
<
index_t
AccessIdx1d
>
static
__device__
__host__
constexpr
auto
GetForwardStep
(
Number
<
AccessIdx1d
>
)
static
__device__
__host__
constexpr
auto
GetForwardStep
(
Number
<
AccessIdx1d
>
)
{
{
static_assert
(
AccessIdx1d
<
GetNumOfAccess
(),
"1D index should be larger than 0"
);
constexpr
auto
idx_curr
=
GetIndex
(
Number
<
AccessIdx1d
>
{});
return
GetStepBetween
(
Number
<
AccessIdx1d
>
{},
Number
<
AccessIdx1d
+
1
>
{});
constexpr
auto
idx_next
=
GetIndex
(
Number
<
AccessIdx1d
+
1
>
{});
return
idx_next
-
idx_curr
;
}
}
template
<
index_t
AccessIdx1d
>
template
<
index_t
AccessIdx1d
>
...
@@ -51,9 +67,7 @@ struct SpaceFillingCurve
...
@@ -51,9 +67,7 @@ struct SpaceFillingCurve
{
{
static_assert
(
AccessIdx1d
>
0
,
"1D index should be larger than 0"
);
static_assert
(
AccessIdx1d
>
0
,
"1D index should be larger than 0"
);
constexpr
auto
idx_curr
=
GetIndex
(
Number
<
AccessIdx1d
>
{});
return
GetStepBetween
(
Number
<
AccessIdx1d
>
{},
Number
<
AccessIdx1d
-
1
>
{});
constexpr
auto
idx_prev
=
GetIndex
(
Number
<
AccessIdx1d
-
1
>
{});
return
idx_prev
-
idx_curr
;
}
}
template
<
index_t
AccessIdx1d
>
template
<
index_t
AccessIdx1d
>
...
@@ -129,3 +143,4 @@ struct SpaceFillingCurve
...
@@ -129,3 +143,4 @@ struct SpaceFillingCurve
};
};
}
// namespace ck
}
// namespace ck
#endif
device_operation/CMakeLists.txt
View file @
c2976d7a
...
@@ -101,16 +101,25 @@ set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
...
@@ -101,16 +101,25 @@ set(DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp;
)
)
add_library
(
device_gemm_instance SHARED
${
DEVICE_GEMM_INSTANCE_SOURCE
}
)
# device_conv2d_bwd_data_instance
add_library
(
device_gemm_bias_2d_instance SHARED
${
DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE
}
)
set
(
DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE
add_library
(
device_gemm_bias_relu_instance SHARED
${
DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE
}
)
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
add_library
(
device_gemm_bias_relu_add_instance SHARED
${
DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
${
PROJECT_SOURCE_DIR
}
/device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
)
add_library
(
device_gemm_instance SHARED
${
DEVICE_GEMM_INSTANCE_SOURCE
}
)
add_library
(
device_gemm_bias_relu_instance SHARED
${
DEVICE_GEMM_BIAS_RELU_INSTANCE_SOURCE
}
)
add_library
(
device_gemm_bias_relu_add_instance SHARED
${
DEVICE_GEMM_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
add_library
(
device_batched_gemm_instance SHARED
${
DEVICE_BATCHED_GEMM_INSTANCE_SOURCE
}
)
add_library
(
device_batched_gemm_instance SHARED
${
DEVICE_BATCHED_GEMM_INSTANCE_SOURCE
}
)
add_library
(
device_conv1d_fwd_instance SHARED
${
DEVICE_CONV1D_FWD_INSTANCE_SOURCE
}
)
add_library
(
device_conv1d_fwd_instance SHARED
${
DEVICE_CONV1D_FWD_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_instance SHARED
${
DEVICE_CONV2D_FWD_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_instance SHARED
${
DEVICE_CONV2D_FWD_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_relu_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_relu_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_relu_add_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_relu_add_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_ADD_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_relu_atomic_add_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_fwd_bias_relu_atomic_add_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
}
)
add_library
(
device_gemm_bias_2d_instance SHARED
${
DEVICE_GEMM_BIAS_2D_INSTANCE_SOURCE
}
)
add_library
(
device_conv2d_bwd_data_instance SHARED
${
DEVICE_CONV2D_BWD_DATA_INSTANCE_SOURCE
}
)
target_include_directories
(
device_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
device_gemm_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
device_gemm_bias_2d_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
device_gemm_bias_2d_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
...
@@ -122,6 +131,7 @@ target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $<BUILD_INTE
...
@@ -122,6 +131,7 @@ target_include_directories(device_conv2d_fwd_instance SYSTEM PUBLIC $<BUILD_INTE
target_include_directories
(
device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
device_conv2d_fwd_bias_relu_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
device_conv2d_fwd_bias_relu_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
device_conv2d_fwd_bias_relu_atomic_add_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_include_directories
(
device_conv2d_bwd_data_instance SYSTEM PUBLIC $<BUILD_INTERFACE:
${
HALF_INCLUDE_DIR
}
>
)
target_compile_features
(
device_gemm_instance PUBLIC
)
target_compile_features
(
device_gemm_instance PUBLIC
)
target_compile_features
(
device_gemm_bias_2d_instance PUBLIC
)
target_compile_features
(
device_gemm_bias_2d_instance PUBLIC
)
...
@@ -133,6 +143,7 @@ target_compile_features(device_conv2d_fwd_instance PUBLIC)
...
@@ -133,6 +143,7 @@ target_compile_features(device_conv2d_fwd_instance PUBLIC)
target_compile_features
(
device_conv2d_fwd_bias_relu_instance PUBLIC
)
target_compile_features
(
device_conv2d_fwd_bias_relu_instance PUBLIC
)
target_compile_features
(
device_conv2d_fwd_bias_relu_add_instance PUBLIC
)
target_compile_features
(
device_conv2d_fwd_bias_relu_add_instance PUBLIC
)
target_compile_features
(
device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC
)
target_compile_features
(
device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC
)
target_compile_features
(
device_conv2d_bwd_data_instance PUBLIC
)
set_target_properties
(
device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_gemm_bias_2d_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
...
@@ -144,6 +155,7 @@ set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT
...
@@ -144,6 +155,7 @@ set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT
set_target_properties
(
device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_conv2d_fwd_bias_relu_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_conv2d_fwd_bias_relu_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
set_target_properties
(
device_conv2d_bwd_data_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_gemm_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_gemm_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_gemm_bias_2d_instance LIBRARY DESTINATION lib
)
...
@@ -155,3 +167,4 @@ install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib)
...
@@ -155,3 +167,4 @@ install(TARGETS device_conv2d_fwd_instance LIBRARY DESTINATION lib)
install
(
TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_bias_relu_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_bias_relu_add_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib
)
install
(
TARGETS device_conv2d_bwd_data_instance LIBRARY DESTINATION lib
)
device_operation/include/convolution_backward_data_specialization.hpp
0 → 100644
View file @
c2976d7a
#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
enum
ConvolutionBackwardDataSpecialization_t
{
Default
,
Filter1x1Stride1Pad0
,
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
device_operation/include/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
0 → 100644
View file @
c2976d7a
#ifndef DEVICE_CONV2D_BWD_DATA_XDL_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_BWD_DATA_XDL_NHWC_KYXC_NHWK_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_base.hpp"
#include "device_conv_bwd_data.hpp"
#include "convolution_backward_data_specialization.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
template
<
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
AccDataType
,
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
,
ConvolutionBackwardDataSpecialization_t
ConvBackwardDataSpecialization
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXdl
,
ck
::
index_t
NPerXdl
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
:
public
DeviceConvBwdData
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>
{
using
DeviceOp
=
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
;
using
ADataType
=
OutDataType
;
using
BDataType
=
WeiDataType
;
using
CDataType
=
InDataType
;
// TODO make A/B datatype different
using
ABDataType
=
InDataType
;
static
constexpr
index_t
NDimSpatial
=
2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static_assert
((
K1
%
ABlockTransferThreadClusterLengths_K0_M_K1
{}[
I2
])
%
ABlockTransferSrcScalarPerVector
==
0
);
static_assert
((
NPerBlock
/
BBlockTransferThreadClusterLengths_K0_N_K1
{}[
I1
])
%
BBlockTransferSrcScalarPerVector
==
0
);
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
GemmK1Number
=
K1Number
;
static
auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
index_t
i_ytilda
,
index_t
i_xtilda
)
{
using
namespace
ck
;
const
index_t
Hi
=
input_spatial_lengths
[
0
];
const
index_t
Wi
=
input_spatial_lengths
[
1
];
const
index_t
Ho
=
output_spatial_lengths
[
0
];
const
index_t
Wo
=
output_spatial_lengths
[
1
];
const
index_t
Y
=
filter_spatial_lengths
[
0
];
const
index_t
X
=
filter_spatial_lengths
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
auto
K0
=
K
/
K1
;
const
auto
out_n_ho_wo_k_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Ho
,
Wo
,
K
));
const
auto
wei_k_y_x_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
,
X
,
C
));
const
auto
in_n_hi_wi_c_grid_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Hi
,
Wi
,
C
));
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// A: output tensor
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
)),
make_tuple
(
make_pass_through_transform
(
N
*
Ho
*
Wo
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
// B: weight tensor
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
// C: input tensor
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
else
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilda
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilda
);
const
auto
HTilda
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilda
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilda and WTilda that contribute to non-padding area of input tensor
const
auto
IHTildaSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilda
-
I1
)),
ConvStrideH
);
const
auto
IWTildaSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilda
-
I1
)),
ConvStrideW
);
const
auto
IHTildaSliceEnd
=
math
::
min
(
HTilda
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
const
auto
IWTildaSliceEnd
=
math
::
min
(
WTilda
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
const
auto
HTildaSlice
=
IHTildaSliceEnd
-
IHTildaSliceBegin
;
const
auto
WTildaSlice
=
IWTildaSliceEnd
-
IWTildaSliceBegin
;
// GemmK is different for each GEMM
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilda
,
YTilda
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilda
,
XTilda
);
// A: output tensor
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_n_ho_wo_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilda
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilda
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydot_htilda_xdot_wtilda_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilda
,
IHTildaSliceBegin
,
HTildaSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilda
,
IWTildaSliceBegin
,
WTildaSlice
),
make_unmerge_transform
(
make_tuple
(
K0
,
K1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
>
{}));
const
auto
out_gemmk0_gemmm_gemmk1_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildaslice_xdotslice_wtildaslice_k0_k1_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// B weight tensor
const
auto
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_y_x_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilda
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilda
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilda_xdot_xtilda_c_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1
)),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_ytilda
),
make_freeze_transform
(
i_xtilda
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
4
>
{}));
const
auto
wei_gemmk0_gemmn_gemmk1_grid_desc
=
transform_tensor_descriptor
(
wei_k0_k1_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K0
)),
make_pass_through_transform
(
C
),
make_pass_through_transform
(
K1
)),
make_tuple
(
Sequence
<
2
,
3
,
0
>
{},
Sequence
<
4
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// C: input tensor
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
YTilda
,
HTilda
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
XTilda
,
WTilda
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildaslice_wtildaslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_freeze_transform
(
i_ytilda
),
make_slice_transform
(
HTilda
,
IHTildaSliceBegin
,
HTildaSlice
),
make_freeze_transform
(
i_xtilda
),
make_slice_transform
(
WTilda
,
IWTildaSliceBegin
,
WTildaSlice
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<>
{},
Sequence
<
1
>
{},
Sequence
<>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmn_grid_desc
=
transform_tensor_descriptor
(
in_n_htildaslice_wtildaslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTildaSlice
,
WTildaSlice
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
make_tuple
(
out_gemmk0_gemmm_gemmk1_grid_desc
,
wei_gemmk0_gemmn_gemmk1_grid_desc
,
in_gemmm_gemmn_grid_desc
);
}
}
// function end
using
ABCGridDescs
=
decltype
(
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
1
,
1
,
1
,
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
0
,
0
));
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
<
BlockSize
,
ABDataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
InMemoryDataOperationEnum_t
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXdl
,
NPerXdl
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
Sequence
<
2
,
3
,
0
,
1
,
7
,
5
,
4
,
6
>
,
// CThreadTransferSrcDstAccessOrder,
7
,
// CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
ck
::
index_t
M01
,
ck
::
index_t
N01
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_wei_grid
},
p_c_grid_
{
p_in_grid
},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
in_element_op
},
Conv_N_
{
N
},
Conv_K_
{
K
},
Conv_C_
{
C
},
input_spatial_lengths_
{
input_spatial_lengths
},
filter_spatial_lengths_
{
filter_spatial_lengths
},
output_spatial_lengths_
{
output_spatial_lengths
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilda
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilda
=
ConvStrideW
/
GcdStrideDilationW
;
for
(
index_t
i_ytilda
=
0
;
i_ytilda
<
YTilda
;
++
i_ytilda
)
{
for
(
index_t
i_xtilda
=
0
;
i_xtilda
<
XTilda
;
++
i_xtilda
)
{
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
(
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
i_ytilda
,
i_xtilda
);
a_grid_desc_k0_m_k1_container_
.
push_back
(
descs
[
I0
]);
b_grid_desc_k0_n_k1_container_
.
push_back
(
descs
[
I1
]);
c_grid_desc_m_n_container_
.
push_back
(
descs
[
I2
]);
if
(
GridwiseGemm
::
CheckValidity
(
descs
[
I0
],
descs
[
I1
],
descs
[
I2
],
M01_
,
N01_
))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
.
push_back
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
descs
[
I2
]));
block_2_ctile_map_container_
.
push_back
(
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
descs
[
I2
],
M01
,
N01
));
}
}
}
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
std
::
vector
<
AGridDesc_K0_M_K1
>
a_grid_desc_k0_m_k1_container_
;
std
::
vector
<
BGridDesc_K0_N_K1
>
b_grid_desc_k0_n_k1_container_
;
std
::
vector
<
CGridDesc_M_N
>
c_grid_desc_m_n_container_
;
std
::
vector
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
;
std
::
vector
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
block_2_ctile_map_container_
;
index_t
M01_
;
index_t
N01_
;
OutElementwiseOperation
a_element_op_
;
WeiElementwiseOperation
b_element_op_
;
InElementwiseOperation
c_element_op_
;
// for checking IsSupportedArgument()
index_t
Conv_N_
;
index_t
Conv_K_
;
index_t
Conv_C_
;
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths_
;
std
::
vector
<
ck
::
index_t
>
conv_filter_strides_
;
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations_
;
std
::
vector
<
ck
::
index_t
>
input_left_pads_
;
std
::
vector
<
ck
::
index_t
>
input_right_pads_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
nrepeat
=
1
;
float
ave_time
=
0
;
for
(
size_t
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_container_{"
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_container_{"
<<
arg
.
b_grid_desc_k0_n_k1_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_container_
[
i
].
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_container_{ "
<<
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_container_
[
i
].
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I1
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I2
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I3
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I4
)
<<
", "
<<
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
].
GetLength
(
I5
)
<<
" ) "
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m_n_container_
[
i
],
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"
);
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_container_
[
i
]);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
OutElementwiseOperation
,
WeiElementwiseOperation
,
InElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
true
>
;
ave_time
+=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
],
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_container_
[
i
]);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
OutElementwiseOperation
,
WeiElementwiseOperation
,
InElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
false
>
;
ave_time
+=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_
[
i
],
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_container_
[
i
]);
}
}
return
ave_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization_t
::
Filter1x1Stride1Pad0
)
{
// check if it's 1x1, stride=1 pad = 0 conv
if
(
!
(
arg
.
filter_spatial_lengths_
[
0
]
==
1
&&
arg
.
filter_spatial_lengths_
[
1
]
==
1
&&
arg
.
conv_filter_strides_
[
0
]
==
1
&&
arg
.
conv_filter_strides_
[
1
]
==
1
&&
arg
.
input_left_pads_
[
0
]
==
0
&&
arg
.
input_left_pads_
[
1
]
==
0
&&
arg
.
input_right_pads_
[
0
]
==
0
&&
arg
.
input_right_pads_
[
1
]
==
0
))
{
return
false
;
}
}
// vector load A/B matrix from global memory
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
BBlockTransferSrcVectorDim
==
1
&&
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
&&
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
// vector store C matrix into global memory
if
(
!
(
arg
.
Conv_C_
%
CThreadTransferDstScalarPerVector
==
0
))
{
return
false
;
}
// Gridwise GEMM size
for
(
int
i
=
0
;
i
<
arg
.
a_grid_desc_k0_m_k1_container_
.
size
();
i
++
)
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_container_
[
i
],
arg
.
b_grid_desc_k0_n_k1_container_
[
i
],
arg
.
c_grid_desc_m_n_container_
[
i
],
arg
.
M01_
,
arg
.
N01_
))
{
return
false
;
}
}
return
true
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
InDataType
*
p_in_grid
,
const
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
{
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_out_grid
,
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
1
,
1
,
in_element_op
,
wei_element_op
,
out_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_in_grid
,
const
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
InDataType
*>
(
p_in_grid
),
static_cast
<
const
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
N
,
K
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
1
,
1
,
in_element_op
,
wei_element_op
,
out_element_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
device_operation/include/device_conv_bwd_data.hpp
0 → 100644
View file @
c2976d7a
#ifndef DEVICE_CONV_BWD_DATA_HPP
#define DEVICE_CONV_BWD_DATA_HPP
#include <iostream>
#include "device_base.hpp"
#include "element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
struct
DeviceConvBwdData
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_in
,
const
void
*
p_wei
,
const
void
*
p_out
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
C
,
std
::
vector
<
ck
::
index_t
>
input_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
filter_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
output_spatial_lengths
,
std
::
vector
<
ck
::
index_t
>
conv_filter_strides
,
std
::
vector
<
ck
::
index_t
>
conv_filter_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
InElementwiseOperation
,
typename
WeiElementwiseOperation
,
typename
OutElementwiseOperation
>
using
DeviceConvBwdDataPtr
=
std
::
unique_ptr
<
DeviceConvBwdData
<
InElementwiseOperation
,
WeiElementwiseOperation
,
OutElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
0 → 100644
View file @
c2976d7a
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_data_instance
{
using
BF16
=
ushort
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdDataDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Default
;
static
constexpr
auto
ConvBwdDataFilter1x1Stride1Pad0
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Filter1x1Stride1Pad0
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
=
std
::
tuple
<
// clang-format off
//####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
using
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances
=
std
::
tuple
<
// clang-format off
//####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
BF16
,
BF16
,
BF16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
(
std
::
vector
<
DeviceConvBwdDataPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances
{});
add_device_operation_instances
(
instances
,
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances
{});
}
}
// namespace device_conv2d_bwd_data_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
0 → 100644
View file @
c2976d7a
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_data_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdDataDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Default
;
static
constexpr
auto
ConvBwdDataFilter1x1Stride1Pad0
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Filter1x1Stride1Pad0
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
// clang-format off
//####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
#if !CK_WORKAROUND_SWDEV_325164
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
#endif
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
using
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances
=
std
::
tuple
<
// clang-format off
//####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
true
,
7
,
1
>
// clang-format on
>
;
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvBwdDataPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
{});
add_device_operation_instances
(
instances
,
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances
{});
}
}
// namespace device_conv2d_bwd_data_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instance.cpp
0 → 100644
View file @
c2976d7a
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_data_instance
{
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdDataDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Default
;
static
constexpr
auto
ConvBwdDataFilter1x1Stride1Pad0
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Filter1x1Stride1Pad0
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
//####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
64
,
64
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
128
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
32
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
64
,
64
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
64
,
32
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
// clang-format on
>
;
using
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances
=
std
::
tuple
<
// clang-format off
//####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
256
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
256
,
4
,
4
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
128
,
128
,
4
,
4
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
64
,
128
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
64
,
64
,
64
,
4
,
4
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
64
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
64
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
128
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
32
,
128
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
64
,
64
,
32
,
4
,
4
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
4
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F32
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
64
,
32
,
64
,
4
,
4
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
4
,
4
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
true
,
7
,
1
>
// clang-format on
>
;
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
DeviceConvBwdDataPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
{});
add_device_operation_instances
(
instances
,
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances
{});
}
}
// namespace device_conv2d_bwd_data_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
device_operation/src/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
0 → 100644
View file @
c2976d7a
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_bwd_data_instance
{
using
DataType
=
int8_t
;
using
AccType
=
int32_t
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdDataDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Default
;
static
constexpr
auto
ConvBwdDataFilter1x1Stride1Pad0
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization_t
::
Filter1x1Stride1Pad0
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
=
std
::
tuple
<
// clang-format off
//####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
128
,
256
,
4
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
128
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
64
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
64
,
64
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
256
,
64
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
128
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
128
,
32
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
64
,
64
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataDefault
,
64
,
32
,
64
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
7
,
1
>
// clang-format on
>
;
using
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances
=
std
::
tuple
<
// clang-format off
//#####################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#####################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Data| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#####################################################################| | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//#####################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
256
,
4
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
128
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
64
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
64
,
64
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
64
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
128
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
128
,
32
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
64
,
64
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
7
,
1
>
,
DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
DataType
,
DataType
,
DataType
,
AccType
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvBwdDataFilter1x1Stride1Pad0
,
64
,
32
,
64
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
2
,
0
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
7
,
1
>
// clang-format on
>
;
void
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
DeviceConvBwdDataPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
{});
add_device_operation_instances
(
instances
,
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances
{});
}
}
// namespace device_conv2d_bwd_data_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
View file @
c2976d7a
...
@@ -32,19 +32,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple<
...
@@ -32,19 +32,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances = std::tuple<
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
// clang-format on
// clang-format on
>
;
>
;
...
@@ -54,19 +54,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = std::tuple<
...
@@ -54,19 +54,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances = std::tuple<
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
// clang-format on
// clang-format on
>
;
>
;
...
@@ -76,19 +76,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple
...
@@ -76,19 +76,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances = std::tuple
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
ushort
,
ushort
,
ushort
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
7
,
1
>
// clang-format on
// clang-format on
>
;
>
;
...
...
device_operation/src/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
View file @
c2976d7a
...
@@ -32,19 +32,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple<
...
@@ -32,19 +32,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances = std::tuple<
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
256
,
4
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
256
,
4
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
64
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
64
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
64
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
64
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
64
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
256
,
64
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
128
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
32
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
128
,
32
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
64
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
64
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
32
,
64
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwdDefault
,
64
,
32
,
64
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
// clang-format on
// clang-format on
>
;
>
;
...
@@ -54,19 +54,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = std::tuple<
...
@@ -54,19 +54,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances = std::tuple<
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
256
,
4
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
256
,
4
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
64
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
64
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
64
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
64
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
64
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
256
,
64
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
128
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
32
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
128
,
32
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
64
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
64
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
32
,
64
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1P0
,
64
,
32
,
64
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
// clang-format on
// clang-format on
>
;
>
;
...
@@ -76,19 +76,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple
...
@@ -76,19 +76,19 @@ using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances = std::tuple
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
256
,
4
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
256
,
4
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
64
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
64
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
64
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
64
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
64
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
256
,
64
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
128
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
32
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
128
,
32
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
64
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
64
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
32
,
64
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
int8_t
,
int8_t
,
int8_t
,
int32_t
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvFwd1x1S1P0
,
64
,
32
,
64
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
// clang-format on
// clang-format on
>
;
>
;
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment