Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a5137505
Unverified
Commit
a5137505
authored
Jan 06, 2025
by
arai713
Committed by
GitHub
Jan 06, 2025
Browse files
Merge branch 'develop' into codegen_hiprtc
parents
208a1dab
888317e6
Changes
259
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1253 additions
and
728 deletions
+1253
-728
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+247
-7
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+52
-22
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+1
-1
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+749
-526
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+2
-1
include/ck/utility/amd_ck_fp8.hpp
include/ck/utility/amd_ck_fp8.hpp
+18
-6
include/ck/utility/amd_inline_asm.hpp
include/ck/utility/amd_inline_asm.hpp
+22
-1
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+37
-0
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+4
-2
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+1
-1
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+4
-2
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+14
-1
include/ck_tile/README.md
include/ck_tile/README.md
+3
-0
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+2
-2
include/ck_tile/core/container/meta_data_buffer.hpp
include/ck_tile/core/container/meta_data_buffer.hpp
+3
-3
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+1
-0
include/ck_tile/core/utility/amd_address_space.hpp
include/ck_tile/core/utility/amd_address_space.hpp
+37
-0
include/ck_tile/host/arg_parser.hpp
include/ck_tile/host/arg_parser.hpp
+44
-2
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+11
-151
No files found.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
a5137505
...
@@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
__device__
constexpr
ThreadwiseTensorSliceTransfer_v4
(
const
Index
&
src_ref_idx
)
__device__
constexpr
ThreadwiseTensorSliceTransfer_v4
(
const
Index
&
src_ref_idx
)
:
src_ref_coord_
(
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
:
src_ref_coord_
(
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
))
{
{
...
@@ -1015,6 +1022,11 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1015,6 +1022,11 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
"wrong! Not divisible"
);
"wrong! Not divisible"
);
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
{
static_assert
(
SrcScalarPerVector
%
PackedSize
==
0
,
"pk data N cannot be 1"
);
}
}
}
template
<
typename
SrcRefToOriginDisplacement
,
template
<
typename
SrcRefToOriginDisplacement
,
...
@@ -1109,7 +1121,7 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1109,7 +1121,7 @@ struct ThreadwiseTensorSliceTransfer_v4
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
src_tmp_vector
;
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
/
PackedSize
>
src_tmp_vector
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
...
@@ -1120,7 +1132,8 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1120,7 +1132,8 @@ struct ThreadwiseTensorSliceTransfer_v4
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
{
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
(),
is_src_valid
);
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
()
/
PackedSize
,
is_src_valid
);
}
}
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
{
{
...
@@ -1133,9 +1146,236 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1133,9 +1146,236 @@ struct ThreadwiseTensorSliceTransfer_v4
});
});
}
}
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
f8_t
>::
value
&&
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
)
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
&&
{
SrcScalarPerVector
%
2
==
0
)
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
8
;
static_assert
(
SrcScalarPerVector
%
pack_size
==
0
,
""
);
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
/
PackedSize
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
f8_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
&&
SrcScalarPerVector
%
2
==
0
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
constexpr
index_t
pack_size
=
2
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack2
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
dst_tmp_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
src_tmp_vector
.
template
AsType
<
SrcData
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
});
}
// Fuse scale
template
<
typename
SrcRefToOriginDisplacement
,
typename
DstOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcRefToOriginDisplacement
&
,
const
SrcBuffer
&
src_buf
,
const
DstData
&
scale
,
const
DstDesc
&
,
const
DstOriginIdx
&
,
DstBuffer
&
dst_buf
)
const
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
SrcBuffer
::
type
>
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
DstBuffer
::
type
>
,
remove_cvref_t
<
DstData
>>::
value
,
"wrong! SrcBuffer or DstBuffer data type is wrong"
);
static_assert
(
DstBuffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcRefToOriginDisplacement
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstOriginIdx
>>::
value
,
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
"at compile-time"
);
// SrcDesc and DstDesc are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
constexpr
auto
src_ref_to_origin_disp_idx
=
to_multi_index
(
SrcRefToOriginDisplacement
{});
constexpr
auto
dst_origin_idx
=
to_multi_index
(
DstOriginIdx
{});
// scalar per access of each dim
constexpr
auto
src_scalar_per_access
=
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Number
<
SrcScalarPerVector
>
{};
}
else
{
return
Number
<
1
>
{};
}
},
Number
<
nDim
>
{});
// scalar step (if steping on SrcVectorDim) of each dim
constexpr
auto
src_scalar_step_in_vector
=
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
if
constexpr
(
i
==
SrcVectorDim
)
{
return
Number
<
1
>
{};
}
else
{
return
Number
<
0
>
{};
}
},
Number
<
nDim
>
{});
constexpr
auto
access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
static_ford
<
decltype
(
ordered_access_lengths
)
>
{}([
&
](
auto
ordered_access_idx
)
{
#if 0
// TODO: unable to compile
// position in slice window
constexpr auto data_to_origin_disp_idx =
container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
src_scalar_per_access;
#else
// position in slice window
constexpr
auto
data_to_origin_disp_idx
=
ordered_access_idx
.
ReorderGivenOld2New
(
dim_access_order
)
*
src_scalar_per_access
;
#endif
// src coordinate
constexpr
auto
src_ref_to_data_disp_idx
=
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
;
constexpr
auto
src_ref_to_data_disp_coord_step
=
make_tensor_coordinate_step
(
src_desc
,
src_ref_to_data_disp_idx
);
auto
src_data_coord
=
src_ref_coord_
;
move_tensor_coordinate
(
src_desc
,
src_data_coord
,
src_ref_to_data_disp_coord_step
);
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
/
PackedSize
>
src_tmp_vector
;
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_data_coord
);
// copy data from src_buf into src_tmp_vector
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_data_coord
.
GetOffset
()
/
PackedSize
,
is_src_valid
);
}
else
if
constexpr
(
SrcBuffer
::
IsStaticBuffer
())
{
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_ref_to_origin_disp_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
});
}
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
vector_type
<
DstData
,
2
>
scale_vector
;
scale_vector
.
template
AsType
<
DstData
>()(
Number
<
0
>
{})
=
scale
;
scale_vector
.
template
AsType
<
DstData
>()(
Number
<
1
>
{})
=
scale
;
constexpr
index_t
pack_size
=
8
;
static_assert
(
SrcScalarPerVector
%
pack_size
==
0
,
""
);
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
/
PackedSize
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
using
scale_v_t
=
typename
vector_type_maker_t
<
DstData
,
2
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
DequantPack8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
],
scale_vector
.
template
AsType
<
scale_v_t
>()[
Number
<
0
>
{}]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_origin_idx
+
data_to_origin_disp_idx
+
i
*
src_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
}
else
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
f8_t
>::
value
&&
is_same
<
remove_cvref_t
<
DstData
>
,
half_t
>::
value
&&
SrcScalarPerVector
%
2
==
0
)
{
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
// DstData)
...
@@ -1304,7 +1544,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
...
@@ -1304,7 +1544,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
ElementwiseOperation
element_op_
;
ElementwiseOperation
element_op_
;
};
};
// Specilized for
WMMA-Navi3
// Speci
a
lized for
gfx11
// A single Wave32 is composed by double row
// A single Wave32 is composed by double row
// Data exchange allowed between these two rows
// Data exchange allowed between these two rows
// This RowLane Dst buf will be filled from two Src buf
// This RowLane Dst buf will be filled from two Src buf
...
@@ -1439,7 +1679,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1439,7 +1679,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
ElementwiseOperation
element_op_
{};
ElementwiseOperation
element_op_
{};
};
};
// Specilized for
WMMA-Navi4
// Speci
a
lized for
gfx12
template
<
typename
SrcData
,
template
<
typename
SrcData
,
typename
DstData
,
typename
DstData
,
typename
SrcDesc
,
typename
SrcDesc
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
a5137505
...
@@ -31,8 +31,8 @@ template <typename SliceLengths,
...
@@ -31,8 +31,8 @@ template <typename SliceLengths,
typename
DstDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
SrcScalarPerVector
_
,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
_
,
index_t
SrcScalarStrideInVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
bool
SrcResetCoordinateAfterRun
,
// control whether to move back src coordinate after each
...
@@ -55,6 +55,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -55,6 +55,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
static
constexpr
auto
SrcScalarPerVector
=
Number
<
SrcScalarPerVector_
/
PackedSize
>
{};
static
constexpr
auto
DstScalarPerVector
=
Number
<
DstScalarPerVector_
/
PackedSize
>
{};
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1
(
__device__
constexpr
ThreadwiseTensorSliceTransfer_v3r1
(
const
SrcDesc
&
src_desc
,
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
,
const
Index
&
src_slice_origin
,
...
@@ -67,6 +77,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -67,6 +77,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_
(
src_element_op
),
src_element_op_
(
src_element_op
),
dst_element_op_
(
dst_element_op
)
dst_element_op_
(
dst_element_op
)
{
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
{
static_assert
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
remove_cvref_t
<
DstData
>>
,
"SrcData != DstData"
);
static_assert
(
SrcScalarPerVector_
%
PackedSize
==
0
&&
DstScalarPerVector_
%
PackedSize
==
0
,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"
);
static_assert
(
SrcVectorDim
==
DstVectorDim
,
"pk_i4_t does not support transpose"
);
}
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
@@ -95,11 +116,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -95,11 +116,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
static_assert
(
SliceLengths
::
At
(
SrcVectorDim
)
%
SrcScalarPerVector
==
0
,
static_assert
(
SliceLengths
::
At
(
SrcVectorDim
)
%
(
SrcScalarPerVector
_
)
==
0
,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"
);
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"
);
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
constexpr
auto
src_dim_access_order
=
SrcDimAccessOrder
{};
...
@@ -180,9 +201,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -180,9 +201,6 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_type
=
vector_type_maker_t
<
SrcData
,
SrcScalarPerVector
>
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
using
src_vector_t
=
typename
src_vector_type
::
type
;
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
true
)};
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_type
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
using
dst_vector_t
=
typename
dst_vector_type
::
type
;
dst_vector_type
op_r_v
;
dst_vector_type
op_r_v
;
...
@@ -193,17 +211,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -193,17 +211,22 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack8_invocable
)
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack8_invocable
)
return
math
::
min
(
8
,
SrcScalarPerVector
);
return
math
::
min
(
8
,
SrcScalarPerVector
);
}
}
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
else
if
constexpr
(
is_detected
<
is_pack4_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack4_invocable
)
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack4_invocable
)
return
math
::
min
(
4
,
SrcScalarPerVector
);
return
math
::
min
(
4
,
SrcScalarPerVector
);
}
}
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
else
if
constexpr
(
is_detected
<
is_pack2_invocable_t
,
decltype
(
src_element_op_
)
>::
value
)
{
{
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack2_invocable
)
if
constexpr
(
decltype
(
src_element_op_
)
::
is_pack2_invocable
)
return
math
::
min
(
2
,
SrcScalarPerVector
);
return
math
::
min
(
2
,
SrcScalarPerVector
);
}
}
return
1
;
else
{
return
1
;
}
};
};
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
constexpr
index_t
elem_op_vec_len
=
get_elem_op_vec_len
();
...
@@ -211,6 +234,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -211,6 +234,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using
src_elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
src_elem_op_vec_t
=
typename
vector_type
<
SrcData
,
elem_op_vec_len
>::
type
;
using
dst_elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
using
dst_elem_op_vec_t
=
typename
vector_type
<
DstData
,
elem_op_vec_len
>::
type
;
auto
src_vector_container
=
src_vector_type
{
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
()
/
PackedSize
,
true
)};
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
idx
)
{
static_for
<
0
,
SrcScalarPerVector
/
elem_op_vec_len
,
1
>
{}([
&
](
auto
idx
)
{
// apply the src elementwise op and convert to DstData under the hood if needed
// apply the src elementwise op and convert to DstData under the hood if needed
src_element_op_
(
op_r_v
.
template
AsType
<
dst_elem_op_vec_t
>()(
idx
),
src_element_op_
(
op_r_v
.
template
AsType
<
dst_elem_op_vec_t
>()(
idx
),
...
@@ -276,10 +302,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -276,10 +302,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
});
});
#else
#else
// OOB Check
// OOB Check
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -350,6 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -350,6 +375,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
(
is_same
<
f8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
(
is_same
<
f8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
SrcScalarPerVector
%
4
==
0
&&
DstScalarPerVector
%
4
==
0
)))
{
{
static_assert
(
!
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
,
"in-register transpose is not supported for pk_i4_t"
);
// each transpose does
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_
...
@@ -410,7 +437,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -410,7 +437,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
}
}
else
else
{
{
static_ford
<
SliceLengths
>
{}([
&
](
auto
idx
)
{
constexpr
auto
packed_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
PackedSize
>
{},
Number
<
nDim
>
{});
constexpr
auto
packed_access_lengths
=
SliceLengths
{}
/
packed_per_access
;
static_ford
<
decltype
(
packed_access_lengths
)
>
{}([
&
](
auto
idx
)
{
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
dst_thread_scratch_
(
idx
)
=
src_thread_scratch_tuple_
[
thread_scratch_id
][
idx
];
});
});
}
}
...
@@ -438,7 +470,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -438,7 +470,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// src scalar per access on each dim
// src scalar per access on each dim
// TODO: don't use this
// TODO: don't use this
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
@@ -526,13 +558,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -526,13 +558,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// apply DstElementwiseOperation
// apply DstElementwiseOperation
dst_element_op_
(
dst_v
,
dst_vector_container
.
template
AsType
<
DstData
>()[
i
]);
dst_element_op_
(
dst_v
,
dst_vector_container
.
template
AsType
<
DstData
>()[
i
]);
dst_vector_container
.
template
AsType
<
DstData
>()(
i
)
=
dst_v
;
});
});
// copy data from dst_vector_container to dst_buf
// copy data from dst_vector_container to dst_buf
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_buf
.
template
Set
<
dst_vector_t
>(
dst_coord_
.
GetOffset
(),
dst_coord_
.
GetOffset
()
/
PackedSize
,
is_dst_valid
,
is_dst_valid
,
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
dst_vector_container
.
template
AsType
<
dst_vector_t
>()[
I0
]);
...
@@ -586,7 +616,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -586,7 +616,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -644,7 +674,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -644,7 +674,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
// TODO: don't use lambda_scalar_per_access
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
@@ -730,7 +760,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -730,7 +760,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
__device__
static
constexpr
auto
GetSrcThreadScratchDescriptor
()
{
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -779,7 +809,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -779,7 +809,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
__device__
static
constexpr
auto
GetSrcOOBThreadScratchDescriptor
()
__device__
static
constexpr
auto
GetSrcOOBThreadScratchDescriptor
()
{
{
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
constexpr
auto
src_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
SrcScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
constexpr
auto
src_access_lengths
=
SliceLengths
{}
/
src_scalar_per_access
;
...
@@ -790,7 +820,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
...
@@ -790,7 +820,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{
{
// 1st stage of transforms
// 1st stage of transforms
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
_
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
constexpr
auto
dst_access_lengths
=
SliceLengths
{}
/
dst_scalar_per_access
;
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
a5137505
...
@@ -307,7 +307,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
...
@@ -307,7 +307,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// Wave mode dependent propety
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// * Fixed
in Navi3x
, Will be wave mode dependent on
Navi4x
// * Fixed
for gfx11
, Will be wave mode dependent on
gfx12
// static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
// static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_acc_vgprs_per_wave alone M direction
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
View file @
a5137505
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -13,245 +13,614 @@
...
@@ -13,245 +13,614 @@
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
{
template
<
template
<
index_t
NDimSpatial
,
index_t
NDimSpatial
,
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
ConvBwdDataSpecialization
,
index_t
AK1
,
index_t
BK1
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
bool
DoPadGemmM
,
bool
DoPadGemmN
,
typename
ALayout
,
typename
ALayout
,
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
ConvBwdDataSpecialization
>
typename
BLayout
,
constexpr
auto
make_out_grid_desc
(
const
index_t
N
,
typename
CLayout
,
const
index_t
Do
,
bool
SplitN
=
false
,
const
index_t
Ho
,
typename
ADataType
=
float
,
const
index_t
Wo
,
typename
CDataType
=
float
,
const
index_t
K
,
index_t
NumGroupsToMerge
=
1
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_strides
)
typename
IndexType
=
index_t
>
struct
TransformConvBwdDataToGemm_v1
{
{
const
auto
KStride
=
Number
<
1
>
{};
private:
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
)
static
constexpr
auto
NonSpatialDimsNum
=
Number
<
3
>
{};
{
const
index_t
NStride
=
out_g_n_k_wos_strides
[
1
];
const
index_t
HiStride
=
out_g_n_k_wos_strides
[
3
];
const
index_t
WiStride
=
out_g_n_k_wos_strides
[
4
];
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
K
),
static
constexpr
auto
DIdx
=
NonSpatialDimsNum
;
make_tuple
(
WiStride
,
KStride
));
static
constexpr
auto
HIdx
=
}
NDimSpatial
==
2
?
NonSpatialDimsNum
:
Number
<
NonSpatialDimsNum
+
1
>
{};
else
static
constexpr
auto
WIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
+
1
>
{}
:
Number
<
NonSpatialDimsNum
+
2
>
{};
static
constexpr
auto
ZIdx
=
NonSpatialDimsNum
;
static
constexpr
auto
YIdx
=
NDimSpatial
==
2
?
NonSpatialDimsNum
:
Number
<
NonSpatialDimsNum
+
1
>
{};
static
constexpr
auto
XIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
+
1
>
{}
:
Number
<
NonSpatialDimsNum
+
2
>
{};
template
<
typename
ConvDimsType
>
static
long_index_t
calculate_element_space_size_impl
(
const
ConvDimsType
&
lengths
,
const
ConvDimsType
&
strides
,
index_t
i
)
{
long_index_t
acc
=
1
;
for
(;
i
<
(
NDimSpatial
+
3
);
i
++
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Ho
,
Wo
,
K
),
acc
+=
make_tuple
(
NStride
,
HiStride
,
WiStride
,
KS
tride
)
);
static_cast
<
long_index_t
>
(
lengths
[
i
]
-
I1
)
*
static_cast
<
long_index_t
>
(
s
tride
s
[
i
]
);
}
}
return
acc
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGK
>
)
template
<
typename
ConvDimsType
>
static
IndexType
GetSplitedNSize
(
const
ConvDimsType
&
a_g_n_k_wos_lengths
,
const
ConvDimsType
&
a_g_n_k_wos_strides
,
const
ConvDimsType
&
c_g_n_c_wis_lengths
,
const
ConvDimsType
&
c_g_n_c_wis_strides
)
{
{
const
index_t
NStride
=
out_g_n_k_wos_strides
[
1
];
const
long_index_t
a_element_space_size
=
const
index_t
DoStride
=
out_g_n_k_wos_strides
[
3
];
calculate_element_space_size_impl
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
I1
);
const
index_t
HoStride
=
out_g_n_k_wos_strides
[
4
];
const
long_index_t
c_element_space_size
=
const
index_t
WoStride
=
out_g_n_k_wos_strides
[
5
];
calculate_element_space_size_impl
(
c_g_n_c_wis_lengths
,
c_g_n_c_wis_strides
,
I1
);
if
constexpr
(
ConvBwdDataSpecialization
==
const
long_index_t
element_space_size
=
math
::
max
(
a_element_space_size
*
sizeof
(
ADataType
),
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
c_element_space_size
*
sizeof
(
CDataType
));
Filter1x1Stride1Pad0
)
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
const
IndexType
N
=
a_g_n_k_wos_lengths
[
I1
];
if
(
element_space_size
>
TwoGB
)
{
{
// Minimum divisor of N to not exceed 2GB
const
auto
divisor
=
math
::
integer_divide_ceil
(
element_space_size
,
TwoGB
);
return
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
),
if
(
divisor
<=
static_cast
<
double
>
(
N
))
make_tuple
(
WoStride
,
KStride
));
{
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for
(
IndexType
least_divisor
=
divisor
;
least_divisor
*
least_divisor
<=
N
;
least_divisor
++
)
{
if
(
N
%
least_divisor
==
0
)
{
return
N
/
least_divisor
;
}
}
// Not found, process one Convolution N per block
return
1
;
}
else
{
// Not possible to support even after split N.
// Too large tensor.
return
N
;
}
}
}
else
else
{
{
return
make_naive_tensor_descriptor
(
// Split N is not needed.
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
K
),
return
N
;
make_tuple
(
NStride
,
DoStride
,
HoStride
,
WoStride
,
KStride
));
}
}
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
)
public:
__host__
__device__
constexpr
TransformConvBwdDataToGemm_v1
()
{}
template
<
typename
TransformConvBwdDataToGemm_v1Base
>
__host__
__device__
TransformConvBwdDataToGemm_v1
(
const
TransformConvBwdDataToGemm_v1Base
&
transform_conv_bwd_data_to_gemm_base
)
:
N_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
N_
)},
Di_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
Di_
)},
Hi_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
Hi_
)},
Wi_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
Wi_
)},
Do_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
Do_
)},
Ho_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
Ho_
)},
Wo_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
Wo_
)},
Z_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
Z_
)},
Y_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
Y_
)},
X_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
X_
)},
K_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
K_
)},
C_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
C_
)},
DiStride_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
DiStride_
)},
HiStride_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
HiStride_
)},
WiStride_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
WiStride_
)},
DoStride_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
DoStride_
)},
HoStride_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
HoStride_
)},
WoStride_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
WoStride_
)},
CStrideTensorB_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
CStrideTensorB_
)},
CStrideTensorC_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
CStrideTensorC_
)},
KStrideTensorA_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
KStrideTensorA_
)},
KStrideTensorB_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
KStrideTensorB_
)},
NStrideTensorA_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
NStrideTensorA_
)},
NStrideTensorC_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
NStrideTensorC_
)},
ConvStrideD_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
ConvStrideD_
)},
ConvStrideH_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
ConvStrideH_
)},
ConvStrideW_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
ConvStrideW_
)},
ConvDilationD_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
ConvDilationD_
)},
ConvDilationH_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
ConvDilationH_
)},
ConvDilationW_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
ConvDilationW_
)},
InLeftPadD_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
InLeftPadD_
)},
InLeftPadH_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
InLeftPadH_
)},
InLeftPadW_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
InLeftPadW_
)},
InRightPadD_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
InRightPadD_
)},
InRightPadH_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
InRightPadH_
)},
InRightPadW_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
InRightPadW_
)},
IdxZTilde_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
IdxZTilde_
)},
IdxYTilde_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
IdxYTilde_
)},
IdxXTilde_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
IdxXTilde_
)},
GcdStrideDilationD_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
GcdStrideDilationD_
)},
GcdStrideDilationH_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
GcdStrideDilationH_
)},
GcdStrideDilationW_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
GcdStrideDilationW_
)},
ZTilde_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
ZTilde_
)},
YTilde_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
YTilde_
)},
XTilde_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
XTilde_
)},
DTilde_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
DTilde_
)},
HTilde_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
HTilde_
)},
WTilde_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
WTilde_
)},
ZDot_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
ZDot_
)},
YDot_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
YDot_
)},
XDot_
{
static_cast
<
IndexType
>
(
transform_conv_bwd_data_to_gemm_base
.
XDot_
)}
{
{
// assume packed
}
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
>
Filter1x1Stride1Pad0
)
__host__
__device__
TransformConvBwdDataToGemm_v1
(
const
ConvDimsType
&
a_g_n_k_wos_lengths
,
const
ConvDimsType
&
a_g_n_k_wos_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
const
ConvDimsType
&
b_g_k_c_xs_strides
,
const
ConvDimsType
&
c_g_n_c_wis_lengths
,
const
ConvDimsType
&
c_g_n_c_wis_strides
,
const
ConvSpatialDimsType
&
conv_filter_strides
,
const
ConvSpatialDimsType
&
conv_filter_dilations
,
const
ConvSpatialDimsType
&
input_left_pads
,
const
ConvSpatialDimsType
&
input_right_pads
,
const
ConvSpatialDimsType
&
tildes
)
:
Hi_
{
c_g_n_c_wis_lengths
[
HIdx
]},
Wi_
{
c_g_n_c_wis_lengths
[
WIdx
]},
Ho_
{
a_g_n_k_wos_lengths
[
HIdx
]},
Wo_
{
a_g_n_k_wos_lengths
[
WIdx
]},
Y_
{
b_g_k_c_xs_lengths
[
YIdx
]},
X_
{
b_g_k_c_xs_lengths
[
XIdx
]},
K_
{
a_g_n_k_wos_lengths
[
I2
]},
C_
{
b_g_k_c_xs_lengths
[
I2
]},
HiStride_
{
c_g_n_c_wis_strides
[
HIdx
]},
WiStride_
{
c_g_n_c_wis_strides
[
WIdx
]},
HoStride_
{
a_g_n_k_wos_strides
[
HIdx
]},
WoStride_
{
a_g_n_k_wos_strides
[
WIdx
]},
CStrideTensorB_
{
b_g_k_c_xs_strides
[
I2
]},
CStrideTensorC_
{
c_g_n_c_wis_strides
[
I2
]},
KStrideTensorA_
{
a_g_n_k_wos_strides
[
I2
]},
KStrideTensorB_
{
b_g_k_c_xs_strides
[
I1
]},
NStrideTensorA_
{
a_g_n_k_wos_strides
[
I1
]},
NStrideTensorC_
{
c_g_n_c_wis_strides
[
I1
]},
ConvStrideH_
{
conv_filter_strides
[
HIdx
-
NonSpatialDimsNum
]},
ConvStrideW_
{
conv_filter_strides
[
WIdx
-
NonSpatialDimsNum
]},
ConvDilationH_
{
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
]},
ConvDilationW_
{
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
]},
InLeftPadH_
{
input_left_pads
[
HIdx
-
NonSpatialDimsNum
]},
InLeftPadW_
{
input_left_pads
[
WIdx
-
NonSpatialDimsNum
]},
InRightPadH_
{
input_right_pads
[
HIdx
-
NonSpatialDimsNum
]},
InRightPadW_
{
input_right_pads
[
WIdx
-
NonSpatialDimsNum
]},
IdxYTilde_
{
tildes
[
YIdx
-
NonSpatialDimsNum
]},
IdxXTilde_
{
tildes
[
XIdx
-
NonSpatialDimsNum
]}
{
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
IndexType
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
IndexType
,
NDimSpatial
+
I3
>>
);
if
constexpr
(
SplitN
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Ho
*
Wo
,
K
));
N_
=
GetSplitedNSize
(
a_g_n_k_wos_lengths
,
a_g_n_k_wos_strides
,
c_g_n_c_wis_lengths
,
c_g_n_c_wis_strides
);
}
}
else
else
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Ho
,
Wo
,
K
))
;
N_
=
c_g_n_c_wis_lengths
[
I1
]
;
}
}
}
if
constexpr
(
NDimSpatial
==
3
)
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWK
>
)
{
// assume packed
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
K
));
Di_
=
c_g_n_c_wis_lengths
[
DIdx
];
Do_
=
a_g_n_k_wos_lengths
[
DIdx
];
Z_
=
b_g_k_c_xs_lengths
[
ZIdx
];
DiStride_
=
c_g_n_c_wis_strides
[
DIdx
];
DoStride_
=
a_g_n_k_wos_strides
[
DIdx
];
ConvStrideD_
=
conv_filter_strides
[
DIdx
-
NonSpatialDimsNum
];
ConvDilationD_
=
conv_filter_dilations
[
DIdx
-
NonSpatialDimsNum
];
InLeftPadD_
=
input_left_pads
[
DIdx
-
NonSpatialDimsNum
];
InRightPadD_
=
input_right_pads
[
DIdx
-
NonSpatialDimsNum
];
IdxZTilde_
=
tildes
[
ZIdx
-
NonSpatialDimsNum
];
GcdStrideDilationD_
=
math
::
gcd
(
ConvStrideD_
,
ConvDilationD_
);
ZTilde_
=
ConvStrideD_
/
GcdStrideDilationD_
;
DTilde_
=
Do_
+
math
::
integer_divide_ceil
(
ConvDilationD_
*
(
Z_
-
I1
),
ConvStrideD_
);
ZDot_
=
math
::
integer_divide_ceil
(
Z_
,
ZTilde_
);
}
}
else
else
{
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
K
));
Di_
=
Do_
=
Z_
=
ZTilde_
=
ConvStrideD_
=
DTilde_
=
ZDot_
=
1
;
InLeftPadD_
=
InRightPadD_
=
DiStride_
=
DoStride_
=
IdxZTilde_
=
0
;
}
}
}
else
{
throw
std
::
runtime_error
(
"wrong! unsupported layout: "
+
ALayout
::
name
());
}
}
template
<
typename
BLayout
>
GcdStrideDilationH_
=
math
::
gcd
(
ConvStrideH_
,
ConvDilationH_
);
constexpr
auto
make_wei_grid_desc
(
GcdStrideDilationW_
=
math
::
gcd
(
ConvStrideW_
,
ConvDilationW_
);
const
index_t
K
,
const
index_t
Z
,
const
index_t
Y
,
const
index_t
X
,
const
index_t
C
)
{
if
constexpr
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
)
YTilde_
=
ConvStrideH_
/
GcdStrideDilationH_
;
{
XTilde_
=
ConvStrideW_
/
GcdStrideDilationW_
;
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Y
,
X
,
C
));
}
else
if
constexpr
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
Z
,
Y
,
X
,
C
));
}
else
{
throw
std
::
runtime_error
(
"wrong! unsupported layout: "
+
BLayout
::
name
());
}
}
template
<
index_t
NDimSpatial
,
typename
CLayout
>
constexpr
auto
make_in_grid_desc
(
const
index_t
N
,
const
index_t
Di
,
const
index_t
Hi
,
const
index_t
Wi
,
const
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_strides
)
{
if
constexpr
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
HTilde_
=
Ho_
+
math
::
integer_divide_ceil
(
ConvDilationH_
*
(
Y_
-
I1
),
ConvStrideH_
);
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
WTilde_
=
Wo_
+
math
::
integer_divide_ceil
(
ConvDilationW_
*
(
X_
-
I1
),
ConvStrideW_
);
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
)
{
YDot_
=
math
::
integer_divide_ceil
(
Y_
,
YTilde_
);
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
XDot_
=
math
::
integer_divide_ceil
(
X_
,
XTilde_
);
make_tuple
(
in_g_n_c_wis_strides
[
1
],
in_g_n_c_wis_strides
[
3
],
in_g_n_c_wis_strides
[
4
],
in_g_n_c_wis_strides
[
2
]));
}
}
else
if
constexpr
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
)
#if 0 // At now not supported to split tensor
__host__ bool AreDescriptorsSmallerThan2GB() const
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
make_tuple
(
in_g_n_c_wis_strides
[
1
],
in_g_n_c_wis_strides
[
3
],
const long_index_t in_desc_space_size =
in_g_n_c_wis_strides
[
4
],
I1 + (N_ - I1) * NStrideTensorC_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ +
in_g_n_c_wis_strides
[
5
],
(Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorC_;
in_g_n_c_wis_strides
[
2
]));
const long_index_t out_desc_space_size =
I1 + (N_ - I1) * NStrideTensorA_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ +
(Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorA_;
bool is_a_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(ADataType)) <= TwoGB;
bool is_c_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(CDataType)) <= TwoGB;
return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB;
}
}
else
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
CDataType* c_grid_ptr_base) const
{
{
throw
std
::
runtime_error
(
"wrong! unsupported layout: "
+
CLayout
::
name
());
// Create copies
}
auto conv_to_gemm_transformer_left = *this;
}
auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate real filter size
const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1;
const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1;
const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1;
// Calculate start position in input for right tensor
const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_;
const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_;
const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_;
// Calculate last position in input for left tensor
const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff;
const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff;
const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff;
// Allow to split if whole left padding will be in left tensor and right padding in right
// tensor
const bool is_possible_to_split_d = Do_ != 1 &&
di_right_transformer_start_idx > InLeftPadD_ &&
di_left_transformer_end_idx <= (InLeftPadD_ + Di_);
const bool is_possible_to_split_h = Ho_ != 1 &&
hi_right_transformer_start_idx > InLeftPadH_ &&
hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_);
const bool is_possible_to_split_w = Wo_ != 1 &&
wi_right_transformer_start_idx > InLeftPadW_ &&
wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_);
if(is_possible_to_split_d)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Do_ = Do_ / 2;
conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_;
conv_to_gemm_transformer_right.Di_ =
math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_),
(conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff);
;
// Calcualte offsets
a_right_offset = (Do_ / 2) * DoStride_;
c_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_;
}
else if(is_possible_to_split_h)
{
conv_to_gemm_transformer_left.Ho_ = Ho_ / 2;
conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2;
}
// namespace
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
template
<
conv_to_gemm_transformer_left.InRightPadH_ = 0;
index_t
NDimSpatial
,
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
ConvBwdDataSpecialization
,
index_t
AK1
,
index_t
BK1
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
bool
DoPadGemmM
,
bool
DoPadGemmN
>
struct
TransformConvBwdDataToGemm_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
NonSpatialDimsNum
=
Number
<
3
>
{};
conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_;
conv_to_gemm_transformer_right.Hi_ =
math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_),
(conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff);
a_right_offset = (Ho_ / 2) * HoStride_;
c_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_;
}
else if(is_possible_to_split_w)
{
conv_to_gemm_transformer_left.Wo_ = Wo_ / 2;
conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2;
static
constexpr
auto
DIdx
=
Number
<
NonSpatialDimsNum
>
{};
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
static
constexpr
auto
HIdx
=
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
>
{}
:
Number
<
NonSpatialDimsNum
+
1
>
{};
static
constexpr
auto
WIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
+
1
>
{}
:
Number
<
NonSpatialDimsNum
+
2
>
{};
static
constexpr
auto
ZIdx
=
Number
<
NonSpatialDimsNum
>
{};
conv_to_gemm_transformer_left.InRightPadW_ = 0;
static
constexpr
auto
YIdx
=
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
>
{}
:
Number
<
NonSpatialDimsNum
+
1
>
{};
static
constexpr
auto
XIdx
=
NDimSpatial
==
2
?
Number
<
NonSpatialDimsNum
+
1
>
{}
:
Number
<
NonSpatialDimsNum
+
2
>
{};
template
<
typename
ALayout
,
conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_;
typename
std
::
enable_if
<
(
NDimSpatial
==
2
||
NDimSpatial
==
3
)
&&
conv_to_gemm_transformer_right.Wi_ =
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_),
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWK
>
||
(conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff);
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGK
>
),
a_right_offset = (Wo_ / 2) * WoStride_;
bool
>::
type
=
false
>
c_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_;
static
auto
MakeADescriptor_AK0_M_AK1
(
}
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
// Return left transform, right transformer, right offset to Input and right offset to
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_strides
,
// Output
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
return ck::make_tuple(conv_to_gemm_transformer_left,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
conv_to_gemm_transformer_right,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
a_grid_ptr_base + a_right_offset,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* in_g_n_c_wis_strides */
,
c_grid_ptr_base + c_right_offset);
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
}
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
__host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_right_pads */
,
CDataType* c_grid_ptr_base) const
const
std
::
array
<
index_t
,
NDimSpatial
>&
tildes
)
{
{
index_t
i_ztilde
=
tildes
[
ZIdx
-
NonSpatialDimsNum
];
// Create copies
index_t
i_ytilde
=
tildes
[
YIdx
-
NonSpatialDimsNum
];
auto conv_to_gemm_transformer_left = *this;
index_t
i_xtilde
=
tildes
[
XIdx
-
NonSpatialDimsNum
];
auto conv_to_gemm_transformer_right = *this;
IndexType a_right_offset = 0;
IndexType c_right_offset = 0;
// Calculate start position in input for right tensor
const IndexType do_right_transformer_start_idx = math::integer_divide_ceil((Di_ / 2) + InLeftPadD_ - ((Z_ - 1) * ConvDilationD_), ConvStrideD_);
const IndexType ho_right_transformer_start_idx = math::integer_divide_ceil((Hi_ / 2) + InLeftPadH_ - ((Y_ - 1) * ConvDilationH_), ConvStrideH_);
const IndexType wo_right_transformer_start_idx = math::integer_divide_ceil((Wi_ / 2) + InLeftPadW_ - ((X_ - 1) * ConvDilationW_), ConvStrideW_);
// Calculate last position in input for left tensor
const IndexType do_left_transformer_end_idx = math::integer_divide_ceil((Di_ / 2 - 1) + InLeftPadD_, ConvStrideD_);
const IndexType ho_left_transformer_end_idx = math::integer_divide_ceil((Hi_ / 2 - 1) + InLeftPadH_, ConvStrideH_);
const IndexType wo_left_transformer_end_idx = math::integer_divide_ceil((Wi_ / 2 - 1) + InLeftPadW_, ConvStrideW_);
if(Di_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Di_ = Di_ / 2;
conv_to_gemm_transformer_right.Di_ = Di_ - Di_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_;
conv_to_gemm_transformer_right.InLeftPadD_ = 0;
// // Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadD_ = 0;
conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_;
// Calculate new input size
conv_to_gemm_transformer_left.Do_ = do_left_transformer_end_idx;
conv_to_gemm_transformer_right.Do_ = Do_ - do_right_transformer_start_idx;
;
// Calcualte offsets
a_right_offset = do_right_transformer_start_idx * DoStride_;
c_right_offset = (Di_ / 2) * DiStride_;
}
else if(Hi_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Hi_ = Hi_ / 2;
conv_to_gemm_transformer_right.Hi_ = Hi_ - Hi_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_;
conv_to_gemm_transformer_right.InLeftPadH_ = 0;
// // Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadH_ = 0;
conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_;
// Calculate new input size
conv_to_gemm_transformer_left.Ho_ = ho_left_transformer_end_idx ;
conv_to_gemm_transformer_right.Ho_ = Ho_ - ho_right_transformer_start_idx ;
;
// Calcualte offsets
a_right_offset = ho_right_transformer_start_idx * HoStride_;
c_right_offset = (Hi_ / 2) * HiStride_;
}
else if(Wi_!=1)
{
// Apply new sizes
// Split output on half
conv_to_gemm_transformer_left.Wi_ = Wi_ / 2;
conv_to_gemm_transformer_right.Wi_ = Wi_ - Wi_ / 2;
// Assign left padding to left convolution
conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_;
conv_to_gemm_transformer_right.InLeftPadW_ = 0;
// Assign right padding to right convolution
conv_to_gemm_transformer_left.InRightPadW_ = 0;
conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_;
// Calculate new input size
conv_to_gemm_transformer_left.Wo_ = wo_left_transformer_end_idx;
conv_to_gemm_transformer_right.Wo_ = Wo_ - wo_right_transformer_start_idx;
;
// Calcualte offsets
a_right_offset = wo_right_transformer_start_idx * WoStride_;
c_right_offset = (Wi_ / 2) * WiStride_;
}
// Return left transform, right transformer, right offset to Input and right offset to
// Output
return ck::make_tuple(conv_to_gemm_transformer_left,
conv_to_gemm_transformer_right,
a_grid_ptr_base + a_right_offset,
c_grid_ptr_base + c_right_offset);
}
#endif
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
__host__
__device__
auto
MakeOutGridDesc
()
const
const
index_t
K
=
wei_g_k_c_xs_lengths
[
1
];
{
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGK
>
)
{
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
Di
=
NDimSpatial
==
3
?
in_g_n_c_wis_lengths
[
DIdx
]
:
1
;
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
*
Ho_
*
Wo_
,
K_
),
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
HIdx
];
make_tuple
(
WoStride_
,
KStrideTensorA_
));
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
WIdx
];
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Ho_
,
Wo_
,
K_
),
make_tuple
(
NStrideTensorA_
,
HoStride_
,
WoStride_
,
KStrideTensorA_
));
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGK
>
)
{
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
Do
=
NDimSpatial
==
3
?
out_g_n_k_wos_lengths
[
DIdx
]
:
1
;
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
*
Do_
*
Ho_
*
Wo_
,
K_
),
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
HIdx
];
make_tuple
(
WoStride_
,
KStrideTensorA_
));
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
WIdx
];
}
else
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
,
K_
),
make_tuple
(
NStrideTensorA_
,
DoStride_
,
HoStride_
,
WoStride_
,
KStrideTensorA_
));
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWK
>
)
{
// assume packed
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N_
*
Ho_
*
Wo_
,
K_
));
}
else
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N_
,
Ho_
,
Wo_
,
K_
));
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWK
>
)
{
// assume packed
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N_
*
Do_
*
Ho_
*
Wo_
,
K_
));
}
else
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
,
K_
));
}
}
else
{
throw
std
::
runtime_error
(
"wrong! unsupported layout: "
+
ALayout
::
name
());
}
}
const
index_t
Z
=
NDimSpatial
==
3
?
wei_g_k_c_xs_lengths
[
ZIdx
]
:
1
;
__host__
__device__
auto
MakeWeiGridDesc
()
const
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
YIdx
];
{
const
index_t
X
=
wei_g_k_c_xs_lengths
[
XIdx
];
const
index_t
InLeftPadD
=
input_left_pads
[
DIdx
-
NonSpatialDimsNum
];
if
constexpr
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
)
const
index_t
InLeftPadH
=
input_left_pads
[
HIdx
-
NonSpatialDimsNum
];
{
const
index_t
InLeftPadW
=
input_left_pads
[
WIdx
-
NonSpatialDimsNum
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K_
,
Y_
,
X_
,
C_
));
}
else
if
constexpr
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K_
,
Z_
,
Y_
,
X_
,
C_
));
}
else
{
throw
std
::
runtime_error
(
"wrong! unsupported layout: "
+
BLayout
::
name
());
}
}
const
index_t
ConvStrideD
=
conv_filter_strides
[
DIdx
-
NonSpatialDimsNum
];
__host__
__device__
auto
MakeInGridDesc
()
const
const
index_t
ConvStrideH
=
conv_filter_strides
[
HIdx
-
NonSpatialDimsNum
];
{
const
index_t
ConvStrideW
=
conv_filter_strides
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
DIdx
-
NonSpatialDimsNum
];
if
constexpr
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Hi_
,
Wi_
,
C_
),
make_tuple
(
NStrideTensorC_
,
HiStride_
,
WiStride_
,
CStrideTensorC_
));
}
else
if
constexpr
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N_
,
Di_
,
Hi_
,
Wi_
,
C_
),
make_tuple
(
NStrideTensorC_
,
DiStride_
,
HiStride_
,
WiStride_
,
CStrideTensorC_
));
}
else
{
throw
std
::
runtime_error
(
"wrong! unsupported layout: "
+
CLayout
::
name
());
}
}
template
<
typename
ALayout_
=
ALayout
,
typename
std
::
enable_if
<
(
NDimSpatial
==
2
||
NDimSpatial
==
3
)
&&
(
is_same_v
<
ALayout_
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
ALayout_
,
tensor_layout
::
convolution
::
GNDHWK
>
||
is_same_v
<
ALayout_
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
ALayout_
,
tensor_layout
::
convolution
::
NDHWGK
>
),
bool
>::
type
=
false
>
__host__
__device__
auto
MakeADescriptor_AK0_M_AK1
()
const
{
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const
auto
out_grid_desc
=
const
auto
out_grid_desc
=
MakeOutGridDesc
();
make_out_grid_desc
<
NDimSpatial
,
ALayout
,
ConvBwdDataSpecialization
>
(
N
,
Do
,
Ho
,
Wo
,
K
,
out_g_n_k_wos_strides
);
if
constexpr
(
ConvBwdDataSpecialization
==
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
Filter1x1Stride1Pad0
)
{
{
const
index_t
AK0
=
math
::
integer_divide_ceil
(
K
,
AK1
);
const
index_t
AK0
=
math
::
integer_divide_ceil
(
K
_
,
AK1
);
// A: output tensor
// A: output tensor
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmak0_gemmmraw_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
out_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
*
Do
*
Ho
*
Wo
),
make_tuple
(
make_pass_through_transform
(
N
_
*
Do
_
*
Ho
_
*
Wo
_
),
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
))),
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
...
@@ -266,82 +635,63 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -266,82 +635,63 @@ struct TransformConvBwdDataToGemm_v1
}
}
else
else
{
{
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
ConvStrideD
/
GcdStrideDilationD
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
ZDot
=
math
::
integer_divide_ceil
(
Z
,
ZTilde
);
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
const
auto
DTilde
=
Do
+
math
::
integer_divide_ceil
(
ConvDilationD
*
(
Z
-
I1
),
ConvStrideD
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IDTildeSliceBegin
=
math
::
integer_divide_floor
(
const
auto
IDTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadD
-
ConvDilationD
*
(
ZTilde
-
I1
)),
ConvStrideD
);
math
::
max
(
I0
,
InLeftPadD
_
-
ConvDilationD
_
*
(
ZTilde
_
-
I1
)),
ConvStrideD
_
);
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
math
::
max
(
I0
,
InLeftPadH
_
-
ConvDilationH
_
*
(
YTilde
_
-
I1
)),
ConvStrideH
_
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
math
::
max
(
I0
,
InLeftPadW
_
-
ConvDilationW
_
*
(
XTilde
_
-
I1
)),
ConvStrideW
_
);
const
auto
IDTildeSliceEnd
=
math
::
min
(
const
auto
IDTildeSliceEnd
=
math
::
min
(
DTilde
,
math
::
integer_divide_ceil
(
InLeftPadD
+
Di
-
I1
,
ConvStrideD
)
+
I1
);
DTilde
_
,
math
::
integer_divide_ceil
(
InLeftPadD
_
+
Di
_
-
I1
,
ConvStrideD
_
)
+
I1
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
HTilde
_
,
math
::
integer_divide_ceil
(
InLeftPadH
_
+
Hi
_
-
I1
,
ConvStrideH
_
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
WTilde
_
,
math
::
integer_divide_ceil
(
InLeftPadW
_
+
Wi
_
-
I1
,
ConvStrideW
_
)
+
I1
);
const
auto
DTildeSlice
=
IDTildeSliceEnd
-
IDTildeSliceBegin
;
const
auto
DTildeSlice
=
IDTildeSliceEnd
-
IDTildeSliceBegin
;
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// GemmK is different for each GEMM
// GemmK is different for each GEMM
const
auto
ZDotSlice
=
math
::
integer_divide_ceil
(
Z
-
i_zt
ilde
,
ZTilde
);
const
auto
ZDotSlice
=
math
::
integer_divide_ceil
(
Z
_
-
IdxZT
ilde
_
,
ZTilde
_
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_yt
ilde
,
YTilde
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
_
-
IdxYT
ilde
_
,
YTilde
_
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xt
ilde
,
XTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
_
-
IdxXT
ilde
_
,
XTilde
_
);
if
constexpr
(
NDimSpatial
==
2
)
if
constexpr
(
NDimSpatial
==
2
)
{
{
// A: output tensor
// A: output tensor
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
out_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Ho
_
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pad_transform
(
Wo
_
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
K
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_embed_transform
(
make_tuple
(
YDot
_
,
HTilde
_
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_tuple
(
-
ConvDilationH
_
/
GcdStrideDilationH
_
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_embed_transform
(
make_tuple
(
XDot
_
,
WTilde
_
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_tuple
(
-
ConvDilationW
_
/
GcdStrideDilationW
_
,
I1
)),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
K
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
=
const
auto
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
,
out_n_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
YDot
_
,
I0
,
YDotSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
HTilde
_
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
_
,
I0
,
XDotSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_slice_transform
(
WTilde
_
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
K
_
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -357,8 +707,8 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -357,8 +707,8 @@ struct TransformConvBwdDataToGemm_v1
const
auto
out_gemmk_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
,
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K
_
)),
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
))),
make_merge_transform
(
make_tuple
(
N
_
,
HTildeSlice
,
WTildeSlice
))),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
0
,
2
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -385,11 +735,11 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -385,11 +735,11 @@ struct TransformConvBwdDataToGemm_v1
// A: output tensor
// A: output tensor
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_n_hop_wop_k_grid_desc
=
transform_tensor_descriptor
(
out_grid_desc
,
out_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Do
,
I0
,
I0
),
make_pad_transform
(
Do
_
,
I0
,
I0
),
make_pad_transform
(
Ho
,
I0
,
I0
),
make_pad_transform
(
Ho
_
,
I0
,
I0
),
make_pad_transform
(
Wo
,
I0
,
I0
),
make_pad_transform
(
Wo
_
,
I0
,
I0
),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
K
_
)),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
make_tuple
(
...
@@ -398,17 +748,17 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -398,17 +748,17 @@ struct TransformConvBwdDataToGemm_v1
const
auto
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
=
const
auto
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_n_hop_wop_k_grid_desc
,
out_n_hop_wop_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_embed_transform
(
make_tuple
(
ZDot
,
DTilde
),
make_tuple
(
ZDot
_
,
DTilde
_
),
make_tuple
(
-
ConvDilationD
/
GcdStrideDilationD
,
I1
)),
make_tuple
(
-
ConvDilationD
_
/
GcdStrideDilationD
_
,
I1
)),
make_embed_transform
(
make_embed_transform
(
make_tuple
(
YDot
,
HTilde
),
make_tuple
(
YDot
_
,
HTilde
_
),
make_tuple
(
-
ConvDilationH
/
GcdStrideDilationH
,
I1
)),
make_tuple
(
-
ConvDilationH
_
/
GcdStrideDilationH
_
,
I1
)),
make_embed_transform
(
make_embed_transform
(
make_tuple
(
XDot
,
WTilde
),
make_tuple
(
XDot
_
,
WTilde
_
),
make_tuple
(
-
ConvDilationW
/
GcdStrideDilationW
,
I1
)),
make_tuple
(
-
ConvDilationW
_
/
GcdStrideDilationW
_
,
I1
)),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
K
_
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -424,14 +774,15 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -424,14 +774,15 @@ struct TransformConvBwdDataToGemm_v1
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
=
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
,
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_pass_through_transform
(
N_
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
make_slice_transform
(
ZDot_
,
I0
,
ZDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
DTilde_
,
IDTildeSliceBegin
,
DTildeSlice
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
YDot_
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
HTilde_
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_slice_transform
(
XDot_
,
I0
,
XDotSlice
),
make_pass_through_transform
(
K
)),
make_slice_transform
(
WTilde_
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
K_
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -452,8 +803,9 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -452,8 +803,9 @@ struct TransformConvBwdDataToGemm_v1
const
auto
out_gemmk_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmk_gemmmraw_grid_desc
=
transform_tensor_descriptor
(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
,
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc
,
make_tuple
(
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K
)),
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K_
)),
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
))),
make_merge_transform
(
make_tuple
(
N_
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
))),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
,
7
>
{},
Sequence
<
0
,
2
,
4
,
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -482,66 +834,31 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -482,66 +834,31 @@ struct TransformConvBwdDataToGemm_v1
}
}
}
}
template
<
typename
BLayout
,
template
<
typename
BLayout_
=
BLayout
,
typename
std
::
enable_if
<
(
NDimSpatial
==
2
||
NDimSpatial
==
3
)
&&
typename
std
::
enable_if
<
(
NDimSpatial
==
2
||
NDimSpatial
==
3
)
&&
(
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
(
is_same_v
<
BLayout
_
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
),
is_same_v
<
BLayout
_
,
tensor_layout
::
convolution
::
GKZYXC
>
),
bool
>::
type
=
false
>
bool
>::
type
=
false
>
static
auto
MakeBDescriptor_BK0_N_BK1
(
__host__
__device__
auto
MakeBDescriptor_BK0_N_BK1
()
const
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* out_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* in_g_n_c_wis_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_left_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_right_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
tildes
)
{
{
index_t
i_ztilde
=
tildes
[
ZIdx
-
NonSpatialDimsNum
];
index_t
i_ytilde
=
tildes
[
YIdx
-
NonSpatialDimsNum
];
index_t
i_xtilde
=
tildes
[
XIdx
-
NonSpatialDimsNum
];
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
K
=
wei_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
wei_g_k_c_xs_lengths
[
2
];
const
index_t
Do
=
NDimSpatial
==
3
?
out_g_n_k_wos_lengths
[
DIdx
]
:
1
;
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
HIdx
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
WIdx
];
const
index_t
Z
=
NDimSpatial
==
3
?
wei_g_k_c_xs_lengths
[
ZIdx
]
:
1
;
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
YIdx
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
XIdx
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
// assume packed
// assume packed
// k_y_x_c for 2d or k_z_y_x_c for 3d
// k_y_x_c for 2d or k_z_y_x_c for 3d
const
auto
wei_grid_desc
=
m
ake
_wei_g
rid
_d
esc
<
BLayout
>
(
K
,
Z
,
Y
,
X
,
C
);
const
auto
wei_grid_desc
=
M
ake
WeiG
rid
D
esc
(
);
if
constexpr
(
ConvBwdDataSpecialization
==
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
Filter1x1Stride1Pad0
)
{
{
const
index_t
BK0
=
math
::
integer_divide_ceil
(
K
,
BK1
);
const
index_t
BK0
=
math
::
integer_divide_ceil
(
K
_
,
BK1
);
// B: weight tensor
// B: weight tensor
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
const
auto
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
C
)),
transform_tensor_descriptor
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
_
,
C
_
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Do
*
Ho
*
Wo
,
C
),
make_tuple
(
I0
,
I1
));
make_naive_tensor_descriptor
(
make_tuple
(
N
_
*
Do
_
*
Ho
_
*
Wo
_
,
C
_
),
make_tuple
(
I0
,
I1
));
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
...
@@ -553,22 +870,10 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -553,22 +870,10 @@ struct TransformConvBwdDataToGemm_v1
}
}
else
else
{
{
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
ConvStrideD
/
GcdStrideDilationD
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
ZDot
=
math
::
integer_divide_ceil
(
Z
,
ZTilde
);
const
auto
YDot
=
math
::
integer_divide_ceil
(
Y
,
YTilde
);
const
auto
XDot
=
math
::
integer_divide_ceil
(
X
,
XTilde
);
// GemmK is different for each GEMM
// GemmK is different for each GEMM
const
auto
ZDotSlice
=
math
::
integer_divide_ceil
(
Z
-
i_zt
ilde
,
ZTilde
);
const
auto
ZDotSlice
=
math
::
integer_divide_ceil
(
Z
_
-
IdxZT
ilde
_
,
ZTilde
_
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_yt
ilde
,
YTilde
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
_
-
IdxYT
ilde
_
,
YTilde
_
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xt
ilde
,
XTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
_
-
IdxXT
ilde
_
,
XTilde
_
);
// B weight tensor
// B weight tensor
if
constexpr
(
NDimSpatial
==
2
)
if
constexpr
(
NDimSpatial
==
2
)
...
@@ -576,23 +881,23 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -576,23 +881,23 @@ struct TransformConvBwdDataToGemm_v1
const
auto
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
wei_grid_desc
,
wei_grid_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
K
),
make_pass_through_transform
(
K
_
),
make_embed_transform
(
make_tuple
(
YDot
,
YTilde
),
make_embed_transform
(
make_tuple
(
YDot
_
,
YTilde
_
),
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_tuple
(
ConvStrideH
_
/
GcdStrideDilationH
_
,
I1
)),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_embed_transform
(
make_tuple
(
XDot
_
,
XTilde
_
),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_tuple
(
ConvStrideW
_
/
GcdStrideDilationW
_
,
I1
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
wei_k_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_k_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_tuple
(
make_pass_through_transform
(
K
_
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
YDot
_
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
_
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_yt
ilde
),
make_freeze_transform
(
IdxYT
ilde
_
),
make_freeze_transform
(
i_xt
ilde
),
make_freeze_transform
(
IdxXT
ilde
_
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
3
>
{},
...
@@ -608,8 +913,8 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -608,8 +913,8 @@ struct TransformConvBwdDataToGemm_v1
const
auto
wei_gemmk_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_gemmk_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
wei_k_ydotslice_xdotslice_c_grid_desc
,
wei_k_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
YDotSlice
,
XDotSlice
,
K
_
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
1
,
2
,
0
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
0
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -636,15 +941,17 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -636,15 +941,17 @@ struct TransformConvBwdDataToGemm_v1
const
auto
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
=
const
auto
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
wei_grid_desc
,
wei_grid_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
K_
),
make_pass_through_transform
(
K
),
make_embed_transform
(
make_embed_transform
(
make_tuple
(
ZDot
,
ZTilde
),
make_tuple
(
ZDot_
,
ZTilde_
),
make_tuple
(
ConvStrideD
/
GcdStrideDilationD
,
I1
)),
make_tuple
(
ConvStrideD_
/
GcdStrideDilationD_
,
I1
)),
make_embed_transform
(
make_tuple
(
YDot
,
YTilde
),
make_embed_transform
(
make_tuple
(
ConvStrideH
/
GcdStrideDilationH
,
I1
)),
make_tuple
(
YDot_
,
YTilde_
),
make_embed_transform
(
make_tuple
(
XDot
,
XTilde
),
make_tuple
(
ConvStrideH_
/
GcdStrideDilationH_
,
I1
)),
make_tuple
(
ConvStrideW
/
GcdStrideDilationW
,
I1
)),
make_embed_transform
(
make_pass_through_transform
(
C
)),
make_tuple
(
XDot_
,
XTilde_
),
make_tuple
(
ConvStrideW_
/
GcdStrideDilationW_
,
I1
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -659,14 +966,14 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -659,14 +966,14 @@ struct TransformConvBwdDataToGemm_v1
const
auto
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc
=
const
auto
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
,
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_tuple
(
make_pass_through_transform
(
K
_
),
make_slice_transform
(
ZDot
,
I0
,
ZDotSlice
),
make_slice_transform
(
ZDot
_
,
I0
,
ZDotSlice
),
make_slice_transform
(
YDot
,
I0
,
YDotSlice
),
make_slice_transform
(
YDot
_
,
I0
,
YDotSlice
),
make_slice_transform
(
XDot
,
I0
,
XDotSlice
),
make_slice_transform
(
XDot
_
,
I0
,
XDotSlice
),
make_freeze_transform
(
i_zt
ilde
),
make_freeze_transform
(
IdxZT
ilde
_
),
make_freeze_transform
(
i_yt
ilde
),
make_freeze_transform
(
IdxYT
ilde
_
),
make_freeze_transform
(
i_xt
ilde
),
make_freeze_transform
(
IdxXT
ilde
_
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
3
>
{},
...
@@ -686,8 +993,9 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -686,8 +993,9 @@ struct TransformConvBwdDataToGemm_v1
const
auto
wei_gemmk_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_gemmk_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc
,
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K
)),
make_tuple
(
make_pass_through_transform
(
C
)),
make_merge_transform
(
make_tuple
(
ZDotSlice
,
YDotSlice
,
XDotSlice
,
K_
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
1
,
2
,
3
,
0
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
3
,
0
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -716,66 +1024,20 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -716,66 +1024,20 @@ struct TransformConvBwdDataToGemm_v1
}
}
}
}
template
<
typename
CLayout
,
template
<
typename
std
::
enable_if
<
(
NDimSpatial
==
2
||
NDimSpatial
==
3
)
&&
typename
CLayout_
=
CLayout
,
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWC
>
||
typename
std
::
enable_if
<
(
NDimSpatial
==
2
||
NDimSpatial
==
3
)
&&
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
||
(
is_same_v
<
CLayout_
,
tensor_layout
::
convolution
::
GNHWC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
CLayout_
,
tensor_layout
::
convolution
::
GNDHWC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
CLayout_
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
),
is_same_v
<
CLayout_
,
tensor_layout
::
convolution
::
NDHWGC
>
||
bool
>::
type
=
false
>
is_same_v
<
CLayout_
,
tensor_layout
::
convolution
::
G_NHW_C
>
),
static
auto
bool
>::
type
=
false
>
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* out_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
tildes
)
{
{
index_t
i_ztilde
=
tildes
[
ZIdx
-
NonSpatialDimsNum
];
index_t
i_ytilde
=
tildes
[
YIdx
-
NonSpatialDimsNum
];
index_t
i_xtilde
=
tildes
[
XIdx
-
NonSpatialDimsNum
];
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
wei_g_k_c_xs_lengths
[
2
];
const
index_t
Di
=
NDimSpatial
==
3
?
in_g_n_c_wis_lengths
[
DIdx
]
:
1
;
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
HIdx
];
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
WIdx
];
const
index_t
Do
=
NDimSpatial
==
3
?
out_g_n_k_wos_lengths
[
DIdx
]
:
1
;
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
HIdx
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
WIdx
];
const
index_t
Z
=
NDimSpatial
==
3
?
wei_g_k_c_xs_lengths
[
ZIdx
]
:
1
;
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
YIdx
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
XIdx
];
const
index_t
InLeftPadD
=
input_left_pads
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
InLeftPadH
=
input_left_pads
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
InLeftPadW
=
input_left_pads
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
InRightPadD
=
input_right_pads
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
InRightPadH
=
input_right_pads
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
InRightPadW
=
input_right_pads
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
WIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
DIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
HIdx
-
NonSpatialDimsNum
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
WIdx
-
NonSpatialDimsNum
];
// assume strided
// assume strided
// n_hi_wi_c for 2d n_di_hi_wi_c for 3d
// n_hi_wi_c for 2d n_di_hi_wi_c for 3d
const
auto
in_grid_desc
=
const
auto
in_grid_desc
=
MakeInGridDesc
();
make_in_grid_desc
<
NDimSpatial
,
CLayout
>
(
N
,
Di
,
Hi
,
Wi
,
C
,
in_g_n_c_wis_strides
);
if
constexpr
(
ConvBwdDataSpecialization
==
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
...
@@ -787,10 +1049,10 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -787,10 +1049,10 @@ struct TransformConvBwdDataToGemm_v1
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
in_grid_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Ho
_
),
make_tuple
(
I1
,
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
_
),
make_tuple
(
I1
,
ConvStrideW
_
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
...
@@ -798,8 +1060,8 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -798,8 +1060,8 @@ struct TransformConvBwdDataToGemm_v1
in_n_y_ho_x_wo_c_grid_desc
,
in_n_y_ho_x_wo_c_grid_desc
,
make_tuple
(
make_freeze_transform
(
I0
),
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
N
_
,
Ho
_
,
Wo
_
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<>
{},
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -818,11 +1080,11 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -818,11 +1080,11 @@ struct TransformConvBwdDataToGemm_v1
const
auto
in_n_x_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_x_do_y_ho_x_wo_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
in_grid_desc
,
make_tuple
(
make_tuple
(
make_pass_through_transform
(
N
),
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
I1
,
Do
),
make_tuple
(
I1
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
I1
,
Do
_
),
make_tuple
(
I1
,
ConvStrideD
_
)),
make_embed_transform
(
make_tuple
(
I1
,
Ho
),
make_tuple
(
I1
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
I1
,
Ho
_
),
make_tuple
(
I1
,
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
),
make_tuple
(
I1
,
ConvStrideW
)),
make_embed_transform
(
make_tuple
(
I1
,
Wo
_
),
make_tuple
(
I1
,
ConvStrideW
_
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
...
@@ -836,8 +1098,8 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -836,8 +1098,8 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
make_freeze_transform
(
I0
),
make_tuple
(
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_freeze_transform
(
I0
),
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
N
_
,
Do
_
,
Ho
_
,
Wo
_
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
1
>
{},
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
3
>
{},
Sequence
<
3
>
{},
Sequence
<
5
>
{},
Sequence
<
5
>
{},
...
@@ -861,36 +1123,21 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -861,36 +1123,21 @@ struct TransformConvBwdDataToGemm_v1
}
}
else
else
{
{
const
auto
GcdStrideDilationD
=
math
::
gcd
(
ConvStrideD
,
ConvDilationD
);
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
ZTilde
=
ConvStrideD
/
GcdStrideDilationD
;
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
DTilde
=
Do
+
math
::
integer_divide_ceil
(
ConvDilationD
*
(
Z
-
I1
),
ConvStrideD
);
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on DTilde, HTilde and WTilde that contribute to
// only work on DTilde, HTilde and WTilde that contribute to
// non-padding area of input tensor
// non-padding area of input tensor
const
auto
IDTildeSliceBegin
=
math
::
integer_divide_floor
(
const
auto
IDTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadD
-
ConvDilationD
*
(
ZTilde
-
I1
)),
ConvStrideD
);
math
::
max
(
I0
,
InLeftPadD
_
-
ConvDilationD
_
*
(
ZTilde
_
-
I1
)),
ConvStrideD
_
);
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
math
::
max
(
I0
,
InLeftPadH
_
-
ConvDilationH
_
*
(
YTilde
_
-
I1
)),
ConvStrideH
_
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
math
::
max
(
I0
,
InLeftPadW
_
-
ConvDilationW
_
*
(
XTilde
_
-
I1
)),
ConvStrideW
_
);
const
auto
IDTildeSliceEnd
=
math
::
min
(
const
auto
IDTildeSliceEnd
=
math
::
min
(
DTilde
,
math
::
integer_divide_ceil
(
InLeftPadD
+
Di
-
I1
,
ConvStrideD
)
+
I1
);
DTilde
_
,
math
::
integer_divide_ceil
(
InLeftPadD
_
+
Di
_
-
I1
,
ConvStrideD
_
)
+
I1
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
HTilde
_
,
math
::
integer_divide_ceil
(
InLeftPadH
_
+
Hi
_
-
I1
,
ConvStrideH
_
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
WTilde
_
,
math
::
integer_divide_ceil
(
InLeftPadW
_
+
Wi
_
-
I1
,
ConvStrideW
_
)
+
I1
);
const
auto
DTildeSlice
=
IDTildeSliceEnd
-
IDTildeSliceBegin
;
const
auto
DTildeSlice
=
IDTildeSliceEnd
-
IDTildeSliceBegin
;
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
...
@@ -901,34 +1148,34 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -901,34 +1148,34 @@ struct TransformConvBwdDataToGemm_v1
{
{
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
in_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Hi
_
,
InLeftPadH
_
,
InRightPadH
_
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
const
auto
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
in_n_hip_wip_c_grid_desc
,
in_n_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_embed_transform
(
make_tuple
(
YTilde
_
,
HTilde
_
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_tuple
(
ConvDilationH
_
,
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_embed_transform
(
make_tuple
(
XTilde
_
,
WTilde
_
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_n_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_freeze_transform
(
i_yt
ilde
),
make_freeze_transform
(
IdxYT
ilde
_
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
HTilde
_
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xt
ilde
),
make_freeze_transform
(
IdxXT
ilde
_
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_slice_transform
(
WTilde
_
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -944,8 +1191,8 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -944,8 +1191,8 @@ struct TransformConvBwdDataToGemm_v1
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
in_n_htildeslice_wtildeslice_c_grid_desc
,
in_n_htildeslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
HTildeSlice
,
WTildeSlice
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -961,11 +1208,11 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -961,11 +1208,11 @@ struct TransformConvBwdDataToGemm_v1
{
{
const
auto
in_n_dip_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_n_dip_hip_wip_c_grid_desc
=
transform_tensor_descriptor
(
in_grid_desc
,
in_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Di
_
,
InLeftPadD
_
,
InRightPadD
_
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Hi
_
,
InLeftPadH
_
,
InRightPadH
_
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
make_tuple
(
...
@@ -974,14 +1221,14 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -974,14 +1221,14 @@ struct TransformConvBwdDataToGemm_v1
const
auto
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
const
auto
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
in_n_dip_hip_wip_c_grid_desc
,
in_n_dip_hip_wip_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
ZTilde
,
DTilde
),
make_embed_transform
(
make_tuple
(
ZTilde
_
,
DTilde
_
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_tuple
(
ConvDilationD
_
,
ConvStrideD
_
)),
make_embed_transform
(
make_tuple
(
YTilde
,
HTilde
),
make_embed_transform
(
make_tuple
(
YTilde
_
,
HTilde
_
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_tuple
(
ConvDilationH
_
,
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
XTilde
,
WTilde
),
make_embed_transform
(
make_tuple
(
XTilde
_
,
WTilde
_
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -996,14 +1243,14 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -996,14 +1243,14 @@ struct TransformConvBwdDataToGemm_v1
const
auto
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc
=
const
auto
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_freeze_transform
(
i_zt
ilde
),
make_freeze_transform
(
IdxZT
ilde
_
),
make_slice_transform
(
DTilde
,
IDTildeSliceBegin
,
DTildeSlice
),
make_slice_transform
(
DTilde
_
,
IDTildeSliceBegin
,
DTildeSlice
),
make_freeze_transform
(
i_yt
ilde
),
make_freeze_transform
(
IdxYT
ilde
_
),
make_slice_transform
(
HTilde
,
IHTildeSliceBegin
,
HTildeSlice
),
make_slice_transform
(
HTilde
_
,
IHTildeSliceBegin
,
HTildeSlice
),
make_freeze_transform
(
i_xt
ilde
),
make_freeze_transform
(
IdxXT
ilde
_
),
make_slice_transform
(
WTilde
,
IWTildeSliceBegin
,
WTildeSlice
),
make_slice_transform
(
WTilde
_
,
IWTildeSliceBegin
,
WTildeSlice
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
2
>
{},
...
@@ -1024,8 +1271,8 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -1024,8 +1271,8 @@ struct TransformConvBwdDataToGemm_v1
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
transform_tensor_descriptor
(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc
,
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc
,
make_tuple
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
)),
make_merge_transform
(
make_tuple
(
N
_
,
DTildeSlice
,
HTildeSlice
,
WTildeSlice
)),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -1044,84 +1291,41 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -1044,84 +1291,41 @@ struct TransformConvBwdDataToGemm_v1
}
}
// for input bias
// for input bias
template
<
typename
CLayout
,
template
<
typename
CLayout_
=
CLayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GC
>
||
(
is_same_v
<
CLayout
_
,
tensor_layout
::
convolution
::
GC
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_C
>
),
is_same_v
<
CLayout
_
,
tensor_layout
::
convolution
::
G_C
>
),
bool
>::
type
=
false
>
bool
>::
type
=
false
>
static
auto
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
out_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* out_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
wei_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* wei_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
in_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* in_g_n_c_wis_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* input_right_pads */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
/* tildes */
)
{
{
const
index_t
N
=
in_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
wei_g_k_c_xs_lengths
[
2
];
const
index_t
Hi
=
in_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
in_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
out_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
out_g_n_k_wos_lengths
[
4
];
const
index_t
Y
=
wei_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
wei_g_k_c_xs_lengths
[
4
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
if
constexpr
(
ConvBwdDataSpecialization
==
if
constexpr
(
ConvBwdDataSpecialization
==
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
)
Filter1x1Stride1Pad0
)
{
{
const
auto
in_gemmm_gemmn_grid_desc
=
const
auto
in_gemmm_gemmn_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
Ho
*
Wo
,
C
),
make_tuple
(
I0
,
I1
));
make_naive_tensor_descriptor
(
make_tuple
(
N
_
*
Ho
_
*
Wo
_
,
C
_
),
make_tuple
(
I0
,
I1
));
return
in_gemmm_gemmn_grid_desc
;
return
in_gemmm_gemmn_grid_desc
;
}
}
else
else
{
{
const
auto
GcdStrideDilationH
=
math
::
gcd
(
ConvStrideH
,
ConvDilationH
);
const
auto
GcdStrideDilationW
=
math
::
gcd
(
ConvStrideW
,
ConvDilationW
);
const
auto
YTilde
=
ConvStrideH
/
GcdStrideDilationH
;
const
auto
XTilde
=
ConvStrideW
/
GcdStrideDilationW
;
const
auto
HTilde
=
Ho
+
math
::
integer_divide_ceil
(
ConvDilationH
*
(
Y
-
I1
),
ConvStrideH
);
const
auto
WTilde
=
Wo
+
math
::
integer_divide_ceil
(
ConvDilationW
*
(
X
-
I1
),
ConvStrideW
);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
const
auto
IHTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadH
-
ConvDilationH
*
(
YTilde
-
I1
)),
ConvStrideH
);
math
::
max
(
I0
,
InLeftPadH
_
-
ConvDilationH
_
*
(
YTilde
_
-
I1
)),
ConvStrideH
_
);
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
const
auto
IWTildeSliceBegin
=
math
::
integer_divide_floor
(
math
::
max
(
I0
,
InLeftPadW
-
ConvDilationW
*
(
XTilde
-
I1
)),
ConvStrideW
);
math
::
max
(
I0
,
InLeftPadW
_
-
ConvDilationW
_
*
(
XTilde
_
-
I1
)),
ConvStrideW
_
);
const
auto
IHTildeSliceEnd
=
math
::
min
(
const
auto
IHTildeSliceEnd
=
math
::
min
(
HTilde
,
math
::
integer_divide_ceil
(
InLeftPadH
+
Hi
-
I1
,
ConvStrideH
)
+
I1
);
HTilde
_
,
math
::
integer_divide_ceil
(
InLeftPadH
_
+
Hi
_
-
I1
,
ConvStrideH
_
)
+
I1
);
const
auto
IWTildeSliceEnd
=
math
::
min
(
const
auto
IWTildeSliceEnd
=
math
::
min
(
WTilde
,
math
::
integer_divide_ceil
(
InLeftPadW
+
Wi
-
I1
,
ConvStrideW
)
+
I1
);
WTilde
_
,
math
::
integer_divide_ceil
(
InLeftPadW
_
+
Wi
_
-
I1
,
ConvStrideW
_
)
+
I1
);
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
HTildeSlice
=
IHTildeSliceEnd
-
IHTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
const
auto
WTildeSlice
=
IWTildeSliceEnd
-
IWTildeSliceBegin
;
// bias tensor
// bias tensor
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
make_naive_tensor_descriptor
(
const
auto
in_gemmmraw_gemmnraw_grid_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
*
HTildeSlice
*
WTildeSlice
,
C
),
make_tuple
(
I0
,
I1
));
make_tuple
(
N
_
*
HTildeSlice
*
WTildeSlice
,
C
_
),
make_tuple
(
I0
,
I1
));
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
const
auto
in_gemmm_gemmn_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
in_gemmmraw_gemmnraw_grid_desc
,
in_gemmmraw_gemmnraw_grid_desc
,
...
@@ -1131,6 +1335,25 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -1131,6 +1335,25 @@ struct TransformConvBwdDataToGemm_v1
return
in_gemmm_gemmn_grid_desc
;
return
in_gemmm_gemmn_grid_desc
;
}
}
}
}
IndexType
N_
;
IndexType
Di_
,
Hi_
,
Wi_
;
IndexType
Do_
,
Ho_
,
Wo_
;
IndexType
Z_
,
Y_
,
X_
;
IndexType
K_
,
C_
;
IndexType
DiStride_
,
HiStride_
,
WiStride_
;
IndexType
DoStride_
,
HoStride_
,
WoStride_
;
IndexType
CStrideTensorB_
,
CStrideTensorC_
,
KStrideTensorA_
,
KStrideTensorB_
;
IndexType
NStrideTensorA_
,
NStrideTensorC_
;
IndexType
ConvStrideD_
,
ConvStrideH_
,
ConvStrideW_
;
IndexType
ConvDilationD_
,
ConvDilationH_
,
ConvDilationW_
;
IndexType
InLeftPadD_
,
InLeftPadH_
,
InLeftPadW_
;
IndexType
InRightPadD_
,
InRightPadH_
,
InRightPadW_
;
IndexType
IdxZTilde_
,
IdxYTilde_
,
IdxXTilde_
;
IndexType
GcdStrideDilationD_
,
GcdStrideDilationH_
,
GcdStrideDilationW_
;
IndexType
ZTilde_
,
YTilde_
,
XTilde_
;
IndexType
DTilde_
,
HTilde_
,
WTilde_
;
IndexType
ZDot_
,
YDot_
,
XDot_
;
};
};
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
a5137505
...
@@ -429,7 +429,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
...
@@ -429,7 +429,8 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
f8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
bf8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
int8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
uint8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
(
is_same
<
T
,
uint8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
is_same
<
T
,
pk_i4_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
)),
"wrong! not implemented"
);
"wrong! not implemented"
);
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
using
r_t
=
typename
vector_type
<
T
,
N
>::
type
;
...
...
include/ck/utility/amd_ck_fp8.hpp
View file @
a5137505
...
@@ -20,6 +20,20 @@
...
@@ -20,6 +20,20 @@
#define CK_USE_OCP_FP8 0
#define CK_USE_OCP_FP8 0
#endif
#endif
namespace
{
// https://en.cppreference.com/w/cpp/types/conditional
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
}
// namespace
namespace
ck
{
namespace
ck
{
using
f8_fnuz_t
=
_BitInt
(
8
);
using
f8_fnuz_t
=
_BitInt
(
8
);
...
@@ -193,11 +207,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
...
@@ -193,11 +207,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
}
}
}
}
typename
__hip_internal
::
conditional
<
typename
conditional
<
sizeof
(
T
)
==
2
,
sizeof
(
T
)
==
2
,
unsigned
short
int
,
unsigned
short
int
,
typename
__hip_internal
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
retval
;
type
>::
type
retval
;
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
if
constexpr
(
we
==
5
&&
is_half
&&
!
is_fnuz
)
{
{
...
@@ -540,11 +553,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
...
@@ -540,11 +553,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
constexpr
int
mfmt
=
(
sizeof
(
T
)
==
8
)
?
52
:
((
sizeof
(
T
)
==
4
)
?
23
:
10
);
constexpr
int
mfmt
=
(
sizeof
(
T
)
==
8
)
?
52
:
((
sizeof
(
T
)
==
4
)
?
23
:
10
);
using
T_bitwise
=
typename
__hip_internal
::
conditional
<
using
T_bitwise
=
typename
conditional
<
sizeof
(
T
)
==
2
,
sizeof
(
T
)
==
2
,
unsigned
short
int
,
unsigned
short
int
,
typename
__hip_internal
::
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
typename
conditional
<
sizeof
(
T
)
==
4
,
unsigned
int
,
unsigned
long
long
>::
type
>::
type
;
type
>::
type
;
T_bitwise
x_bitwise
=
bit_cast
<
T_bitwise
>
(
_x
);
T_bitwise
x_bitwise
=
bit_cast
<
T_bitwise
>
(
_x
);
unsigned
long
long
x
{
x_bitwise
};
unsigned
long
long
x
{
x_bitwise
};
...
...
include/ck/utility/amd_inline_asm.hpp
View file @
a5137505
...
@@ -4,13 +4,34 @@
...
@@ -4,13 +4,34 @@
#ifndef CK_AMD_INLINE_ASM_HPP
#ifndef CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP
#define CK_AMD_INLINE_ASM_HPP
#include "data_type.hpp"
#include "c_style_pointer_cast.hpp"
#include "c_style_pointer_cast.hpp"
#include "data_type.hpp"
// TODO: deprecate all amd_assembly_outer_product_xxx
// TODO: deprecate all amd_assembly_outer_product_xxx
namespace
ck
{
namespace
ck
{
inline
__device__
int
amd_assembly_and_or_b32
(
int
a
,
int
b
,
int
d
)
{
int
c
;
asm
volatile
(
"v_and_or_b32 %0, %1, %2, %3"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"v"
(
d
));
return
c
;
}
inline
__device__
half2_t
amd_assembly_pk_fma_f16
(
half2_t
a
,
half2_t
b
,
half2_t
c
)
{
half2_t
d
;
asm
volatile
(
"v_pk_fma_f16 %0, %1, %2, %3"
:
"=v"
(
d
)
:
"v"
(
a
),
"v"
(
b
),
"v"
(
c
));
return
d
;
}
inline
__device__
half2_t
amd_assembly_pk_add_f16
(
half2_t
a
,
half2_t
b
)
{
half2_t
c
;
asm
volatile
(
"v_pk_add_f16 %0, %1, %2"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
));
return
c
;
}
// c0 += inner_product(a, b0)
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
// c1 += inner_product(a, b1)
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
...
...
include/ck/utility/data_type.hpp
View file @
a5137505
...
@@ -24,6 +24,17 @@ using bhalf_t = ushort;
...
@@ -24,6 +24,17 @@ using bhalf_t = ushort;
using
half_t
=
_Float16
;
using
half_t
=
_Float16
;
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
// custom data type - pack int4 data
struct
pk_i4_t
{
using
type
=
int8_t
;
type
data
;
__host__
__device__
constexpr
pk_i4_t
()
:
data
{
type
{}}
{}
__host__
__device__
constexpr
pk_i4_t
(
type
init
)
:
data
{
init
}
{}
__host__
__device__
constexpr
operator
float
()
const
{
return
static_cast
<
int8_t
>
(
data
);
}
};
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
inline
constexpr
auto
next_pow2
(
uint32_t
x
)
{
{
// Precondition: x > 1.
// Precondition: x > 1.
...
@@ -177,6 +188,13 @@ struct scalar_type<int4_t>
...
@@ -177,6 +188,13 @@ struct scalar_type<int4_t>
};
};
#endif
#endif
template
<
>
struct
scalar_type
<
pk_i4_t
>
{
using
type
=
pk_i4_t
;
static
constexpr
index_t
vector_size
=
1
;
};
template
<
>
template
<
>
struct
scalar_type
<
f8_fnuz_t
>
struct
scalar_type
<
f8_fnuz_t
>
{
{
...
@@ -1056,6 +1074,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
...
@@ -1056,6 +1074,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using
type
=
bf8_ocp_t
::
data_type
;
using
type
=
bf8_ocp_t
::
data_type
;
};
};
template
<
>
struct
nnvb_data_t_selector
<
pk_i4_t
>
{
using
type
=
pk_i4_t
::
type
;
};
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
non_native_vector_base
<
struct
non_native_vector_base
<
T
,
T
,
...
@@ -1175,6 +1199,14 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
...
@@ -1175,6 +1199,14 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
static
constexpr
index_t
vector_size
=
N
;
static
constexpr
index_t
vector_size
=
N
;
};
};
template
<
index_t
N
>
struct
scalar_type
<
non_native_vector_base
<
pk_i4_t
,
N
>>
{
using
type
=
typename
non_native_vector_base
<
pk_i4_t
,
N
>::
data_t
;
static
constexpr
index_t
vector_size
=
N
;
};
// non-native vector_type implementation
// non-native vector_type implementation
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
struct
vector_type
<
T
,
1
,
typename
ck
::
enable_if_t
<!
is_native_type
<
T
>
()
>>
...
@@ -1882,6 +1914,11 @@ using uint8x8_t = typename vector_type<uint8_t, 8>::type;
...
@@ -1882,6 +1914,11 @@ using uint8x8_t = typename vector_type<uint8_t, 8>::type;
using
uint8x16_t
=
typename
vector_type
<
uint8_t
,
16
>::
type
;
using
uint8x16_t
=
typename
vector_type
<
uint8_t
,
16
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x32_t
=
typename
vector_type
<
uint8_t
,
32
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
using
uint8x64_t
=
typename
vector_type
<
uint8_t
,
64
>::
type
;
// pack int4
using
pk_i4x2_t
=
typename
vector_type
<
pk_i4_t
,
2
>::
type
;
using
pk_i4x4_t
=
typename
vector_type
<
pk_i4_t
,
4
>::
type
;
using
pk_i4x8_t
=
typename
vector_type
<
pk_i4_t
,
8
>::
type
;
#ifdef CK_CODE_GEN_RTC
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
template
<
typename
T
>
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
a5137505
...
@@ -54,7 +54,8 @@ struct DynamicBuffer
...
@@ -54,7 +54,8 @@ struct DynamicBuffer
template
<
typename
X
,
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
||
!
is_native_type
<
X
>
(),
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_element
)
const
__host__
__device__
constexpr
auto
Get
(
index_t
i
,
bool
is_valid_element
)
const
{
{
...
@@ -195,7 +196,8 @@ struct DynamicBuffer
...
@@ -195,7 +196,8 @@ struct DynamicBuffer
template
<
typename
X
,
template
<
typename
X
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
enable_if
<
is_same
<
typename
scalar_type
<
remove_cvref_t
<
X
>
>::
type
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
,
typename
scalar_type
<
remove_cvref_t
<
T
>>::
type
>::
value
||
!
is_native_type
<
X
>
(),
bool
>::
type
=
false
>
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
__host__
__device__
void
Set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
{
...
...
include/ck/utility/math_v2.hpp
View file @
a5137505
...
@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
...
@@ -611,7 +611,7 @@ inline __device__ int8_t neg<int8_t>(int8_t x)
template
<
>
template
<
>
inline
__device__
half_t
neg
<
half_t
>
(
half_t
x
)
inline
__device__
half_t
neg
<
half_t
>
(
half_t
x
)
{
{
return
__hneg
(
x
);
return
__hneg
(
static_cast
<
__half
>
(
x
)
);
};
};
template
<
typename
T
>
template
<
typename
T
>
...
...
include/ck/utility/static_buffer.hpp
View file @
a5137505
...
@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector
...
@@ -116,7 +116,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X
// i is offset of S, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
index_t
I
,
index_t
I
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
(),
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
GetAsType
(
Number
<
I
>
i
)
const
__host__
__device__
constexpr
auto
GetAsType
(
Number
<
I
>
i
)
const
{
{
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
...
@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector
...
@@ -134,7 +135,8 @@ struct StaticBufferTupleOfVector
// i is offset of S, not X. i should be aligned to X
// i is offset of S, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
index_t
I
,
index_t
I
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
,
bool
>::
type
=
false
>
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
(),
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Number
<
I
>
i
,
X
x
)
__host__
__device__
constexpr
void
SetAsType
(
Number
<
I
>
i
,
X
x
)
{
{
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
constexpr
auto
s_per_x
=
Number
<
scalar_type
<
remove_cvref_t
<
X
>>::
vector_size
>
{};
...
...
include/ck/utility/type_convert.hpp
View file @
a5137505
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -467,6 +467,19 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
...
@@ -467,6 +467,19 @@ inline __host__ __device__ float2_t type_convert<float2_t, f8x2_ocp_t>(f8x2_ocp_
#endif
#endif
}
}
template
<
>
inline
__host__
__device__
float2_t
type_convert
<
float2_t
,
pk_i4_t
>
(
pk_i4_t
x
)
{
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
x
);
uint8_t
x_l
=
(
x_u8
&
0x0f
)
>>
0
;
uint8_t
x_h
=
(
x_u8
&
0xf0
)
>>
4
;
auto
l_f32
=
ck
::
type_convert
<
float
>
(
x_l
);
auto
h_f32
=
ck
::
type_convert
<
float
>
(
x_h
);
return
{
l_f32
,
h_f32
};
}
template
<
>
template
<
>
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
float2_t
>
(
float2_t
x
)
inline
__host__
__device__
half2_t
type_convert
<
half2_t
,
float2_t
>
(
float2_t
x
)
{
{
...
...
include/ck_tile/README.md
View file @
a5137505
...
@@ -45,5 +45,8 @@ our implementation of different device operators.
...
@@ -45,5 +45,8 @@ our implementation of different device operators.
**[ops/epilogue]**
**[ops/epilogue]**
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
**[ref]**
reference implementation of cpu or gpu. This folder is supposed to include a specific header on demand.
## examples
## examples
currently we put all ck_tile related example under
[
/example/ck_tile
](
/example/ck_tile/
)
folder. Please check each example's subfolder.
currently we put all ck_tile related example under
[
/example/ck_tile
](
/example/ck_tile/
)
folder. Please check each example's subfolder.
include/ck_tile/core.hpp
View file @
a5137505
...
@@ -54,6 +54,7 @@
...
@@ -54,6 +54,7 @@
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
a5137505
...
@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
...
@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
static_assert
(
static_assert
(
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
double
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
float
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
bf16_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
))
||
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
std
::
is_same
<
T
,
int32_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
(
std
::
is_same
<
T
,
fp8_t
>::
value
&&
(
N
==
1
||
N
==
2
||
N
==
4
||
N
==
8
||
N
==
16
))
||
...
...
include/ck_tile/core/container/meta_data_buffer.hpp
View file @
a5137505
...
@@ -30,7 +30,7 @@ struct meta_data_buffer
...
@@ -30,7 +30,7 @@ struct meta_data_buffer
{
{
constexpr
index_t
size
=
sizeof
(
T
);
constexpr
index_t
size
=
sizeof
(
T
);
auto
tmp
=
bit_cast
<
array
<
std
::
byte
,
size
>>
(
data
);
auto
tmp
=
ck_tile
::
bit_cast
<
array
<
std
::
byte
,
size
>>
(
data
);
for
(
int
i
=
0
;
i
<
size
;
i
++
)
for
(
int
i
=
0
;
i
<
size
;
i
++
)
{
{
...
@@ -66,7 +66,7 @@ struct meta_data_buffer
...
@@ -66,7 +66,7 @@ struct meta_data_buffer
pos
++
;
pos
++
;
}
}
data
=
bit_cast
<
T
>
(
tmp
);
data
=
ck_tile
::
bit_cast
<
T
>
(
tmp
);
}
}
return
data
;
return
data
;
...
@@ -86,7 +86,7 @@ struct meta_data_buffer
...
@@ -86,7 +86,7 @@ struct meta_data_buffer
pos
++
;
pos
++
;
}
}
auto
data
=
bit_cast
<
T
>
(
tmp
);
auto
data
=
ck_tile
::
bit_cast
<
T
>
(
tmp
);
return
data
;
return
data
;
}
}
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
a5137505
...
@@ -29,6 +29,7 @@ struct static_distributed_tensor
...
@@ -29,6 +29,7 @@ struct static_distributed_tensor
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static_assert
(
0
<
kThreadElementSpaceSize
,
"Make sure tile distribution is valid"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
{
{
...
...
include/ck_tile/core/utility/amd_address_space.hpp
0 → 100644
View file @
a5137505
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace
ck_tile
{
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
template
<
typename
T
>
__device__
T
*
cast_pointer_to_generic_address_space
(
T
CK_CONSTANT_ADDRESS_SPACE
*
p
)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template
<
typename
T
>
__host__
__device__
T
CK_CONSTANT_ADDRESS_SPACE
*
cast_pointer_to_constant_address_space
(
T
*
p
)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
CK_CONSTANT_ADDRESS_SPACE
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
}
// namespace ck_tile
include/ck_tile/host/arg_parser.hpp
View file @
a5137505
...
@@ -15,11 +15,14 @@
...
@@ -15,11 +15,14 @@
namespace
ck_tile
{
namespace
ck_tile
{
/*
/*
* a host side utility, arg parser for
* a host side utility, arg parser for, either
* -[key0]=[value0] -[key1]=[value1] ...
* -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
*/
*/
class
ArgParser
class
ArgParser
{
{
public:
public:
class
Arg
class
Arg
{
{
...
@@ -187,6 +190,45 @@ class ArgParser
...
@@ -187,6 +190,45 @@ class ArgParser
return
value
;
return
value
;
}
}
std
::
vector
<
std
::
string
>
get_string_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
std
::
string
s
=
get_str
(
name
);
std
::
vector
<
std
::
string
>
tokens
;
size_t
pos
=
0
;
std
::
string
token
;
while
((
pos
=
s
.
find
(
delimiter
))
!=
std
::
string
::
npos
)
{
token
=
s
.
substr
(
0
,
pos
);
tokens
.
push_back
(
token
);
s
.
erase
(
0
,
pos
+
delimiter
.
length
());
}
tokens
.
push_back
(
s
);
return
tokens
;
}
std
::
vector
<
int
>
get_int_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
const
std
::
vector
<
std
::
string
>
args
=
get_string_vec
(
name
,
delimiter
);
std
::
vector
<
int
>
tokens
;
tokens
.
reserve
(
static_cast
<
int
>
(
args
.
size
()));
for
(
const
std
::
string
&
token
:
args
)
{
int
value
=
atoi
(
token
.
c_str
());
tokens
.
push_back
(
value
);
}
return
tokens
;
}
private:
private:
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
vector
<
std
::
string
>
keys
;
std
::
vector
<
std
::
string
>
keys
;
...
...
include/ck_tile/host/reference/reference_gemm.hpp
View file @
a5137505
...
@@ -97,9 +97,9 @@ template <typename ADataType,
...
@@ -97,9 +97,9 @@ template <typename ADataType,
typename
LayoutA
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutB
,
typename
LayoutC
>
typename
LayoutC
>
void
reference_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_gemm_gpu
(
ADataType
*
a_ptr
,
DeviceMem
&
b_device
,
BDataType
*
b_ptr
,
DeviceMem
&
c_device
,
CDataType
*
c_ptr
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
index_t
stride_b
,
index_t
stride_b
,
index_t
stride_c
)
index_t
stride_c
)
{
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_A
,
d_B
,
d_C
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
errC
=
hipMemcpy
(
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
c_device
.
GetDeviceBuffer
(),
d_C
,
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
return
;
}
}
...
@@ -191,9 +125,9 @@ template <typename ADataType,
...
@@ -191,9 +125,9 @@ template <typename ADataType,
typename
LayoutA
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutB
,
typename
LayoutC
>
typename
LayoutC
>
void
reference_batched_gemm_gpu
(
DeviceMem
&
a_device
,
void
reference_batched_gemm_gpu
(
ADataType
*
a_ptr
,
DeviceMem
&
b_device
,
BDataType
*
b_ptr
,
DeviceMem
&
c_device
,
CDataType
*
c_ptr
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
...
@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
...
@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
index_t
batch_stride_C
,
index_t
batch_stride_C
,
index_t
batch_count
)
index_t
batch_count
)
{
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
{
{
ADataType
*
d_ATemp
=
d_A
+
batch_id
*
batch_stride_A
;
ADataType
*
d_ATemp
=
a_ptr
+
batch_id
*
batch_stride_A
;
BDataType
*
d_BTemp
=
d_B
+
batch_id
*
batch_stride_B
;
BDataType
*
d_BTemp
=
b_ptr
+
batch_id
*
batch_stride_B
;
CDataType
*
d_CTemp
=
d_C
+
batch_id
*
batch_stride_C
;
CDataType
*
d_CTemp
=
c_ptr
+
batch_id
*
batch_stride_C
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
}
}
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
return
;
}
}
}
// namespace ck_tile
}
// namespace ck_tile
Prev
1
2
3
4
5
6
7
8
9
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment