Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
78a300ff
Commit
78a300ff
authored
Oct 07, 2022
by
Alan Turner
Browse files
Update tuning method
parent
dea0555f
Changes
159
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7219 additions
and
0 deletions
+7219
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_description/multi_index_transform.hpp
...l/include/ck/tensor_description/multi_index_transform.hpp
+1954
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_description/multi_index_transform_helper.hpp
...de/ck/tensor_description/multi_index_transform_helper.hpp
+130
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_description/tensor_adaptor.hpp
.../install/include/ck/tensor_description/tensor_adaptor.hpp
+482
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_description/tensor_descriptor.hpp
...stall/include/ck/tensor_description/tensor_descriptor.hpp
+615
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_description/tensor_descriptor_helper.hpp
...nclude/ck/tensor_description/tensor_descriptor_helper.hpp
+165
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_description/tensor_space_filling_curve.hpp
...lude/ck/tensor_description/tensor_space_filling_curve.hpp
+160
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
.../ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
+412
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp
.../tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp
+397
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
...ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
+178
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+962
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
..._operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
+321
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
...clude/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
+115
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
...ration/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
+156
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
...clude/ck/tensor_operation/gpu/block/blockwise_welford.hpp
+108
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
...sor_operation/gpu/block/reduction_functions_blockwise.hpp
+244
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+173
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
+134
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
+158
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
+183
-0
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
...ation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
+172
-0
No files found.
Too many changes to show.
To preserve performance only
159 of 159+
files are displayed.
Plain diff
Email patch
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_description/multi_index_transform.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/multi_index.hpp"
namespace
ck
{
template
<
typename
LowLength
>
struct
PassThrough
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}));
UpLengths
up_lengths_
;
__host__
__device__
constexpr
PassThrough
()
=
default
;
__host__
__device__
constexpr
PassThrough
(
const
LowLength
&
low_length
)
:
up_lengths_
{
make_tuple
(
low_length
)}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
static
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}];
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"PassThrough, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
LowLength
,
typename
LeftPadLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
=
false
>
struct
Pad
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}
+
LeftPadLength
{}
+
RightPadLength
{}));
UpLengths
up_lengths_
;
LeftPadLength
left_pad_length_
;
RightPadLength
right_pad_length_
;
__host__
__device__
constexpr
Pad
()
=
default
;
__host__
__device__
constexpr
Pad
(
const
LowLength
&
low_length
,
const
LeftPadLength
&
left_pad_length
,
const
RightPadLength
&
right_pad_length
)
:
up_lengths_
{
make_tuple
(
low_length
+
left_pad_length
+
right_pad_length
)},
left_pad_length_
{
left_pad_length
},
right_pad_length_
{
right_pad_length
}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}]
-
left_pad_length_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
SkipIsValidCheck
;
}
template
<
typename
UpIdx
>
__host__
__device__
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
idx_up
)
const
{
return
SkipIsValidCheck
||
((
idx_up
[
Number
<
0
>
{}]
>=
left_pad_length_
)
&&
(
idx_up
[
Number
<
0
>
{}]
<
up_lengths_
[
Number
<
0
>
{}]
-
right_pad_length_
));
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
LeftPadLength
>::
value
&&
is_known_at_compile_time
<
RightPadLength
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Pad, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"left_pad_length %d"
,
index_t
{
left_pad_length_
});
printf
(
"right_pad_length %d"
,
index_t
{
right_pad_length_
});
printf
(
"}"
);
}
};
template
<
typename
LowLength
,
typename
LeftPadLength
,
bool
SkipIsValidCheck
=
false
>
struct
LeftPad
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}
+
LeftPadLength
{}));
UpLengths
up_lengths_
;
LeftPadLength
left_pad_length_
;
__host__
__device__
constexpr
LeftPad
()
=
default
;
__host__
__device__
constexpr
LeftPad
(
const
LowLength
&
low_length
,
const
LeftPadLength
&
left_pad_length
)
:
up_lengths_
{
make_tuple
(
low_length
+
left_pad_length
)},
left_pad_length_
{
left_pad_length
}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}]
-
left_pad_length_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
SkipIsValidCheck
;
}
template
<
typename
UpIdx
>
__host__
__device__
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
idx_up
)
const
{
return
SkipIsValidCheck
||
(
idx_up
[
Number
<
0
>
{}]
>=
left_pad_length_
);
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
LeftPadLength
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"LeftPad, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"left_pad_length_ %d"
,
index_t
{
left_pad_length_
});
printf
(
"}"
);
}
};
template
<
typename
LowLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
=
false
>
struct
RightPad
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}
+
RightPadLength
{}));
UpLengths
up_lengths_
;
LowLength
low_length_
;
RightPadLength
right_pad_length_
;
__host__
__device__
constexpr
RightPad
()
=
default
;
__host__
__device__
constexpr
RightPad
(
const
LowLength
&
low_length
,
const
RightPadLength
&
right_pad_length
)
:
up_lengths_
{
make_tuple
(
low_length
+
right_pad_length
)},
low_length_
{
low_length
},
right_pad_length_
{
right_pad_length
}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
static
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}];
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
SkipIsValidCheck
;
}
template
<
typename
UpIdx
>
__host__
__device__
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
idx_up
)
const
{
return
SkipIsValidCheck
||
(
idx_up
[
Number
<
0
>
{}]
<
low_length_
);
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
LowLength
>::
value
&&
is_known_at_compile_time
<
RightPadLength
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"RightPad, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"low_length_ %d"
,
index_t
{
low_length_
});
printf
(
"left_pad_length_ %d"
,
index_t
{
right_pad_length_
});
printf
(
"}"
);
}
};
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
// UpLengths and Coefficients can be either of the followings:
// 1) Tuple of index_t, which is known at run-time, or
// 2) Tuple of Number, which is known at compile-time, or
// 3) Tuple of mixture of index_t and Number, which is known partially at run-time and partially
// at compile-time
template
<
typename
UpLengths
,
typename
Coefficients
,
typename
enable_if
<
UpLengths
::
Size
()
==
Coefficients
::
Size
(),
bool
>
::
type
=
false
>
struct
Embed
{
static
constexpr
index_t
NDimUp
=
UpLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
NDimUp
>
;
UpLengths
up_lengths_
;
Coefficients
coefficients_
;
__host__
__device__
constexpr
Embed
()
=
default
;
__host__
__device__
constexpr
Embed
(
const
UpLengths
&
up_lengths
,
const
Coefficients
&
coefficients
)
:
up_lengths_
{
up_lengths
},
coefficients_
{
coefficients
}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
NDimUp
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
0
;
static_for
<
0
,
NDimUp
,
1
>
{}([
&
idx_low
,
&
idx_up
,
this
](
auto
i
)
{
idx_low
(
Number
<
0
>
{})
+=
idx_up
[
i
]
*
this
->
coefficients_
[
i
];
});
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
NDimUp
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
Number
<
0
>
{})
=
0
;
static_for
<
0
,
NDimUp
,
1
>
{}(
[
&
](
auto
i
)
{
idx_diff_low
(
Number
<
0
>
{})
+=
idx_diff_up
[
i
]
*
coefficients_
[
i
];
});
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
Coefficients
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Embed, "
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"coefficients_ "
);
print_multi_index
(
coefficients_
);
printf
(
"}"
);
}
};
// Implementation of "Merge" transformation primitive that uses regular to do lowering of
// multi-index and use carry-and-borrow check to do lowering of multi-index delta
template
<
typename
LowLengths
>
struct
Merge_v1_carry_check
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
LowLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v1_carry_check
()
=
default
;
__host__
__device__
constexpr
Merge_v1_carry_check
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
// normal division
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
tmp
-=
idx_low
[
i
]
*
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex_1a
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
/* idx_up_new */
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
// CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
// However,
// 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
// can be calculated at compile-time.
// 2) If idx_diff_up is not known at compile-time, but its value
// doesn't change during the whole kernel execution, then
// idx_diff_low_const also
// doesn't change during the whole kernel execution. Compiler generated
// ISA should
// only caclculate idx_diff_low_const once and save it durinng the whole
// kernel execution
// If neither 1) nor 2) is satisfied, then the calculation will also be
// computed at
// run-time each time this function is called, and can be very expensive.
LowerIndex
idx_diff_low_const
;
LowerIndex
idx_low_length_minus_idx_diff_low_const
;
LowerIndex
idx_low_length_plus_idx_diff_low_const
;
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t
tmp
=
idx_diff_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_diff_low_const
(
i
)
=
tmp
/
low_lengths_scan_
[
i
];
tmp
-=
idx_diff_low_const
[
i
]
*
low_lengths_scan_
[
i
];
});
idx_diff_low_const
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
static_for
<
0
,
NDimLow
,
1
>
{}([
&
](
auto
i
)
{
idx_low_length_minus_idx_diff_low_const
(
i
)
=
low_lengths_
[
i
]
-
idx_diff_low_const
[
i
];
idx_low_length_plus_idx_diff_low_const
(
i
)
=
low_lengths_
[
i
]
+
idx_diff_low_const
[
i
];
});
#else
// Hack: this force result into SGPR. Need to make sure the result is thread invariant
index_t
tmp
=
idx_diff_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_diff_low_const
(
i
)
=
__builtin_amdgcn_readfirstlane
(
tmp
/
low_lengths_scan_
[
i
]);
tmp
-=
idx_diff_low_const
[
i
]
*
low_lengths_scan_
[
i
];
});
idx_diff_low_const
(
Number
<
NDimLow
-
1
>
{})
=
__builtin_amdgcn_readfirstlane
(
tmp
);
static_for
<
0
,
NDimLow
,
1
>
{}([
&
](
auto
i
)
{
idx_low_length_minus_idx_diff_low_const
(
i
)
=
__builtin_amdgcn_readfirstlane
(
low_lengths_
[
i
]
-
idx_diff_low_const
[
i
]);
idx_low_length_plus_idx_diff_low_const
(
i
)
=
__builtin_amdgcn_readfirstlane
(
low_lengths_
[
i
]
+
idx_diff_low_const
[
i
]);
});
#endif
if
constexpr
(
Hack
==
1
)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t
carry
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
index_t
idx_low_tmp
=
idx_low
[
i
]
+
carry
;
bool
do_carry
=
idx_low_tmp
>=
idx_low_length_minus_idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_carry
?
-
idx_low_length_minus_idx_diff_low_const
[
i
]
:
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
+=
carry
;
carry
=
do_carry
?
1
:
0
;
});
idx_diff_low
(
Number
<
0
>
{})
=
idx_diff_low_const
[
Number
<
0
>
{}]
+
carry
;
idx_low
+=
idx_diff_low
;
}
else
if
constexpr
(
Hack
==
2
)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t
borrow
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
index_t
idx_low_tmp
=
idx_low
[
i
]
-
borrow
;
bool
do_borrow
=
idx_low_tmp
<
-
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_borrow
?
idx_low_length_plus_idx_diff_low_const
[
i
]
:
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
-=
borrow
;
borrow
=
do_borrow
?
1
:
0
;
});
idx_diff_low
(
Number
<
0
>
{})
=
idx_diff_low_const
[
Number
<
0
>
{}]
-
borrow
;
idx_low
+=
idx_diff_low
;
}
else
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t
carry
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
index_t
idx_low_tmp
=
idx_low
[
i
]
+
carry
;
bool
do_carry
=
idx_low_tmp
>=
idx_low_length_minus_idx_diff_low_const
[
i
];
bool
do_borrow
=
idx_low_tmp
<
-
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_carry
?
-
idx_low_length_minus_idx_diff_low_const
[
i
]
:
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_borrow
?
idx_low_length_plus_idx_diff_low_const
[
i
]
:
idx_diff_low
[
i
];
idx_diff_low
(
i
)
+=
carry
;
carry
=
do_carry
?
1
:
0
;
carry
=
do_borrow
?
-
1
:
carry
;
});
idx_diff_low
(
Number
<
0
>
{})
=
idx_diff_low_const
[
Number
<
0
>
{}]
+
carry
;
idx_low
+=
idx_diff_low
;
}
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex_1b
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
/* idx_up_new */
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
// CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
// However,
// 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
// can be calculated at compile-time.
// 2) If idx_diff_up is not known at compile-time, but its value
// doesn't change during the whole kernel execution, then
// idx_diff_low_const also
// doesn't change during the whole kernel execution. Compiler generated
// ISA should
// only caclculate idx_diff_low_const once and save it durinng the whole
// kernel execution
// If neither 1) nor 2) is satisfied, then the calculation will also be
// computed at
// run-time each time this function is called, and can be very expensive.
LowerIndex
idx_diff_low_const
;
LowerIndex
idx_low_length_minus_idx_diff_low_const
;
LowerIndex
idx_low_length_plus_idx_diff_low_const
;
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t
tmp
=
idx_diff_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_diff_low_const
(
i
)
=
tmp
/
low_lengths_scan_
[
i
];
tmp
-=
idx_diff_low_const
[
i
]
*
low_lengths_scan_
[
i
];
});
idx_diff_low_const
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
static_for
<
0
,
NDimLow
,
1
>
{}([
&
](
auto
i
)
{
idx_low_length_minus_idx_diff_low_const
(
i
)
=
low_lengths_
[
i
]
-
idx_diff_low_const
[
i
];
idx_low_length_plus_idx_diff_low_const
(
i
)
=
low_lengths_
[
i
]
+
idx_diff_low_const
[
i
];
});
#else
// Hack: this force result into SGPR. Need to make sure the result is thread invariant
index_t
tmp
=
idx_diff_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_diff_low_const
(
i
)
=
__builtin_amdgcn_readfirstlane
(
tmp
/
low_lengths_scan_
[
i
]);
tmp
-=
idx_diff_low_const
[
i
]
*
low_lengths_scan_
[
i
];
});
idx_diff_low_const
(
Number
<
NDimLow
-
1
>
{})
=
__builtin_amdgcn_readfirstlane
(
tmp
);
static_for
<
0
,
NDimLow
,
1
>
{}([
&
](
auto
i
)
{
idx_low_length_minus_idx_diff_low_const
(
i
)
=
__builtin_amdgcn_readfirstlane
(
low_lengths_
[
i
]
-
idx_diff_low_const
[
i
]);
idx_low_length_plus_idx_diff_low_const
(
i
)
=
low_lengths_
[
i
]
+
idx_diff_low_const
[
i
];
});
#endif
if
constexpr
(
Hack
==
1
)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t
carry
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
index_t
idx_low_tmp
=
idx_low
[
i
]
+
carry
;
bool
do_carry
=
idx_low_tmp
>=
idx_low_length_minus_idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_carry
?
-
idx_low_length_minus_idx_diff_low_const
[
i
]
:
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
+=
carry
;
carry
=
do_carry
?
1
:
0
;
});
idx_diff_low
(
Number
<
0
>
{})
=
idx_diff_low_const
[
Number
<
0
>
{}]
+
carry
;
idx_low
+=
idx_diff_low
;
}
else
if
constexpr
(
Hack
==
2
)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t
borrow
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
index_t
negative_idx_low_tmp
=
borrow
-
idx_low
[
i
];
bool
do_borrow
=
negative_idx_low_tmp
>
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_borrow
?
idx_low_length_plus_idx_diff_low_const
[
i
]
:
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
-=
borrow
;
borrow
=
do_borrow
?
1
:
0
;
});
idx_diff_low
(
Number
<
0
>
{})
=
idx_diff_low_const
[
Number
<
0
>
{}]
-
borrow
;
idx_low
+=
idx_diff_low
;
}
else
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t
carry
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
index_t
idx_low_tmp
=
idx_low
[
i
]
+
carry
;
bool
do_carry
=
idx_low_tmp
>=
idx_low_length_minus_idx_diff_low_const
[
i
];
bool
do_borrow
=
idx_low_tmp
<
-
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_carry
?
-
idx_low_length_minus_idx_diff_low_const
[
i
]
:
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_borrow
?
idx_low_length_plus_idx_diff_low_const
[
i
]
:
idx_diff_low
[
i
];
idx_diff_low
(
i
)
+=
carry
;
carry
=
do_carry
?
1
:
0
;
carry
=
do_borrow
?
-
1
:
carry
;
});
idx_diff_low
(
Number
<
0
>
{})
=
idx_diff_low_const
[
Number
<
0
>
{}]
+
carry
;
idx_low
+=
idx_diff_low
;
}
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex_2
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
/* idx_up_new */
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
// CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
// However,
// 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
// can be calculated at compile-time.
// 2) If idx_diff_up is not known at compile-time, but its value
// doesn't change during the whole kernel execution, then
// idx_diff_low_const also
// doesn't change during the whole kernel execution. Compiler generated
// ISA should
// only caclculate idx_diff_low_const once and save it durinng the whole
// kernel execution
// If neither 1) nor 2) is satisfied, then the calculation will also be
// computed at run-time each time this function is called, and can be
// very expensive.
LowerIndex
idx_diff_low_const
;
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t
tmp
=
idx_diff_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_diff_low_const
(
i
)
=
tmp
/
low_lengths_scan_
[
i
];
tmp
-=
idx_diff_low_const
[
i
]
*
low_lengths_scan_
[
i
];
});
idx_diff_low_const
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
#else
// Hack: this force result into SGPR. Need to make sure the result is thread invariant
index_t
tmp
=
idx_diff_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_diff_low_const
(
i
)
=
__builtin_amdgcn_readfirstlane
(
tmp
/
low_lengths_scan_
[
i
]);
tmp
-=
idx_diff_low_const
[
i
]
*
low_lengths_scan_
[
i
];
});
idx_diff_low_const
(
Number
<
NDimLow
-
1
>
{})
=
__builtin_amdgcn_readfirstlane
(
tmp
);
#endif
if
constexpr
(
Hack
==
1
)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
bool
do_carry
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
idx_diff_low
(
i
)
=
idx_diff_low_const
[
i
]
+
do_carry
;
index_t
idx_low_tmp
=
idx_low
[
i
]
+
idx_diff_low
[
i
];
do_carry
=
idx_low_tmp
>=
low_lengths_
[
i
];
#if 0
// TODO: use exec-mask inline asm, which use 1 VALU
if(do_carry)
{
idx_diff_low(i) -= low_lengths_[i];
}
#elif
1
// this use 2 VALU
idx_diff_low
(
i
)
=
do_carry
?
idx_diff_low
[
i
]
-
low_lengths_
[
i
]
:
idx_diff_low
[
i
];
#elif 1
// this use 2 VALU
index_t
idx_diff_low_tmp
=
idx_diff_low
[
i
]
-
low_lengths_
[
i
];
idx_diff_low
(
i
)
=
do_carry
?
idx_diff_low_tmp
:
idx_diff_low
[
i
];
#endif
idx_low
(
i
)
+=
idx_diff_low
[
i
];
});
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_low_const
[
I0
]
+
do_carry
;
idx_low
(
I0
)
+=
idx_diff_low
[
I0
];
}
else
if
constexpr
(
Hack
==
2
)
{
// do borrow check on each low dimension in reversed order
// do not need to check the first dimension
bool
do_borrow
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
idx_diff_low
(
i
)
=
idx_diff_low_const
[
i
]
-
do_borrow
;
index_t
idx_low_tmp
=
idx_low
[
i
]
+
idx_diff_low
[
i
];
do_borrow
=
idx_low_tmp
<
0
;
#if 0
// TODO: use exec-mask inline asm
if(do_borrow)
{
idx_diff_low(i) += low_lengths_[i];
}
#elif
1
idx_diff_low
(
i
)
=
do_borrow
?
idx_diff_low
[
i
]
+
low_lengths_
[
i
]
:
idx_diff_low
[
i
];
#elif 1
index_t
idx_diff_low_tmp
=
idx_diff_low
[
i
]
+
low_lengths_
[
i
];
idx_diff_low
(
i
)
=
do_borrow
?
idx_diff_low_tmp
:
idx_diff_low
[
i
];
#endif
idx_low
(
i
)
+=
idx_diff_low
[
i
];
});
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_low_const
[
I0
]
-
do_borrow
;
idx_low
(
I0
)
+=
idx_diff_low
[
I0
];
}
else
{
// not implemented
}
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
#if 1
UpdateLowerIndex_1a
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
#elif 0
UpdateLowerIndex_1b
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
#else
UpdateLowerIndex_2
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
#endif
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsScan
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v1_carry_check, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_scan_ "
);
print_multi_index
(
low_lengths_scan_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
LowLengths
>
struct
lambda_merge_generate_MagicDivision_calculate_magic_multiplier
{
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
()(
Number
<
I
>
i
)
const
{
return
MagicDivision
::
CalculateMagicMultiplier
(
LowLengths
{}[
i
]);
}
};
template
<
typename
LowLengths
>
struct
lambda_merge_generate_MagicDivision_calculate_magic_shift
{
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
()(
Number
<
I
>
i
)
const
{
return
MagicDivision
::
CalculateMagicShift
(
LowLengths
{}[
i
]);
}
};
// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index
// Caution:
// 1. The magic number division implementation being used would produce correct result if the
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
// uint32_t is then used.
// 3. For Merge primitive, upper-index is the dividend.
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative.
template
<
typename
LowLengths
>
struct
Merge_v2_magic_division
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
using
LowLengthsMagicDivisorMultipiler
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_multiplier
<
LowLengths
>
{},
Number
<
NDimLow
>
{}));
using
LowLengthsMagicDivisorShift
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_shift
<
LowLengths
>
{},
Number
<
NDimLow
>
{}));
LowLengths
low_lengths_
;
LowLengthsMagicDivisorMultipiler
low_lengths_magic_divisor_multiplier_
;
LowLengthsMagicDivisorShift
low_lengths_magic_divisor_shift_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v2_magic_division
()
=
default
;
__host__
__device__
constexpr
Merge_v2_magic_division
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_magic_divisor_multiplier_
{
generate_tuple
(
[
&
](
auto
i
)
{
return
MagicDivision
::
CalculateMagicMultiplier
(
low_lengths
[
i
]);
},
Number
<
NDimLow
>
{})},
low_lengths_magic_divisor_shift_
{
generate_tuple
(
[
&
](
auto
i
)
{
return
MagicDivision
::
CalculateMagicShift
(
low_lengths
[
i
]);
},
Number
<
NDimLow
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
,
this
](
auto
i
)
{
index_t
tmp2
=
MagicDivision
::
DoMagicDivision
(
tmp
,
this
->
low_lengths_magic_divisor_multiplier_
[
i
],
this
->
low_lengths_magic_divisor_shift_
[
i
]);
idx_low
(
i
)
=
tmp
-
tmp2
*
this
->
low_lengths_
[
i
];
tmp
=
tmp2
;
});
idx_low
(
Number
<
0
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up_new
[
Number
<
0
>
{}];
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
,
this
](
auto
i
)
{
index_t
tmp2
=
MagicDivision
::
DoMagicDivision
(
tmp
,
this
->
low_lengths_magic_divisor_multiplier_
[
i
],
this
->
low_lengths_magic_divisor_shift_
[
i
]);
index_t
idx_low_old
=
idx_low
[
i
];
idx_low
(
i
)
=
tmp
-
tmp2
*
this
->
low_lengths_
[
i
];
tmp
=
tmp2
;
idx_diff_low
(
i
)
=
idx_low
[
i
]
-
idx_low_old
;
});
idx_diff_low
(
Number
<
0
>
{})
=
tmp
-
idx_low
(
Number
<
0
>
{});
idx_low
(
Number
<
0
>
{})
=
tmp
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsMagicDivisorMultipiler
>::
value
&&
is_known_at_compile_time
<
LowLengthsMagicDivisorShift
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v2_magic_division, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_magic_divisor_multiplier_ "
);
print_multi_index
(
low_lengths_magic_divisor_multiplier_
);
printf
(
"low_lengths_magic_divisor_shift_ "
);
print_multi_index
(
low_lengths_magic_divisor_shift_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index
// Caution:
// 1. The magic number division implementation being used would produce correct result if the
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
// uint32_t is then used.
// 3. For Merge primitive, upper-index is the dividend.
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative.
template
<
typename
LowLengths
>
struct
Merge_v2r2_magic_division
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
LowLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
using
LowLengthsScanMagicDivisorMultipiler
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_multiplier
<
LowLengthsScan
>
{},
Number
<
NDimLow
>
{}));
using
LowLengthsScanMagicDivisorShift
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_shift
<
LowLengthsScan
>
{},
Number
<
NDimLow
>
{}));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
LowLengthsScanMagicDivisorMultipiler
low_lengths_scan_magic_divisor_multiplier_
;
LowLengthsScanMagicDivisorShift
low_lengths_scan_magic_divisor_shift_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v2r2_magic_division
()
=
default
;
__host__
__device__
constexpr
Merge_v2r2_magic_division
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})},
low_lengths_scan_magic_divisor_multiplier_
{
generate_tuple
(
[
&
](
auto
i
)
{
return
MagicDivision
::
CalculateMagicMultiplier
(
low_lengths_scan_
[
i
]);
},
Number
<
NDimLow
>
{})},
low_lengths_scan_magic_divisor_shift_
{
generate_tuple
(
[
&
](
auto
i
)
{
return
MagicDivision
::
CalculateMagicShift
(
low_lengths_scan_
[
i
]);
},
Number
<
NDimLow
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
,
this
](
auto
i
)
{
idx_low
(
i
)
=
MagicDivision
::
DoMagicDivision
(
tmp
,
this
->
low_lengths_scan_magic_divisor_multiplier_
[
i
],
this
->
low_lengths_scan_magic_divisor_shift_
[
i
]);
tmp
-=
idx_low
[
i
]
*
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up_new
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
,
this
](
auto
i
)
{
index_t
idx_low_old
=
idx_low
[
i
];
idx_low
(
i
)
=
MagicDivision
::
DoMagicDivision
(
tmp
,
this
->
low_lengths_scan_magic_divisor_multiplier_
[
i
],
this
->
low_lengths_scan_magic_divisor_shift_
[
i
]);
idx_diff_low
(
i
)
=
idx_low
[
i
]
-
idx_low_old
;
tmp
-=
idx_low
[
i
]
*
this
->
low_lengths_scan_
[
i
];
});
idx_diff_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
-
idx_low
[
Number
<
NDimLow
-
1
>
{}];
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsScanMagicDivisorMultipiler
>::
value
&&
is_known_at_compile_time
<
LowLengthsScanMagicDivisorShift
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v2r2_magic_division, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_scan "
);
print_multi_index
(
low_lengths_scan_
);
printf
(
"low_lengths_scan_magic_divisor_multiplier_ "
);
print_multi_index
(
low_lengths_scan_magic_divisor_multiplier_
);
printf
(
"low_lengths_scan_magic_divisor_shift_ "
);
print_multi_index
(
low_lengths_scan_magic_divisor_shift_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template
<
typename
LowLengths
>
struct
Merge_v3_division_mod
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
LowLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v3_division_mod
()
=
default
;
__host__
__device__
constexpr
Merge_v3_division_mod
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
// division and mod
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
INm1
=
Number
<
NDimLow
-
1
>
{};
index_t
tmp
=
idx_up_new
[
I0
];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
const
index_t
tmp2
=
idx_low
[
i
];
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
idx_diff_low
(
i
)
=
idx_low
[
i
]
-
tmp2
;
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
const
index_t
tmp2
=
idx_low
[
INm1
];
idx_low
(
INm1
)
=
tmp
;
idx_diff_low
(
INm1
)
=
idx_low
[
INm1
]
-
tmp2
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsScan
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v3_direct_division_mod, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_scan_ "
);
print_multi_index
(
low_lengths_scan_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
>
struct
UnMerge
{
static
constexpr
index_t
NDimUp
=
UpLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
NDimUp
>
;
using
UpLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
UpLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
UpLengths
up_lengths_
;
UpLengthsScan
up_lengths_scan_
;
__host__
__device__
constexpr
UnMerge
()
=
default
;
__host__
__device__
constexpr
UnMerge
(
const
UpLengths
&
up_lengths
)
:
up_lengths_
{
up_lengths
},
up_lengths_scan_
{
container_reverse_exclusive_scan
(
up_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
NDimUp
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
if
constexpr
(
!
Use24BitIntegerCalculation
)
{
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
NDimUp
-
1
>
{}];
static_for
<
0
,
NDimUp
-
1
,
1
>
{}(
[
&
](
auto
i
)
{
idx_low
(
Number
<
0
>
{})
+=
idx_up
[
i
]
*
up_lengths_scan_
[
i
];
});
}
else
{
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
NDimUp
-
1
>
{}];
static_for
<
0
,
NDimUp
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
Number
<
0
>
{})
=
(
0x00ffffff
&
idx_low
[
Number
<
0
>
{}])
+
(
0x00ffffff
&
idx_up
[
i
])
*
(
0x00ffffff
&
up_lengths_scan_
[
i
]);
});
}
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
const
{
CalculateLowerIndex
(
idx_diff_low
,
idx_diff_up
);
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
UpLengthsScan
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"UnMerge, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"up_lengths_scan_"
);
print_multi_index
(
up_lengths_scan_
);
printf
(
"}"
);
}
};
template
<
typename
LowerIndex
>
struct
Freeze
{
LowerIndex
low_idx_
;
__host__
__device__
constexpr
Freeze
()
=
default
;
__host__
__device__
constexpr
Freeze
(
const
LowerIndex
&
low_idx
)
:
low_idx_
{
low_idx
}
{}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
0
;
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Tuple
<>
{};
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
/* idx_up */
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
0
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
low_idx_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
/* idx_diff_up */
,
LowIdx
&
/* idx_low */
,
const
UpIdx
&
/* idx_up_new */
,
Number
<
Hack
>
)
{
idx_diff_low
(
Number
<
0
>
{})
=
0
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowerIndex
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"Freeze"
);
printf
(
"low_idx_ %d"
,
index_t
{
low_idx_
});
}
};
// Insert a dangling upper dimension without lower dimension
template
<
typename
UpperLength
>
struct
Insert
{
using
UpLengths
=
decltype
(
make_tuple
(
UpperLength
{}));
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Insert
()
=
default
;
__host__
__device__
constexpr
Insert
(
const
UpperLength
&
up_length
)
:
up_lengths_
{
make_tuple
(
up_length
)}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
0
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
auto
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
,
const
UpIdx
&
)
const
{
static_assert
(
LowIdx
::
Size
()
==
0
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
,
const
UpIdxDiff
&
,
LowIdx
&
,
const
UpIdx
&
,
Number
<
Hack
>
)
{
static_assert
(
LowIdxDiff
::
Size
()
==
0
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
0
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpperLength
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"Insert"
);
print_multi_index
(
up_lengths_
);
}
};
template
<
typename
VectorSize
,
typename
UpLength
>
struct
Vectorize
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
UpLength
{}));
UpLengths
up_lengths_
;
VectorSize
vector_size_
;
__host__
__device__
constexpr
Vectorize
()
=
default
;
__host__
__device__
constexpr
Vectorize
(
const
VectorSize
&
vector_size
,
const
UpLength
&
up_length
)
:
vector_size_
{
vector_size
},
up_lengths_
{
make_tuple
(
up_length
)}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
vector_size_
*
idx_up
[
Number
<
0
>
{}];
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
vector_size_
*
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Vectorize, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
LowLength
,
typename
SliceBegin
,
typename
SliceEnd
>
struct
Slice
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
SliceEnd
{}
-
SliceBegin
{}));
UpLengths
up_lengths_
;
SliceBegin
slice_begin_
;
SliceEnd
slice_end_
;
__host__
__device__
constexpr
Slice
()
=
default
;
__host__
__device__
constexpr
Slice
(
const
LowLength
&
,
const
SliceBegin
&
slice_begin
,
const
SliceEnd
&
slice_end
)
:
up_lengths_
{
make_tuple
(
slice_end
-
slice_begin
)},
slice_begin_
{
slice_begin
},
slice_end_
{
slice_end
}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}]
+
slice_begin_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
)
const
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
SliceBegin
>::
value
&&
is_known_at_compile_time
<
SliceEnd
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Slice, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"slice_begin_ %d"
,
index_t
{
slice_begin_
});
printf
(
"slice_end %d"
,
index_t
{
slice_end_
});
printf
(
"}"
);
}
};
/*
* \brief lower_idx = upper_idx % modulus.
* TODO: Need an improved implementation since the modulo operation is expensive.
*/
template
<
typename
Modulus
,
typename
UpLength
>
struct
Modulo
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
UpLength
{}));
Modulus
modulus_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Modulo
()
=
default
;
__host__
__device__
constexpr
Modulo
(
const
Modulus
&
modulus
,
const
UpLength
&
up_length
)
:
modulus_
{
modulus
},
up_lengths_
{
make_tuple
(
up_length
)}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}]
%
modulus_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
up_idx
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
const
auto
idx_low_old
=
idx_low
;
idx_low
(
I0
)
=
(
up_idx
(
I0
)
+
idx_diff_up
(
I0
))
%
modulus_
;
idx_diff_low
(
I0
)
=
idx_low
-
idx_low_old
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Modulus, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_description/multi_index_transform_helper.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform.hpp"
namespace
ck
{
template
<
typename
LowLength
>
__host__
__device__
constexpr
auto
make_pass_through_transform
(
const
LowLength
&
low_length
)
{
return
PassThrough
<
LowLength
>
{
low_length
};
}
template
<
typename
LowLength
,
typename
LeftPad
,
typename
RightPad
,
bool
SkipIsValidCheck
=
false
>
__host__
__device__
constexpr
auto
make_pad_transform
(
const
LowLength
&
low_length
,
const
LeftPad
&
left_pad
,
const
RightPad
&
right_pad
,
integral_constant
<
bool
,
SkipIsValidCheck
>
=
integral_constant
<
bool
,
false
>
{})
{
return
Pad
<
LowLength
,
LeftPad
,
RightPad
,
SkipIsValidCheck
>
{
low_length
,
left_pad
,
right_pad
};
}
template
<
typename
LowLength
,
typename
LeftPadLength
,
bool
SkipIsValidCheck
=
false
>
__host__
__device__
constexpr
auto
make_left_pad_transform
(
const
LowLength
&
low_length
,
const
LeftPadLength
&
left_pad
,
integral_constant
<
bool
,
SkipIsValidCheck
>
=
integral_constant
<
bool
,
false
>
{})
{
return
LeftPad
<
LowLength
,
LeftPadLength
,
SkipIsValidCheck
>
{
low_length
,
left_pad
};
}
template
<
typename
LowLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
=
false
>
__host__
__device__
constexpr
auto
make_right_pad_transform
(
const
LowLength
&
low_length
,
const
RightPadLength
&
right_pad
,
integral_constant
<
bool
,
SkipIsValidCheck
>
=
integral_constant
<
bool
,
false
>
{})
{
return
RightPad
<
LowLength
,
RightPadLength
,
SkipIsValidCheck
>
{
low_length
,
right_pad
};
}
template
<
typename
UpLengths
,
typename
Coefficients
,
typename
enable_if
<
UpLengths
::
Size
()
==
Coefficients
::
Size
(),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
auto
make_embed_transform
(
const
UpLengths
&
up_lengths
,
const
Coefficients
&
coefficients
)
{
return
Embed
<
UpLengths
,
Coefficients
>
{
up_lengths
,
coefficients
};
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform
(
const
LowLengths
&
low_lengths
)
{
#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return
make_merge_transform_v2_magic_division
(
low_lengths
);
#else
return
make_merge_transform_v1_carry_check
(
low_lengths
);
#endif
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v1_carry_check
(
const
LowLengths
&
low_lengths
)
{
return
Merge_v1_carry_check
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v2_magic_division
(
const
LowLengths
&
low_lengths
)
{
#if 1
return
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
#else
return
Merge_v2r2_magic_division
<
LowLengths
>
{
low_lengths
};
#endif
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v3_division_mod
(
const
LowLengths
&
low_lengths
)
{
return
Merge_v3_division_mod
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
=
false
>
__host__
__device__
constexpr
auto
make_unmerge_transform
(
const
UpLengths
&
up_lengths
,
integral_constant
<
bool
,
Use24BitIntegerCalculation
>
=
integral_constant
<
bool
,
false
>
{})
{
return
UnMerge
<
UpLengths
,
Use24BitIntegerCalculation
>
{
up_lengths
};
}
template
<
typename
LowerIndex
>
__host__
__device__
constexpr
auto
make_freeze_transform
(
const
LowerIndex
&
low_idx
)
{
return
Freeze
<
LowerIndex
>
{
low_idx
};
}
template
<
typename
UpperIndex
>
__host__
__device__
constexpr
auto
make_insert_transform
(
const
UpperIndex
&
up_idx
)
{
return
Insert
<
UpperIndex
>
{
up_idx
};
}
template
<
typename
LowLength
,
typename
SliceBegin
,
typename
SliceEnd
>
__host__
__device__
constexpr
auto
make_slice_transform
(
const
LowLength
&
low_length
,
const
SliceBegin
&
slice_begin
,
const
SliceEnd
&
slice_end
)
{
return
Slice
<
LowLength
,
SliceBegin
,
SliceEnd
>
{
low_length
,
slice_begin
,
slice_end
};
}
template
<
typename
VectorSize
,
typename
UpLength
>
__host__
__device__
constexpr
auto
make_vectorize_transform
(
const
VectorSize
&
vector_size
,
const
UpLength
&
up_length
)
{
return
Vectorize
<
VectorSize
,
UpLength
>
{
vector_size
,
up_length
};
}
template
<
typename
Modulus
,
typename
UpLength
>
__host__
__device__
constexpr
auto
make_modulo_transform
(
const
Modulus
&
modulus
,
const
UpLength
&
up_length
)
{
return
Modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
}
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_description/tensor_adaptor.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
namespace
ck
{
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// BottomDimensionHiddenIds : Sequence<...>
// TopDimensionHiddenIds : Sequence<...>
template
<
typename
Transforms
,
typename
LowerDimensionHiddenIdss
,
typename
UpperDimensionHiddenIdss
,
typename
BottomDimensionHiddenIds
,
typename
TopDimensionHiddenIds
>
struct
TensorAdaptor
{
__host__
__device__
static
constexpr
index_t
GetNumOfTransform
()
{
return
Transforms
::
Size
();
}
__host__
__device__
constexpr
const
auto
&
GetTransforms
()
const
{
return
transforms_
;
}
__host__
__device__
static
constexpr
auto
GetLowerDimensionHiddenIdss
()
{
return
LowerDimensionHiddenIdss
{};
}
__host__
__device__
static
constexpr
auto
GetUpperDimensionHiddenIdss
()
{
return
UpperDimensionHiddenIdss
{};
}
__host__
__device__
static
constexpr
auto
GetTopDimensionHiddenIds
()
{
return
TopDimensionHiddenIds
{};
}
__host__
__device__
static
constexpr
auto
GetBottomDimensionHiddenIds
()
{
return
BottomDimensionHiddenIds
{};
}
__host__
__device__
static
constexpr
auto
InitializeElementSize
(
const
Transforms
&
transforms
)
{
const
auto
lengths
=
generate_tuple
(
[
&
](
auto
idim_top
)
{
constexpr
auto
tmp
=
GetTransformAndItsUpperDimension
(
idim_top
);
constexpr
index_t
itran
=
tmp
[
Number
<
0
>
{}];
constexpr
index_t
idim_up
=
tmp
[
Number
<
1
>
{}];
constexpr
bool
found
=
tmp
[
Number
<
2
>
{}];
static_assert
(
found
==
true
,
"wrong! not found matching transformation and upper-dimension"
);
const
auto
length
=
transforms
[
Number
<
itran
>
{}].
GetUpperLengths
()[
Number
<
idim_up
>
{}];
return
length
;
},
Number
<
ndim_top_
>
{});
// TODO: make container_reduce support tuple of Number and index_t
return
container_reduce
(
lengths
,
math
::
multiplies
{},
Number
<
1
>
{});
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetTransformAndItsUpperDimension
(
Number
<
IDim
>
)
{
constexpr
auto
idim_top
=
Number
<
IDim
>
{};
constexpr
index_t
idim_hidden
=
TopDimensionHiddenIds
::
At
(
idim_top
);
index_t
itran_found
=
0
;
index_t
idim_up_found
=
0
;
bool
found
=
false
;
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
up_dim_ids
=
UpperDimensionHiddenIdss
{}[
itran
];
static_for
<
0
,
up_dim_ids
.
Size
(),
1
>
{}([
&
](
auto
idim_up
)
{
if
constexpr
(
up_dim_ids
[
idim_up
]
==
idim_hidden
)
{
itran_found
=
itran
;
idim_up_found
=
idim_up
;
found
=
true
;
}
});
});
return
make_tuple
(
itran_found
,
idim_up_found
,
found
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfBottomDimension
()
{
return
BottomDimensionHiddenIds
::
Size
();
}
__host__
__device__
static
constexpr
index_t
GetNumOfTopDimension
()
{
return
TopDimensionHiddenIds
::
Size
();
}
__host__
__device__
static
constexpr
index_t
GetNumOfHiddenDimension
()
{
constexpr
auto
all_low_dim_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
LowerDimensionHiddenIdss
{});
constexpr
auto
all_up_dim_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
UpperDimensionHiddenIdss
{});
constexpr
auto
all_dim_ids
=
merge_sequences
(
all_low_dim_ids
,
all_up_dim_ids
);
using
unique_sort_all_dim_ids
=
typename
sequence_unique_sort
<
decltype
(
all_dim_ids
),
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>::
type
;
return
unique_sort_all_dim_ids
::
Size
();
}
constexpr
static
index_t
ntransform_
=
GetNumOfTransform
();
constexpr
static
index_t
ndim_hidden_
=
GetNumOfHiddenDimension
();
constexpr
static
index_t
ndim_bottom_
=
GetNumOfBottomDimension
();
constexpr
static
index_t
ndim_top_
=
GetNumOfTopDimension
();
using
HiddenIndex
=
MultiIndex
<
ndim_hidden_
>
;
using
BottomIndex
=
MultiIndex
<
ndim_bottom_
>
;
using
TopIndex
=
MultiIndex
<
ndim_top_
>
;
// may be index_t or Number<>
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorAdaptor() = default;
#else
__host__
__device__
constexpr
TensorAdaptor
()
:
transforms_
{},
element_size_
{}
{}
#endif
__host__
__device__
constexpr
TensorAdaptor
(
const
Transforms
&
transforms
)
:
transforms_
{
transforms
},
element_size_
{
InitializeElementSize
(
transforms
)}
{
static_assert
(
Transforms
::
Size
()
==
ntransform_
&&
LowerDimensionHiddenIdss
::
Size
()
==
ntransform_
&&
UpperDimensionHiddenIdss
::
Size
()
==
ntransform_
,
"wrong! inconsistent # of transformations"
);
// TODO check dependency of dimensions is valid
}
__host__
__device__
constexpr
auto
GetElementSize
()
const
{
return
element_size_
;
}
#if 0 // debug
template <index_t I>
__host__ __device__ constexpr index_t GetTopDimensionLength(Number<I> idim) const
{
// TODO: not implemented
}
template <index_t I>
__host__ __device__ constexpr index_t GetBottomDimensionLength(Number<I> idim) const
{
// TODO: not implemented
}
#endif
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
static_assert
(
TopIdx
::
Size
()
==
TopDimensionHiddenIds
::
Size
(),
"wrong! # of dimension inconsistent"
);
constexpr
index_t
ntransform
=
GetNumOfTransform
();
constexpr
index_t
ndim_hidden
=
GetNumOfHiddenDimension
();
MultiIndex
<
ndim_hidden
>
idx_hidden
;
// initialize uppest index
set_container_subset
(
idx_hidden
,
GetTopDimensionHiddenIds
(),
idx_top
);
// calculate hidden index
static_for
<
ntransform
,
0
,
-
1
>
{}([
&
](
auto
itran_p1
)
{
auto
itran
=
itran_p1
-
Number
<
1
>
{};
const
auto
&
tran
=
GetTransforms
().
At
(
itran
);
constexpr
auto
dims_low
=
GetLowerDimensionHiddenIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
GetUpperDimensionHiddenIdss
().
At
(
itran
);
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
dims_up
);
MultiIndex
<
dims_low
.
Size
()
>
idx_low
;
tran
.
CalculateLowerIndex
(
idx_low
,
idx_up
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
});
return
get_container_subset
(
idx_hidden
,
BottomDimensionHiddenIds
{});
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cvref_t
<
decltype
(
Transforms
{}[
i
])
>::
IsKnownAtCompileTime
();
});
return
is_known
&&
is_known_at_compile_time
<
ElementSize
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"TensorAdaptor, "
);
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
i
)
{
printf
(
"transforms: "
);
transforms_
[
i
].
Print
();
printf
(
"LowerDimensionHiddenIds:"
);
LowerDimensionHiddenIdss
{}.
At
(
i
).
Print
();
printf
(
"UpperDimensionHiddenIds:"
);
UpperDimensionHiddenIdss
{}.
At
(
i
).
Print
();
});
printf
(
"BottomDimensionHiddenIds:"
);
BottomDimensionHiddenIds
::
Print
();
printf
(
"TopDimensionHiddenIds:"
);
TopDimensionHiddenIds
::
Print
();
printf
(
"}"
);
}
private:
Transforms
transforms_
;
ElementSize
element_size_
;
};
template
<
typename
TensorAdaptor0
,
typename
TensorAdaptor1
>
__host__
__device__
constexpr
auto
chain_tensor_adaptors
(
const
TensorAdaptor0
&
adaptor0
,
const
TensorAdaptor1
&
adaptor1
)
{
static_assert
(
TensorAdaptor0
::
GetNumOfTopDimension
()
==
TensorAdaptor1
::
GetNumOfBottomDimension
(),
"wrong!"
);
// all_transforms = transform0 + transform1
const
auto
all_transforms
=
container_concat
(
adaptor0
.
GetTransforms
(),
adaptor1
.
GetTransforms
());
// shift
constexpr
index_t
adaptor0_max_hidden_id
=
[
&
]()
{
index_t
adaptor0_max_hidden_id_
=
NumericLimits
<
index_t
>::
Min
();
static_for
<
0
,
TensorAdaptor0
::
GetNumOfTransform
(),
1
>
{}([
&
](
auto
itran
)
{
constexpr
index_t
ndim_low
=
TensorAdaptor0
{}.
GetTransforms
()[
itran
].
GetNumOfLowerDimension
();
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
idim_low
)
{
adaptor0_max_hidden_id_
=
math
::
max
(
adaptor0_max_hidden_id_
,
TensorAdaptor0
::
GetLowerDimensionHiddenIdss
()[
itran
][
idim_low
].
value
);
});
constexpr
index_t
ndim_up
=
TensorAdaptor0
{}.
GetTransforms
()[
itran
].
GetNumOfUpperDimension
();
static_for
<
0
,
ndim_up
,
1
>
{}([
&
](
auto
idim_up
)
{
adaptor0_max_hidden_id_
=
math
::
max
(
adaptor0_max_hidden_id_
,
TensorAdaptor0
::
GetUpperDimensionHiddenIdss
()[
itran
][
idim_up
].
value
);
});
});
return
adaptor0_max_hidden_id_
;
}();
constexpr
index_t
adaptor1_min_hidden_id
=
[
&
]()
{
index_t
adaptor1_min_hidden_id_
=
NumericLimits
<
index_t
>::
Max
();
static_for
<
0
,
TensorAdaptor1
::
GetNumOfTransform
(),
1
>
{}([
&
](
auto
itran
)
{
constexpr
index_t
ndim_low
=
TensorAdaptor1
{}.
GetTransforms
()[
itran
].
GetNumOfLowerDimension
();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
idim_low
)
{
constexpr
index_t
low_dim_hidden_id
=
TensorAdaptor1
::
GetLowerDimensionHiddenIdss
()[
itran
][
idim_low
].
value
;
bool
is_bottom_dim
=
false
;
static_for
<
0
,
TensorAdaptor1
::
GetNumOfBottomDimension
(),
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
low_dim_hidden_id
==
TensorAdaptor1
::
GetBottomDimensionHiddenIds
()[
i
])
{
is_bottom_dim
=
true
;
}
});
if
(
!
is_bottom_dim
)
{
adaptor1_min_hidden_id_
=
math
::
min
(
adaptor1_min_hidden_id_
,
low_dim_hidden_id
);
}
});
constexpr
index_t
ndim_up
=
TensorAdaptor1
{}.
GetTransforms
()[
itran
].
GetNumOfUpperDimension
();
// get the min of all upper dimensions
static_for
<
0
,
ndim_up
,
1
>
{}([
&
](
auto
idim_up
)
{
adaptor1_min_hidden_id_
=
math
::
min
(
adaptor1_min_hidden_id_
,
TensorAdaptor1
::
GetUpperDimensionHiddenIdss
()[
itran
][
idim_up
].
value
);
});
});
return
adaptor1_min_hidden_id_
;
}();
constexpr
index_t
adaptor1_hidden_id_shift
=
adaptor0_max_hidden_id
+
1
-
adaptor1_min_hidden_id
;
constexpr
index_t
ndim_bottom_1
=
TensorAdaptor1
::
GetNumOfBottomDimension
();
// all_low_dim_hidden_idss =
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
constexpr
auto
low_dim_hidden_idss_1
=
generate_tuple
(
// generate sequence of ids for a transform
[
&
](
auto
itran
)
{
constexpr
auto
ndim_low_1
=
TensorAdaptor1
::
GetLowerDimensionHiddenIdss
()[
itran
].
Size
();
constexpr
auto
low_dim_hidden_ids_1
=
TensorAdaptor1
::
GetLowerDimensionHiddenIdss
()[
itran
];
// sequence in, sequence out
constexpr
auto
low_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
auto
low_dim_hidden_ids_1_mod_
=
to_multi_index
(
low_dim_hidden_ids_1
);
// shift hidden id so every dim id is unique
static_for
<
0
,
ndim_low_1
,
1
>
{}([
&
](
auto
idim_low_1
)
{
low_dim_hidden_ids_1_mod_
(
idim_low_1
)
+=
adaptor1_hidden_id_shift
;
});
// match hidden id
static_for
<
0
,
ndim_low_1
,
1
>
{}([
&
](
auto
idim_low_1
)
{
static_for
<
0
,
ndim_bottom_1
,
1
>
{}([
&
](
auto
idim_bottom_1
)
{
// if this low dim is bottom dim, then do id matching
if
constexpr
(
low_dim_hidden_ids_1
[
idim_low_1
]
==
TensorAdaptor1
::
GetBottomDimensionHiddenIds
()[
idim_bottom_1
])
{
low_dim_hidden_ids_1_mod_
(
idim_low_1
)
=
TensorAdaptor0
::
GetTopDimensionHiddenIds
()[
idim_bottom_1
];
}
});
});
return
low_dim_hidden_ids_1_mod_
;
}
();
return
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
return
Number
<
low_dim_hidden_ids_1_mod
[
i
]
>
{};
},
Number
<
ndim_low_1
>
{});
},
Number
<
TensorAdaptor1
::
GetNumOfTransform
()
>
{});
constexpr
auto
all_low_dim_hidden_idss
=
container_concat
(
TensorAdaptor0
::
GetLowerDimensionHiddenIdss
(),
low_dim_hidden_idss_1
);
// all_up_dim_hidden_idss =
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
constexpr
auto
up_dim_hidden_idss_1
=
generate_tuple
(
// generate sequence of ids for a transform
[
&
](
auto
itran
)
{
constexpr
auto
ndim_up_1
=
TensorAdaptor1
::
GetUpperDimensionHiddenIdss
()[
itran
].
Size
();
constexpr
auto
up_dim_hidden_ids_1
=
TensorAdaptor1
::
GetUpperDimensionHiddenIdss
()[
itran
];
// sequence in, constexpr tuple out
constexpr
auto
up_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
auto
up_dim_hidden_ids_1_mod_
=
to_multi_index
(
up_dim_hidden_ids_1
);
// shift hidden id
static_for
<
0
,
ndim_up_1
,
1
>
{}([
&
](
auto
idim_up_1
)
{
up_dim_hidden_ids_1_mod_
(
idim_up_1
)
+=
adaptor1_hidden_id_shift
;
});
return
up_dim_hidden_ids_1_mod_
;
}
();
// constexpr tuple to sequence
return
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
return
Number
<
up_dim_hidden_ids_1_mod
[
i
]
>
{};
},
Number
<
ndim_up_1
>
{});
},
Number
<
TensorAdaptor1
::
GetNumOfTransform
()
>
{});
constexpr
auto
all_up_dim_hidden_idss
=
container_concat
(
TensorAdaptor0
::
GetUpperDimensionHiddenIdss
(),
up_dim_hidden_idss_1
);
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
constexpr
auto
bottom_dim_hidden_ids
=
TensorAdaptor0
::
GetBottomDimensionHiddenIds
();
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr
auto
top_dim_hidden_ids
=
TensorAdaptor1
::
GetTopDimensionHiddenIds
()
+
Number
<
adaptor1_hidden_id_shift
>
{};
// put everything together
return
TensorAdaptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
bottom_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
all_transforms
};
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template
<
typename
Transforms
,
typename
LowerDimensionOldTopIdss
,
typename
UpperDimensionNewTopIdss
>
__host__
__device__
constexpr
auto
make_single_stage_tensor_adaptor
(
const
Transforms
&
transforms
,
LowerDimensionOldTopIdss
,
UpperDimensionNewTopIdss
)
{
constexpr
index_t
ntransform
=
Transforms
::
Size
();
static_assert
(
LowerDimensionOldTopIdss
::
Size
()
==
ntransform
&&
UpperDimensionNewTopIdss
::
Size
()
==
ntransform
,
"wrong!"
);
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr
auto
all_low_dim_old_top_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
LowerDimensionOldTopIdss
{});
constexpr
auto
all_up_dim_new_top_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
UpperDimensionNewTopIdss
{});
static_assert
(
is_valid_sequence_map
<
decltype
(
all_low_dim_old_top_ids
)
>::
value
&&
is_valid_sequence_map
<
decltype
(
all_up_dim_new_top_ids
)
>::
value
,
"wrong!"
);
constexpr
index_t
ndim_old_top
=
all_low_dim_old_top_ids
.
Size
();
constexpr
index_t
ndim_new_top
=
all_up_dim_new_top_ids
.
Size
();
// low_dim_hidden_idss
constexpr
auto
low_dim_hidden_idss
=
LowerDimensionOldTopIdss
{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr
auto
up_dim_hidden_idss
=
generate_tuple
(
[](
auto
itran
)
{
return
UpperDimensionNewTopIdss
{}[
itran
]
+
Number
<
ndim_old_top
>
{};
},
Number
<
ntransform
>
{});
// bottom_dim_hidden_ids
constexpr
auto
bottom_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
0
,
ndim_old_top
,
1
>::
type
{};
// top_dim_hidden_ids
constexpr
auto
top_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
0
,
ndim_new_top
,
1
>::
type
{}
+
Number
<
ndim_old_top
>
{};
return
TensorAdaptor
<
remove_cv_t
<
Transforms
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
bottom_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
transforms
};
}
template
<
typename
X
,
typename
...
Xs
,
typename
enable_if
<
sizeof
...(
Xs
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
chain_tensor_adaptors
(
const
X
&
x
,
const
Xs
&
...
xs
)
{
return
chain_tensor_adaptors
(
x
,
chain_tensor_adaptors
(
xs
...));
}
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_description/tensor_descriptor.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/tensor_description/multi_index_transform.hpp"
namespace
ck
{
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
struct
TensorCoordinate
;
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
TensorCoordinateStep
;
// Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
// UpperDimensionIdss : Tuple<Sequence<...>, ...>
// VisibleDimensionIds> : Sequence<...>
template
<
typename
Transforms
,
typename
LowerDimensionIdss
,
typename
UpperDimensionIdss
,
typename
VisibleDimensionIds
,
typename
ElementSpaceSize
>
struct
TensorDescriptor
{
// TODO make these private
__host__
__device__
static
constexpr
index_t
GetNumOfTransform
()
{
return
Transforms
::
Size
();
}
__host__
__device__
static
constexpr
index_t
GetNumOfVisibleDimension
()
{
return
VisibleDimensionIds
::
Size
();
}
__host__
__device__
static
constexpr
index_t
GetNumOfHiddenDimension
()
{
constexpr
auto
all_low_dim_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
LowerDimensionIdss
{});
constexpr
auto
all_up_dim_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
UpperDimensionIdss
{});
constexpr
auto
all_dim_ids
=
merge_sequences
(
all_low_dim_ids
,
all_up_dim_ids
);
using
unique_sort_all_dim_ids
=
typename
sequence_unique_sort
<
decltype
(
all_dim_ids
),
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>::
type
;
return
unique_sort_all_dim_ids
::
Size
();
}
__host__
__device__
static
constexpr
auto
InitializeElementSize
(
const
Transforms
&
transforms
)
{
const
auto
lengths
=
generate_tuple
(
[
&
](
auto
idim_visible
)
{
constexpr
auto
tmp
=
GetTransformAndItsUpperDimension
(
idim_visible
);
constexpr
index_t
itran
=
tmp
[
Number
<
0
>
{}];
constexpr
index_t
idim_up
=
tmp
[
Number
<
1
>
{}];
constexpr
bool
found
=
tmp
[
Number
<
2
>
{}];
static_assert
(
found
==
true
,
"wrong! not found matching transformation and upper-dimension"
);
const
auto
length
=
transforms
[
Number
<
itran
>
{}].
GetUpperLengths
()[
Number
<
idim_up
>
{}];
return
length
;
},
Number
<
ndim_visible_
>
{});
// TODO: make container_reduce support tuple of Number and index_t
return
container_reduce
(
lengths
,
math
::
multiplies
{},
Number
<
1
>
{});
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetTransformAndItsUpperDimension
(
Number
<
IDim
>
)
{
constexpr
auto
idim_visible
=
Number
<
IDim
>
{};
constexpr
index_t
idim_hidden
=
VisibleDimensionIds
::
At
(
idim_visible
);
index_t
itran_found
=
0
;
index_t
idim_up_found
=
0
;
bool
found
=
false
;
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
up_dim_ids
=
UpperDimensionIdss
{}[
itran
];
static_for
<
0
,
up_dim_ids
.
Size
(),
1
>
{}([
&
](
auto
idim_up
)
{
if
constexpr
(
up_dim_ids
[
idim_up
]
==
idim_hidden
)
{
itran_found
=
itran
;
idim_up_found
=
idim_up
;
found
=
true
;
}
});
});
return
make_tuple
(
itran_found
,
idim_up_found
,
found
);
}
constexpr
static
index_t
ntransform_
=
GetNumOfTransform
();
constexpr
static
index_t
ndim_visible_
=
GetNumOfVisibleDimension
();
constexpr
static
index_t
ndim_hidden_
=
GetNumOfHiddenDimension
();
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
using
HiddenIndex
=
MultiIndex
<
ndim_hidden_
>
;
using
Coordinate
=
TensorCoordinate
<
ndim_hidden_
,
VisibleDimensionIds
>
;
// may be index_t or Number<>
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorDescriptor() = default;
#else
__host__
__device__
constexpr
TensorDescriptor
()
:
transforms_
{},
element_size_
{},
element_space_size_
{}
{
}
#endif
__host__
__device__
constexpr
TensorDescriptor
(
const
Transforms
&
transforms
,
ElementSpaceSize
element_space_size
)
:
transforms_
{
transforms
},
element_size_
{
InitializeElementSize
(
transforms
)},
element_space_size_
{
element_space_size
}
{
static_assert
(
Transforms
::
Size
()
==
ntransform_
&&
LowerDimensionIdss
::
Size
()
==
ntransform_
&&
UpperDimensionIdss
::
Size
()
==
ntransform_
,
"wrong! inconsistent # of transformations"
);
// TODO check dependency of dimensions is valid
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
GetNumOfVisibleDimension
();
}
template
<
index_t
IDim
>
__host__
__device__
constexpr
auto
GetLength
(
Number
<
IDim
>
)
const
{
static_assert
(
IDim
>=
0
&&
IDim
<
ndim_visible_
,
"wrong! out of range"
);
constexpr
auto
tmp
=
GetTransformAndItsUpperDimension
(
Number
<
IDim
>
{});
constexpr
index_t
itran
=
tmp
[
Number
<
0
>
{}];
constexpr
index_t
idim_up
=
tmp
[
Number
<
1
>
{}];
constexpr
bool
found
=
tmp
[
Number
<
2
>
{}];
static_assert
(
found
==
true
,
"wrong! not found matching transformation and upper-dimension"
);
return
transforms_
[
Number
<
itran
>
{}].
GetUpperLengths
()[
Number
<
idim_up
>
{}];
}
__host__
__device__
constexpr
auto
GetLengths
()
const
{
// FIXME: use Tuple of reference instead
return
generate_sequence_v2
([
&
](
auto
I
)
{
return
GetLength
(
I
);
},
Number
<
ndim_visible_
>
{});
}
__host__
__device__
constexpr
auto
GetElementSize
()
const
{
return
element_size_
;
}
__host__
__device__
constexpr
auto
GetElementSpaceSize
()
const
{
return
element_space_size_
;
}
template
<
typename
Idx
>
__host__
__device__
constexpr
index_t
CalculateOffset
(
const
Idx
&
idx
)
const
{
static_assert
(
Idx
::
Size
()
==
GetNumOfDimension
(),
"wrong! inconsistent # of dimension"
);
return
make_tensor_coordinate
(
*
this
,
idx
).
GetOffset
();
}
// TODO make these private
__host__
__device__
constexpr
const
auto
&
GetTransforms
()
const
{
return
transforms_
;
}
__host__
__device__
static
constexpr
auto
GetLowerDimensionIdss
()
{
return
LowerDimensionIdss
{};
}
__host__
__device__
static
constexpr
auto
GetUpperDimensionIdss
()
{
return
UpperDimensionIdss
{};
}
__host__
__device__
static
constexpr
auto
GetVisibleDimensionIds
()
{
return
VisibleDimensionIds
{};
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cvref_t
<
decltype
(
Transforms
{}[
i
])
>::
IsKnownAtCompileTime
();
});
return
is_known
&&
is_known_at_compile_time
<
ElementSize
>::
value
&&
is_known_at_compile_time
<
ElementSpaceSize
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"TensorDescriptor, "
);
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
i
)
{
printf
(
"transforms: "
);
transforms_
[
i
].
Print
();
printf
(
"LowerDimensionIds:"
);
LowerDimensionIdss
{}.
At
(
i
).
Print
();
printf
(
"UpperDimensionIds:"
);
UpperDimensionIdss
{}.
At
(
i
).
Print
();
});
printf
(
"}"
);
VisibleDimensionIds
::
Print
();
}
// TODO make these private
Transforms
transforms_
;
ElementSize
element_size_
;
ElementSpaceSize
element_space_size_
;
};
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
struct
TensorCoordinate
{
// TODO make these private
static
constexpr
index_t
ndim_visible_
=
VisibleDimensionIds
::
Size
();
using
HiddenIndex
=
MultiIndex
<
NDimHidden
>
;
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
public:
__host__
__device__
constexpr
TensorCoordinate
()
=
default
;
__host__
__device__
constexpr
TensorCoordinate
(
const
HiddenIndex
&
idx_hidden
)
:
idx_hidden_
{
idx_hidden
}
{
}
__host__
__device__
constexpr
auto
GetIndex
()
const
{
return
GetVisibleIndex
();
}
__host__
__device__
constexpr
index_t
GetOffset
()
const
{
return
idx_hidden_
[
Number
<
0
>
{}];
}
// TODO make these private
__host__
__device__
constexpr
const
auto
&
GetHiddenIndex
()
const
{
return
idx_hidden_
;
}
__host__
__device__
auto
&
GetHiddenIndex
()
{
return
idx_hidden_
;
}
__host__
__device__
constexpr
auto
GetVisibleIndex
()
const
{
return
get_container_subset
(
idx_hidden_
,
VisibleDimensionIds
{});
}
// TODO make these private
HiddenIndex
idx_hidden_
;
};
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
TensorCoordinateStep
{
// TODO make these private
using
VisibleIndex
=
MultiIndex
<
NDimVisible
>
;
public:
__host__
__device__
constexpr
TensorCoordinateStep
()
=
default
;
__host__
__device__
constexpr
TensorCoordinateStep
(
const
VisibleIndex
&
idx_diff_visible
,
const
MultiIndex
<
NTransform
>&
do_transforms
)
:
idx_diff_visible_
{
idx_diff_visible
},
do_transforms_
{
do_transforms
}
{
}
__host__
__device__
constexpr
const
auto
&
GetIndexDiff
()
const
{
return
GetVisibleIndexDiff
();
}
// TODO make these private
__host__
__device__
constexpr
const
auto
&
GetVisibleIndexDiff
()
const
{
return
idx_diff_visible_
;
}
VisibleIndex
idx_diff_visible_
;
MultiIndex
<
NTransform
>
do_transforms_
;
// HACK: control UpdateLowerIndex()
static
constexpr
UpdateLowerIndexHack
update_lower_index_hack_
;
};
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_tensor_descriptor) because template cannot be defined inside a function
// template
template
<
typename
NewTransforms
>
struct
lambda_get_up_dim_num
{
template
<
typename
I
>
__host__
__device__
constexpr
auto
operator
()(
I
)
const
{
using
Tran
=
remove_reference_t
<
decltype
(
NewTransforms
{}.
At
(
I
{}))
>
;
return
Number
<
Tran
::
GetNumOfUpperDimension
()
>
{};
}
};
template
<
typename
OldTensorDescriptor
,
typename
NewTransforms
,
typename
NewLowerDimensionOldVisibleIdss
,
typename
NewUpperDimensionNewVisibleIdss
>
__host__
__device__
constexpr
auto
transform_tensor_descriptor
(
const
OldTensorDescriptor
&
old_tensor_desc
,
const
NewTransforms
&
new_transforms
,
NewLowerDimensionOldVisibleIdss
,
NewUpperDimensionNewVisibleIdss
)
{
// sanity check
{
static_assert
(
NewTransforms
::
Size
()
==
NewLowerDimensionOldVisibleIdss
::
Size
()
&&
NewTransforms
::
Size
()
==
NewUpperDimensionNewVisibleIdss
::
Size
(),
"wrong! inconsitent number of transform"
);
constexpr
auto
all_old_top_ids
=
unpack
([](
auto
...
xs
)
{
return
merge_sequences
(
xs
...);
},
NewLowerDimensionOldVisibleIdss
{});
constexpr
auto
all_new_top_ids
=
unpack
([](
auto
...
xs
)
{
return
merge_sequences
(
xs
...);
},
NewUpperDimensionNewVisibleIdss
{});
static_assert
(
is_valid_sequence_map
<
decltype
(
all_old_top_ids
)
>::
value
&&
is_valid_sequence_map
<
decltype
(
all_new_top_ids
)
>::
value
,
"wrong!"
);
}
// lower dimension's hidden idss
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr
auto
low_dim_hidden_idss
=
transform_tuples
(
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
[](
auto
low_dim_visible_ids
)
constexpr
{
return
transform_sequences
(
// convert lower dimension visible id to hidden id
[](
auto
low_dim_visible_id
)
constexpr
{
return
OldTensorDescriptor
::
GetVisibleDimensionIds
()[
low_dim_visible_id
];
},
low_dim_visible_ids
);
},
NewLowerDimensionOldVisibleIdss
{});
constexpr
index_t
num_new_transform
=
NewTransforms
::
Size
();
// upper dimension's hidden idss
constexpr
index_t
old_hidden_dim_number
=
OldTensorDescriptor
::
GetNumOfHiddenDimension
();
constexpr
auto
up_dim_numbers
=
generate_sequence
(
lambda_get_up_dim_num
<
NewTransforms
>
{},
Number
<
num_new_transform
>
{});
constexpr
auto
up_dim_numbers_scan
=
merge_sequences
(
Sequence
<
0
>
{},
inclusive_scan_sequence
(
up_dim_numbers
,
math
::
plus
<
index_t
>
{},
Number
<
0
>
{}));
constexpr
auto
up_dim_hidden_idss
=
generate_tuple
(
[
old_hidden_dim_number
,
up_dim_numbers_scan
](
auto
i
)
constexpr
{
return
typename
arithmetic_sequence_gen
<
old_hidden_dim_number
+
up_dim_numbers_scan
[
i
],
old_hidden_dim_number
+
up_dim_numbers_scan
[
i
+
1
],
1
>::
type
{};
},
Number
<
num_new_transform
>
{});
// new visible dimension's hidden ids
constexpr
auto
unordered_new_visible_dim_hidden_ids
=
unpack
(
[](
auto
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
up_dim_hidden_idss
);
constexpr
auto
new_visible_dim_unordered2ordered
=
unpack
(
[](
auto
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
NewUpperDimensionNewVisibleIdss
{});
constexpr
auto
new_visible_dim_hidden_ids
=
unordered_new_visible_dim_hidden_ids
.
ReorderGivenOld2New
(
new_visible_dim_unordered2ordered
);
// put everything together
const
auto
all_transforms
=
container_concat
(
old_tensor_desc
.
GetTransforms
(),
new_transforms
);
constexpr
auto
all_low_dim_hidden_idss
=
container_concat
(
OldTensorDescriptor
::
GetLowerDimensionIdss
(),
low_dim_hidden_idss
);
constexpr
auto
all_up_dim_hidden_idss
=
container_concat
(
OldTensorDescriptor
::
GetUpperDimensionIdss
(),
up_dim_hidden_idss
);
const
auto
element_space_size
=
old_tensor_desc
.
GetElementSpaceSize
();
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
new_visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
all_transforms
,
element_space_size
};
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
make_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
const
VisibleIndex
&
idx_visible
)
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
"wrong! # of dimension inconsistent"
);
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
auto
visible_dim_ids
=
TensorDesc
::
GetVisibleDimensionIds
();
MultiIndex
<
ndim_hidden
>
idx_hidden
;
// initialize visible index
set_container_subset
(
idx_hidden
,
visible_dim_ids
,
idx_visible
);
// calculate hidden index
static_for
<
ntransform
,
0
,
-
1
>
{}([
&
tensor_desc
,
&
idx_hidden
](
auto
itran_p1
)
{
auto
itran
=
itran_p1
-
Number
<
1
>
{};
const
auto
&
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
dims_up
);
MultiIndex
<
dims_low
.
Size
()
>
idx_low
;
tran
.
CalculateLowerIndex
(
idx_low
,
idx_up
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
});
return
TensorCoordinate
<
ndim_hidden
,
decltype
(
visible_dim_ids
)
>
{
idx_hidden
};
}
// UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex
template
<
typename
TensorDesc
,
typename
VisibleIndex
,
typename
UpdateLowerIndexHack
>
__host__
__device__
constexpr
auto
make_tensor_coordinate_step
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
,
UpdateLowerIndexHack
)
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
"wrong! # of dimension inconsistent"
);
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
index_t
ndim_visible
=
TensorDesc
::
GetNumOfVisibleDimension
();
constexpr
auto
visible_dim_ids
=
TensorDesc
::
GetVisibleDimensionIds
();
static_assert
(
UpdateLowerIndexHack
::
Size
()
==
ntransform
,
"wrong!"
);
// use index_t for boolean type
auto
do_transforms
=
make_zero_multi_index
<
ntransform
>
();
auto
is_non_zero_diff
=
make_zero_multi_index
<
ndim_hidden
>
();
// decide do_transform by checkout non-zero index diff components
MultiIndex
<
VisibleIndex
::
Size
()
>
non_zero_diff_pick_visible
;
static_for
<
0
,
ndim_visible
,
1
>
{}(
[
&
](
auto
i
)
{
non_zero_diff_pick_visible
(
i
)
=
(
idx_diff_visible
[
i
]
!=
0
);
});
set_container_subset
(
is_non_zero_diff
,
visible_dim_ids
,
non_zero_diff_pick_visible
);
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
const
auto
non_zero_diff_pick_up
=
get_container_subset
(
is_non_zero_diff
,
dims_up
);
MultiIndex
<
dims_low
.
Size
()
>
non_zero_diff_pick_low
;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const
bool
idx_diff_up_has_non_zero
=
container_reduce
(
non_zero_diff_pick_up
,
[](
auto
a
,
auto
b
)
constexpr
{
return
a
or
b
;
},
false
);
do_transforms
(
itran
)
=
idx_diff_up_has_non_zero
;
static_for
<
0
,
dims_low
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
non_zero_diff_pick_low
(
i
)
=
idx_diff_up_has_non_zero
;
});
set_container_subset
(
is_non_zero_diff
,
dims_low
,
non_zero_diff_pick_low
);
});
return
TensorCoordinateStep
<
ntransform
,
ndim_visible
,
UpdateLowerIndexHack
>
{
idx_diff_visible
,
do_transforms
};
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
make_tensor_coordinate_step
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
)
{
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
return
make_tensor_coordinate_step
(
TensorDesc
{},
idx_diff_visible
,
typename
uniform_sequence_gen
<
ntransform
,
0
>::
type
{});
}
template
<
typename
TensorDesc
,
typename
TensorCoord
,
typename
TensorCoordStep
>
__host__
__device__
constexpr
void
move_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
TensorCoord
&
coord
,
const
TensorCoordStep
&
coord_step
)
{
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
// this is what needs to be calculated
auto
idx_diff_hidden
=
make_zero_multi_index
<
ndim_hidden
>
();
// initialize visible index diff
set_container_subset
(
idx_diff_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
coord_step
.
GetVisibleIndexDiff
());
// this is what needs to be updated
auto
&
idx_hidden
=
coord
.
GetHiddenIndex
();
// update visible index
auto
idx_hidden_pick_visible
=
get_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
());
idx_hidden_pick_visible
+=
coord_step
.
GetIndexDiff
();
set_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
idx_hidden_pick_visible
);
// update rest of hidden index
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
if
(
coord_step
.
do_transforms_
[
itran
])
{
const
auto
&
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
const
auto
idx_up_new
=
get_container_subset
(
idx_hidden
,
dims_up
);
auto
idx_low
=
get_container_subset
(
idx_hidden
,
dims_low
);
const
auto
idx_diff_up
=
get_container_subset
(
idx_diff_hidden
,
dims_up
);
MultiIndex
<
dims_low
.
Size
()
>
idx_diff_low
;
// HACK: control UpdateLowerIndex for Merge using hack
constexpr
index_t
Hack
=
decltype
(
coord_step
.
update_lower_index_hack_
)
::
At
(
itran
);
tran
.
UpdateLowerIndex
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
set_container_subset
(
idx_diff_hidden
,
dims_low
,
idx_diff_low
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
}
});
}
template
<
typename
TensorDesc
,
typename
TensorCoord
>
__host__
__device__
constexpr
bool
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
const
TensorDesc
&
tensor_desc
,
const
TensorCoord
&
coord
)
{
bool
valid
=
true
;
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
const
auto
&
idx_hidden
=
coord
.
GetHiddenIndex
();
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
tensor_desc
,
&
idx_hidden
,
&
valid
](
auto
itran
)
{
const
auto
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
// check validity, only if current transformation does not always has a valid mapping
if
constexpr
(
!
decltype
(
tran
)
::
IsValidUpperIndexAlwaysMappedToValidLowerIndex
())
{
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
));
// Comment: using valid = valid && .. will result in weird control flow in ISA
valid
&=
tran
.
IsValidUpperIndexMappedToValidLowerIndex
(
idx_up
);
}
});
return
valid
;
}
template
<
typename
TensorDesc
,
typename
TensorCoord
>
__host__
__device__
constexpr
bool
coordinate_has_valid_offset
(
const
TensorDesc
&
tensor_desc
,
const
TensorCoord
&
coord
)
{
// check visible index
const
auto
&
idx_visible
=
coord
.
GetVisibleIndex
();
bool
is_visible_index_valid
=
true
;
static_for
<
0
,
TensorDesc
::
GetNumOfDimension
(),
1
>
{}(
[
&
is_visible_index_valid
,
&
idx_visible
,
&
tensor_desc
](
auto
i
)
{
is_visible_index_valid
=
is_visible_index_valid
&&
(
idx_visible
[
i
]
>=
0
&&
idx_visible
[
i
]
<
tensor_desc
.
GetLength
(
i
));
});
// check other hidden index
return
is_visible_index_valid
&&
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
tensor_desc
,
coord
);
}
template
<
typename
TensorDesc
>
using
TensorCoordinate_t
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
MultiIndex
<
remove_cvref_t
<
TensorDesc
>::
GetNumOfDimension
()
>
{}));
template
<
typename
TensorDesc
>
using
TensorCoordinateStep_t
=
decltype
(
make_tensor_coordinate_step
(
TensorDesc
{},
MultiIndex
<
remove_cvref_t
<
TensorDesc
>::
GetNumOfDimension
()
>
{}));
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_description/tensor_descriptor_helper.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
namespace
ck
{
/*
* These functions create tensor descriptor at runtime. If they are not constexpr, you will
* likely see usage of scratch memory during construction of these tensor descriptors. So
* it's better to call these functions on host and then pass the constructed tensor descritpors
* to GPU. If the tensor descritpors being constructed are constexpr, then you can call these
* functions on GPU without worrying about scratch memory usage.
*/
#if CK_WORKAROUND_SWDEV_275126
template
<
typename
Lengths
,
typename
Strides
,
index_t
I
,
typename
AccOld
>
__host__
__device__
constexpr
auto
calculate_element_space_size_impl
(
const
Lengths
&
lengths
,
const
Strides
&
strides
,
Number
<
I
>
i
,
AccOld
acc_old
)
{
auto
acc_new
=
acc_old
+
(
lengths
[
i
]
-
Number
<
1
>
{})
*
strides
[
i
];
if
constexpr
(
i
.
value
<
Lengths
::
Size
()
-
1
)
{
return
calculate_element_space_size_impl
(
lengths
,
strides
,
i
+
Number
<
1
>
{},
acc_new
);
}
else
{
return
acc_new
;
}
}
#endif
// Lengths..., Strides... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) LongNumber<>
template
<
typename
...
Lengths
,
typename
...
Strides
,
typename
enable_if
<
sizeof
...(
Lengths
)
==
sizeof
...(
Strides
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
auto
make_naive_tensor_descriptor
(
const
Tuple
<
Lengths
...
>&
lengths
,
const
Tuple
<
Strides
...
>&
strides
)
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
const
auto
transforms
=
make_tuple
(
make_embed_transform
(
lengths
,
strides
));
constexpr
auto
low_dim_hidden_idss
=
make_tuple
(
Sequence
<
0
>
{});
constexpr
auto
up_dim_hidden_idss
=
make_tuple
(
typename
arithmetic_sequence_gen
<
1
,
N
+
1
,
1
>::
type
{});
constexpr
auto
visible_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
1
,
N
+
1
,
1
>::
type
{};
#if !CK_WORKAROUND_SWDEV_275126
// rocm-4.1 compiler would crash for recursive labmda
// recursive function for reduction
auto
f
=
[
&
](
auto
fs
,
auto
i
,
auto
acc_old
)
{
auto
acc_new
=
acc_old
+
(
lengths
[
i
]
-
Number
<
1
>
{})
*
strides
[
i
];
if
constexpr
(
i
.
value
<
N
-
1
)
{
return
fs
(
fs
,
i
+
Number
<
1
>
{},
acc_new
);
}
else
{
return
acc_new
;
}
};
const
auto
element_space_size
=
f
(
f
,
Number
<
0
>
{},
LongNumber
<
1
>
{});
#else
const
auto
element_space_size
=
calculate_element_space_size_impl
(
lengths
,
strides
,
Number
<
0
>
{},
LongNumber
<
1
>
{});
#endif
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) LongNumber<>
template
<
typename
...
Lengths
>
__host__
__device__
constexpr
auto
make_naive_tensor_descriptor_packed
(
const
Tuple
<
Lengths
...
>&
lengths
)
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
const
auto
transforms
=
make_tuple
(
make_unmerge_transform
(
lengths
));
constexpr
auto
low_dim_hidden_idss
=
make_tuple
(
Sequence
<
0
>
{});
constexpr
auto
up_dim_hidden_idss
=
make_tuple
(
typename
arithmetic_sequence_gen
<
1
,
N
+
1
,
1
>::
type
{});
constexpr
auto
visible_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
1
,
N
+
1
,
1
>::
type
{};
const
auto
element_space_size
=
container_reduce
(
lengths
,
math
::
multiplies
{},
LongNumber
<
1
>
{});
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// align could be:
// 1) index_t, or
// 2) Number<>
template
<
typename
...
Lengths
,
typename
Align
>
__host__
__device__
constexpr
auto
make_naive_tensor_descriptor_aligned
(
const
Tuple
<
Lengths
...
>&
lengths
,
Align
align
)
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
N
=
sizeof
...(
Lengths
);
const
auto
stride_n_minus_2
=
math
::
integer_least_multiple
(
lengths
[
Number
<
N
-
1
>
{}],
align
);
auto
strides
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
.
value
==
N
-
1
)
{
return
I1
;
}
else
if
constexpr
(
i
.
value
==
N
-
2
)
{
return
Number
<
stride_n_minus_2
>
{};
}
else
{
return
container_reduce
(
lengths
,
math
::
multiplies
{},
Number
<
stride_n_minus_2
>
{},
i
+
I1
,
Number
<
N
-
1
>
{},
I1
);
}
},
Number
<
N
>
{});
return
make_naive_tensor_descriptor
(
lengths
,
strides
);
}
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_description/tensor_space_filling_curve.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/statically_indexed_array_multi_index.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
template
<
typename
TensorLengths
,
typename
DimAccessOrder
,
typename
ScalarsPerAccess
>
// # of scalars per access in each dimension
struct
SpaceFillingCurve
{
static
constexpr
index_t
nDim
=
TensorLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
index_t
ScalarPerVector
=
reduce_on_sequence
(
ScalarsPerAccess
{},
math
::
multiplies
{},
Number
<
1
>
{});
static
constexpr
auto
access_lengths
=
TensorLengths
{}
/
ScalarsPerAccess
{};
static
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
static
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
static
constexpr
auto
to_index_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
ordered_access_lengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
nDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
index_t
GetNumOfAccess
()
{
static_assert
(
TensorLengths
::
Size
()
==
ScalarsPerAccess
::
Size
());
static_assert
(
TensorLengths
{}
%
ScalarsPerAccess
{}
==
typename
uniform_sequence_gen
<
TensorLengths
::
Size
(),
0
>::
type
{});
return
reduce_on_sequence
(
TensorLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})
/
ScalarPerVector
;
}
template
<
index_t
AccessIdx1dBegin
,
index_t
AccessIdx1dEnd
>
static
__device__
__host__
constexpr
auto
GetStepBetween
(
Number
<
AccessIdx1dBegin
>
,
Number
<
AccessIdx1dEnd
>
)
{
static_assert
(
AccessIdx1dBegin
>=
0
,
"1D index should be non-negative"
);
static_assert
(
AccessIdx1dBegin
<
GetNumOfAccess
(),
"1D index should be larger than 0"
);
static_assert
(
AccessIdx1dEnd
>=
0
,
"1D index should be non-negative"
);
static_assert
(
AccessIdx1dEnd
<
GetNumOfAccess
(),
"1D index should be larger than 0"
);
constexpr
auto
idx_begin
=
GetIndex
(
Number
<
AccessIdx1dBegin
>
{});
constexpr
auto
idx_end
=
GetIndex
(
Number
<
AccessIdx1dEnd
>
{});
return
idx_end
-
idx_begin
;
}
template
<
index_t
AccessIdx1d
>
static
__device__
__host__
constexpr
auto
GetForwardStep
(
Number
<
AccessIdx1d
>
)
{
static_assert
(
AccessIdx1d
<
GetNumOfAccess
(),
"1D index should be larger than 0"
);
return
GetStepBetween
(
Number
<
AccessIdx1d
>
{},
Number
<
AccessIdx1d
+
1
>
{});
}
template
<
index_t
AccessIdx1d
>
static
__device__
__host__
constexpr
auto
GetBackwardStep
(
Number
<
AccessIdx1d
>
)
{
static_assert
(
AccessIdx1d
>
0
,
"1D index should be larger than 0"
);
return
GetStepBetween
(
Number
<
AccessIdx1d
>
{},
Number
<
AccessIdx1d
-
1
>
{});
}
template
<
index_t
AccessIdx1d
>
static
__device__
__host__
constexpr
Index
GetIndex
(
Number
<
AccessIdx1d
>
)
{
#if 0
/*
* \todo: TensorAdaptor::CalculateBottomIndex does NOT return constexpr as expected.
*/
constexpr auto ordered_access_idx = to_index_adaptor.CalculateBottomIndex(make_multi_index(Number<AccessIdx1d>{}));
#else
constexpr
auto
access_strides
=
container_reverse_exclusive_scan
(
ordered_access_lengths
,
math
::
multiplies
{},
Number
<
1
>
{});
constexpr
auto
idx_1d
=
Number
<
AccessIdx1d
>
{};
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
constexpr
auto
compute_index
=
[
idx_1d
,
access_strides
](
auto
idim
)
constexpr
{
constexpr
auto
compute_index_impl
=
[
idx_1d
,
access_strides
](
auto
jdim
)
constexpr
{
auto
res
=
idx_1d
.
value
;
auto
id
=
0
;
static_for
<
0
,
jdim
.
value
+
1
,
1
>
{}([
&
](
auto
kdim
)
{
id
=
res
/
access_strides
[
kdim
].
value
;
res
-=
id
*
access_strides
[
kdim
].
value
;
});
return
id
;
};
constexpr
auto
id
=
compute_index_impl
(
idim
);
return
Number
<
id
>
{};
};
constexpr
auto
ordered_access_idx
=
generate_tuple
(
compute_index
,
Number
<
nDim
>
{});
#endif
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
index_t
tmp
=
ordered_access_idx
[
I0
];
static_for
<
1
,
idim
,
1
>
{}(
[
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_idx
[
j
];
});
forward_sweep_
(
idim
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate multi-dim tensor index
auto
idx_md
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
ordered_idx
(
idim
)
=
forward_sweep
[
idim
]
?
ordered_access_idx
[
idim
]
:
ordered_access_lengths
[
idim
]
-
1
-
ordered_access_idx
[
idim
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
ScalarsPerAccess
{};
}();
return
idx_md
;
}
// FIXME: rename this function
template
<
index_t
AccessIdx1d
>
static
__device__
__host__
constexpr
auto
GetIndexTupleOfNumber
(
Number
<
AccessIdx1d
>
)
{
constexpr
auto
idx
=
GetIndex
(
Number
<
AccessIdx1d
>
{});
return
generate_tuple
([
&
](
auto
i
)
{
return
Number
<
idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
}
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp"
namespace
ck
{
// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc_BK0_BM_BK1
,
typename
BBlockDesc_BK0_BN_BK1
,
index_t
BM1PerThreadBM11
,
index_t
BN1PerThreadBN11
,
index_t
BK0PerThread
,
typename
BM10BN10ThreadClusterBM10Xs
,
// Sequence<BM10BN10ThreadClusterBM100,
// BM10BN10ThreadClusterBM101, ...>
typename
BM10BN10ThreadClusterBN10Xs
,
// Sequence<BM10BN10ThreadClusterBN100,
// BM10BN10ThreadClusterBN101, ...>
index_t
AThreadCopyScalarPerVector_BM11
,
index_t
BThreadCopyScalarPerVector_BN11
,
typename
enable_if
<
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
BK0
=
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I0
);
static
constexpr
index_t
BK1
=
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I2
);
static
constexpr
index_t
BM
=
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I1
);
static
constexpr
index_t
BN
=
BBlockDesc_BK0_BN_BK1
{}.
GetLength
(
I1
);
static
constexpr
index_t
BM100
=
BM10BN10ThreadClusterBM10Xs
{}[
I0
];
static
constexpr
index_t
BN100
=
BM10BN10ThreadClusterBN10Xs
{}[
I0
];
static
constexpr
index_t
BM101
=
BM10BN10ThreadClusterBM10Xs
{}[
I1
];
static
constexpr
index_t
BN101
=
BM10BN10ThreadClusterBN10Xs
{}[
I1
];
static
constexpr
index_t
BM11
=
BM1PerThreadBM11
;
static
constexpr
index_t
BN11
=
BN1PerThreadBN11
;
static
constexpr
index_t
BM1
=
BM100
*
BM101
*
BM11
;
static
constexpr
index_t
BN1
=
BN100
*
BN101
*
BN11
;
static
constexpr
index_t
BM0
=
BM
/
BM1
;
static
constexpr
index_t
BN0
=
BN
/
BN1
;
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
const
ABlockDesc_BK0_BM_BK1
&
a_block_desc_bk0_bm_bk1
)
{
const
auto
a_block_bk0_bm0_bm1_bk1
=
transform_tensor_descriptor
(
a_block_desc_bk0_bm_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM1
>
{})),
make_pass_through_transform
(
Number
<
BK1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
a_block_bk0_bm0_bm1_bk1
;
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
const
BBlockDesc_BK0_BN_BK1
&
b_block_desc_bk0_bn_bk1
)
{
const
auto
b_block_desc_bk0_bn0_bn1_bk1
=
transform_tensor_descriptor
(
b_block_desc_bk0_bn_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN1
>
{})),
make_pass_through_transform
(
Number
<
BK1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
b_block_desc_bk0_bn0_bn1_bk1
;
}
__host__
__device__
static
constexpr
auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN
()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM, BN]
constexpr
auto
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM100
>
{},
Number
<
BM101
>
{},
Number
<
BM11
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN100
>
{},
Number
<
BN101
>
{},
Number
<
BN11
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
,
5
,
6
,
7
>
{}));
return
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n
;
}
__host__
__device__
static
constexpr
auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1
()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM0, BM1, BN0, BN1]
constexpr
auto
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
Number
<
BM0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BM100
>
{},
Number
<
BM101
>
{},
Number
<
BM11
>
{})),
make_pass_through_transform
(
Number
<
BN0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BN100
>
{},
Number
<
BN101
>
{},
Number
<
BN11
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{}));
return
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1
;
}
__host__
__device__
static
constexpr
auto
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
()
{
return
Sequence
<
BM0
,
BM11
,
BN0
,
BN11
>
{};
}
static
constexpr
auto
a_block_desc_bk0_bm0_bm1_bk1_
=
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
ABlockDesc_BK0_BM_BK1
{});
static
constexpr
auto
b_block_desc_bk0_bn0_bn1_bk1_
=
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
BBlockDesc_BK0_BN_BK1
{});
public:
__device__
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
],
0
)},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I2
],
c_thread_origin_data_idx_
[
I3
],
0
)}
{
static_assert
(
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
BM101
*
BM100
*
BN101
*
BN100
,
"wrong! blocksize and cluster size not consistent"
);
static_assert
(
BM
%
BM1
==
0
&&
BN
%
BN1
==
0
,
"wrong!"
);
static_assert
(
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I0
)
==
BBlockDesc_BK0_BN_BK1
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
// TODO remove this restriction
static_assert
(
BM10BN10ThreadClusterBM10Xs
::
Size
()
==
2
&&
BM10BN10ThreadClusterBN10Xs
::
Size
()
==
2
,
"wrong!"
);
// TODO: remove this restriction
static_assert
(
BM0
==
2
,
"wrong"
);
static_assert
(
BM0
==
2
&&
BN0
==
2
,
"wrong"
);
}
__device__
static
CIndex
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
index_t
thread_id
)
{
// lower: [BM0, BM1, BN0, BN1]
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
constexpr
auto
adaptor0
=
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1
();
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// upper: [Tid, BM0, BM11, BN0, BN11]
constexpr
auto
adaptor1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
BM100
,
BN100
,
BM101
,
BN101
)),
make_pass_through_transform
(
BM0
),
make_pass_through_transform
(
BM11
),
make_pass_through_transform
(
BN0
),
make_pass_through_transform
(
BN11
)),
make_tuple
(
Sequence
<
1
,
5
,
2
,
6
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
constexpr
auto
adaptor
=
chain_tensor_adaptors
(
adaptor0
,
adaptor1
);
return
adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
,
0
,
0
,
0
,
0
));
}
template
<
typename
CThreadDesc_BM0_BM11_BN0_BN11
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
CThreadDesc_BM0_BM11_BN0_BN11
&
,
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
CThreadDesc_BM0_BM11_BN0_BN11
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
// TODO: remove this restriction
static_assert
(
BM0
==
2
&&
BN0
==
2
&&
CThreadDesc_BM0_BM11_BN0_BN11
{}.
GetLength
(
I0
)
==
BM0
&&
CThreadDesc_BM0_BM11_BN0_BN11
{}.
GetLength
(
I2
)
==
BN0
,
"wrong"
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_thread_desc_bk0_bm0_bm1_bk1_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_bk0_bn0_bn1_bk1_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_contraction
=
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_desc_bk0_bm0_bm1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
CThreadDesc_BM0_BM11_BN0_BN11
,
Sequence
<
BK0PerThread
,
BK1
>
,
Sequence
<
1
,
BM1PerThreadBM11
>
,
Sequence
<
1
,
BN1PerThreadBN11
>>
{};
// read A_sub_0
a_thread_copy_
.
Run
(
a_block_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
b_block_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
b_block_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
a_block_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I1
,
I0
));
// loop over rest of bk0
static_for
<
BK0PerThread
,
BK0
,
BK0PerThread
>
{}([
&
](
auto
bk0
)
{
// read A_sub_0
a_thread_copy_
.
Run
(
a_block_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
bk0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// read B_sub_0
b_thread_copy_
.
Run
(
b_block_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
bk0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I1
,
I0
));
// read B_sub_1
b_thread_copy_
.
Run
(
b_block_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
bk0
,
I1
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
a_block_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
bk0
,
I1
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I1
,
I0
));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I1
,
I0
));
}
private:
// A[BK0, BM0, BM1, BK1]
static
constexpr
auto
a_thread_desc_bk0_bm0_bm1_bk1_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BM0
>
{},
Number
<
BM1PerThreadBM11
>
{},
Number
<
BK1
>
{}));
// B[BK0, BN0, BN1, BK1]
static
constexpr
auto
b_thread_desc_bk0_bn0_bn1_bk1_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BN0
>
{},
Number
<
BN1PerThreadBN11
>
{},
Number
<
BK1
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_bk0_bm0_bm1_bk1_
),
decltype
(
a_thread_desc_bk0_bm0_bm1_bk1_
),
Sequence
<
BK0PerThread
,
1
,
BM1PerThreadBM11
,
BK1
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder
Sequence
<
1
,
1
,
BM1PerThreadBM11
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
Sequence
<
BK0PerThread
,
1
,
BN1PerThreadBN11
,
BK1
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder
Sequence
<
1
,
1
,
BN1PerThreadBN11
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace
ck
{
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. AKMBlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BKNBlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CM0M1N0N1ThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AKMBlockDesc
,
typename
BKNBlockDesc
,
index_t
M1PerThreadM11
,
index_t
N1PerThreadN11
,
index_t
KPerThread
,
index_t
M1N1ThreadClusterM100
,
index_t
M1N1ThreadClusterN100
,
index_t
M1N1ThreadClusterM101
,
index_t
M1N1ThreadClusterN101
,
index_t
AThreadCopyScalarPerVector_M11
,
index_t
BThreadCopyScalarPerVector_N11
,
typename
enable_if
<
AKMBlockDesc
::
IsKnownAtCompileTime
()
&&
BKNBlockDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
K
=
AKMBlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
M
=
AKMBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N
=
BKNBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
M100
=
M1N1ThreadClusterM100
;
static
constexpr
index_t
N100
=
M1N1ThreadClusterN100
;
static
constexpr
index_t
M101
=
M1N1ThreadClusterM101
;
static
constexpr
index_t
N101
=
M1N1ThreadClusterN101
;
static
constexpr
index_t
M11
=
M1PerThreadM11
;
static
constexpr
index_t
N11
=
N1PerThreadN11
;
static
constexpr
index_t
M1
=
M1N1ThreadClusterM100
*
M1N1ThreadClusterM101
*
M1PerThreadM11
;
static
constexpr
index_t
N1
=
M1N1ThreadClusterN100
*
M1N1ThreadClusterN101
*
N1PerThreadN11
;
static
constexpr
index_t
M0
=
M
/
M1
;
static
constexpr
index_t
N0
=
N
/
N1
;
__host__
__device__
static
constexpr
auto
MakeAKM0M1BlockDescriptor
(
const
AKMBlockDesc
&
/* a_k_m_block_desc */
)
{
const
auto
a_k_m0_m1_block_desc
=
transform_tensor_descriptor
(
AKMBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
M0
>
{},
Number
<
M1
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
return
a_k_m0_m1_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeBKN0N1BlockDescriptor
(
const
BKNBlockDesc
&
/* b_k_n_block_desc */
)
{
const
auto
b_k_n0_n1_block_desc
=
transform_tensor_descriptor
(
BKNBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
N0
>
{},
Number
<
N1
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
return
b_k_n0_n1_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor
()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M, N]
constexpr
auto
c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
M0
>
{},
Number
<
M100
>
{},
Number
<
M101
>
{},
Number
<
M11
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
N0
>
{},
Number
<
N100
>
{},
Number
<
N101
>
{},
Number
<
N11
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
,
5
,
6
,
7
>
{}));
return
c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor
;
}
__host__
__device__
static
constexpr
auto
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor
()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M1, N0, N1]
constexpr
auto
c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
Number
<
M0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
M100
>
{},
Number
<
M101
>
{},
Number
<
M11
>
{})),
make_pass_through_transform
(
Number
<
N0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
N100
>
{},
Number
<
N101
>
{},
Number
<
N11
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{}));
return
c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor
;
}
__host__
__device__
static
constexpr
auto
GetCM0M1N0N1ThreadTensorLengths
()
{
return
Sequence
<
M0
,
M11
,
N0
,
N11
>
{};
}
static
constexpr
auto
a_k_m0_m1_block_desc_
=
MakeAKM0M1BlockDescriptor
(
AKMBlockDesc
{});
static
constexpr
auto
b_k_n0_n1_block_desc_
=
MakeBKN0N1BlockDescriptor
(
BKNBlockDesc
{});
public:
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
()
:
c_thread_origin_data_idx_
{
CalculateCM0M1N0N1ThreadOriginOnBlock
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
])},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I2
],
c_thread_origin_data_idx_
[
I3
])}
{
static_assert
(
AKMBlockDesc
::
IsKnownAtCompileTime
()
&&
BKNBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
M101
*
M100
*
N101
*
N100
,
"wrong! blocksize and cluster size not consistent"
);
static_assert
(
M
%
M1
==
0
&&
N
%
N1
==
0
,
"wrong!"
);
static_assert
(
AKMBlockDesc
{}.
GetLength
(
I0
)
==
BKNBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
// TODO: remove this restriction
static_assert
(
M0
==
2
&&
N0
==
2
,
"wrong"
);
}
__device__
static
CIndex
CalculateCM0M1N0N1ThreadOriginOnBlock
(
index_t
thread_id
)
{
// lower: [M0, M1, N0, N1]
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr
auto
adaptor0
=
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor
();
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
// upper: [Tid, M0, M11, N0, N11]
constexpr
auto
adaptor1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M100
,
N100
,
M101
,
N101
)),
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
M11
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
N11
)),
make_tuple
(
Sequence
<
1
,
5
,
2
,
6
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
constexpr
auto
adaptor
=
chain_tensor_adaptors
(
adaptor0
,
adaptor1
);
return
adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
,
0
,
0
,
0
,
0
));
}
__host__
__device__
static
constexpr
index_t
GetABlockAlignment
()
{
return
M1PerThreadM11
;
}
__host__
__device__
static
constexpr
auto
GetBBlockAlignment
()
{
return
N1PerThreadN11
;
}
template
<
typename
CM0M1N0N1ThreadDesc
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
CM0M1N0N1ThreadDesc
&
/* c_m0_m1_n0_n1_thread_desc */
,
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
CM0M1N0N1ThreadDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
// TODO: remove this restriction
static_assert
(
M0
==
2
&&
N0
==
2
&&
CM0M1N0N1ThreadDesc
{}.
GetLength
(
I0
)
==
M0
&&
CM0M1N0N1ThreadDesc
{}.
GetLength
(
I2
)
==
N0
,
"wrong"
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_k_m0_m1_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_k_n0_n1_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_k_m0_m1_thread_desc_
),
decltype
(
b_k_n0_n1_thread_desc_
),
CM0M1N0N1ThreadDesc
,
Sequence
<
KPerThread
>
,
Sequence
<
1
,
M1PerThreadM11
>
,
Sequence
<
1
,
N1PerThreadN11
>>
{};
// read A_sub_0
a_thread_copy_
.
Run
(
a_k_m0_m1_block_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_block_buf
,
a_k_m0_m1_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
b_k_n0_n1_block_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_block_buf
,
b_k_n0_n1_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
b_k_n0_n1_block_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_block_buf
,
b_k_n0_n1_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
a_k_m0_m1_block_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_block_buf
,
a_k_m0_m1_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I1
,
I0
));
// loop over rest of k
static_for
<
KPerThread
,
K
,
KPerThread
>
{}([
&
](
auto
k
)
{
// read A_sub_0
a_thread_copy_
.
Run
(
a_k_m0_m1_block_desc_
,
make_tuple
(
k
,
I0
,
I0
),
a_block_buf
,
a_k_m0_m1_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// read B_sub_0
b_thread_copy_
.
Run
(
b_k_n0_n1_block_desc_
,
make_tuple
(
k
,
I0
,
I0
),
b_block_buf
,
b_k_n0_n1_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I1
,
I0
));
// read B_sub_1
b_thread_copy_
.
Run
(
b_k_n0_n1_block_desc_
,
make_tuple
(
k
,
I1
,
I0
),
b_block_buf
,
b_k_n0_n1_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
a_k_m0_m1_block_desc_
,
make_tuple
(
k
,
I1
,
I0
),
a_block_buf
,
a_k_m0_m1_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I1
,
I0
));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I1
,
I0
));
}
private:
// A[K, M0, M1]
static
constexpr
auto
a_k_m0_m1_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0
>
{},
Number
<
M1PerThreadM11
>
{}));
// B[K, N0, N1]
static
constexpr
auto
b_k_n0_n1_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0
>
{},
Number
<
N1PerThreadN11
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_k_m0_m1_block_desc_
),
decltype
(
a_k_m0_m1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
M1PerThreadM11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
AThreadCopyScalarPerVector_M11
,
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_k_n0_n1_block_desc_
),
decltype
(
b_k_n0_n1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
N1PerThreadN11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BThreadCopyScalarPerVector_N11
,
1
>
;
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
#endif
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#include "common_header.hpp"
#include "threadwise_gemm_dlops_v3.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc_E1_K1_E2
,
typename
BBlockDesc_E1_N_Ho_Wo_E2
,
typename
CThreadDesc_K_N_Ho_Wo
,
index_t
EPerThreadLoop
,
index_t
KPerThreadLoop
>
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
E1
=
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I0
);
static
constexpr
auto
KPerBlock
=
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I1
);
static
constexpr
auto
E2
=
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I2
);
static
constexpr
auto
HoPerBlock
=
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I2
);
static
constexpr
auto
WoPerBlock
=
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I3
);
static
constexpr
auto
KPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I0
);
static
constexpr
auto
HoPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I2
);
static
constexpr
auto
WoPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I3
);
static
constexpr
auto
a_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadLoop
>
{},
Number
<
E2
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{},
Number
<
E2
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
:
c_thread_origin_data_idx_
{
GetBeginOfCThreadDesc_K_N_Ho_Wo
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
]
*
KPerThread
,
0
)}
{
static_assert
(
ABlockDesc_E1_K1_E2
::
IsKnownAtCompileTime
()
&&
BBlockDesc_E1_N_Ho_Wo_E2
::
IsKnownAtCompileTime
()
&&
CThreadDesc_K_N_Ho_Wo
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I0
)
==
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I0
)
&&
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I2
)
==
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I4
),
"wrong! E dimension not consistent
\n
"
);
static_assert
(
E1
%
EPerThreadLoop
==
0
,
""
);
static_assert
(
KPerThread
%
KPerThreadLoop
==
0
,
""
);
static_assert
(
KPerBlock
%
KPerThread
==
0
&&
HoPerBlock
%
HoPerThread
==
0
&&
WoPerBlock
%
WoPerThread
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
constexpr
auto
KThreadCluster
=
KPerBlock
/
KPerThread
;
constexpr
auto
HThreadCluster
=
HoPerBlock
/
HoPerThread
;
constexpr
auto
WThreadCluster
=
WoPerBlock
/
WoPerThread
;
static_assert
(
BlockSize
==
KThreadCluster
*
HThreadCluster
*
WThreadCluster
,
"wrong! wrong blocksize
\n
"
);
}
__device__
static
constexpr
auto
GetCThreadDesc_K_N_Ho_WoLengths
()
{
return
Sequence
<
KPerThread
,
I1
,
HoPerThread
,
WoPerThread
>
{};
}
__device__
static
CIndex
GetBeginOfCThreadDesc_K_N_Ho_Wo
(
index_t
thread_id
)
{
constexpr
auto
K0
=
KPerBlock
/
KPerThread
;
constexpr
auto
N0
=
I1
;
constexpr
auto
H0
=
HoPerBlock
/
HoPerThread
;
constexpr
auto
W0
=
WoPerBlock
/
WoPerThread
;
constexpr
auto
c_threadid_to_k_n_h_w_thread_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
N0
,
H0
,
W0
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_k_n_h_w_thread_cluster_idx
=
c_threadid_to_k_n_h_w_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
return
c_k_n_h_w_thread_cluster_idx
;
}
template
<
typename
ABlockBuffer
,
typename
BThreadBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BThreadBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABlockBuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
a_block_mtx
=
ABlockDesc_E1_K1_E2
{};
// thread A buffer for GEMM
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
(),
true
>
a_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_mtx_
),
decltype
(
b_thread_mtx_
),
decltype
(
c_thread_mtx_
)
>
{};
static_for
<
0
,
E1
,
EPerThreadLoop
>
{}([
&
](
auto
e_begin
)
{
static_for
<
0
,
KPerThread
,
KPerThreadLoop
>
{}([
&
](
auto
k_begin
)
{
a_thread_copy_
.
Run
(
a_block_mtx
,
make_tuple
(
e_begin
,
k_begin
,
I0
),
a_block_buf
,
a_thread_mtx_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
e_begin
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
k_begin
,
I0
,
I0
,
I0
));
});
});
}
template
<
typename
ABlockSliceMoveStepIdx
>
__device__
void
MoveABlockSliceWindow
(
const
ABlockSliceMoveStepIdx
&
a_block_slice_move_step_idx
)
{
a_thread_copy_
.
MoveSrcSliceWindow
(
ABlockDesc_E1_K1_E2
{},
a_block_slice_move_step_idx
);
}
private:
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc_E1_K1_E2
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadLoop
,
E2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
E2
,
E2
>
;
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
};
}
// namespace ck
#endif
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
enum
struct
LoopScheduler
{
Default
,
Interwave
,
};
constexpr
LoopScheduler
make_default_loop_scheduler
()
{
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
return
LoopScheduler
::
Interwave
;
#else
return
LoopScheduler
::
Default
;
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
}
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
TileDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
(
const
TileDesc_K0_MN_K1
&
)
{
constexpr
index_t
K0
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
0
>
{});
constexpr
index_t
K1
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
2
>
{});
return
transform_tensor_descriptor
(
TileDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
KPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
*
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPerThread
*
xdlops_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KPerThread
*
xdlops_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_G_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
auto
c_grid_desc_g_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_g_m_n
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
2
,
4
,
6
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_M0_M1_M2_K
()
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_N0_N1_N2_K
()
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
B_K0
>
{},
Number
<
B_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
static
constexpr
auto
b_block_desc_n0_n1_n2_k
=
MakeBBlockDescriptor_N0_N1_N2_K
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
protected:
// A[M0, M1, M2, KPerThread]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
// B[N0, N1, N2, KPerThread]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
};
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
>
struct
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
:
public
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
A_K1
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
B_K1
;
using
Base
::
c_thread_buf_
;
using
Base
::
c_thread_desc_
;
using
Base
::
CalculateAThreadOriginDataIndex
;
using
Base
::
CalculateBThreadOriginDataIndex
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KPerThread
;
using
Base
::
xdlops_gemm
;
static
constexpr
index_t
KPerInnerLoop
=
math
::
max
(
KPerThread
/
NumMacClusters
,
KPack
);
// 2-wave optimized blockwise gemm
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
KPerThread
,
KPerInnerLoop
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
k
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
k
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
// the first, as we can shorten non-MAC cluster a bit and there's no observable negative
// impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
// some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
// chance of latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if
constexpr
(
k
.
value
!=
0
||
KPerInnerLoop
==
KPerThread
)
{
asm
volatile
(
"s_barrier"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
0
,
0
,
k_
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
0
,
0
,
k_
+
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from blockwise_gemm is
// moved here B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts It is performed near the end of MAC cluster to
// minimize lgkmcnt penalty
if
constexpr
(
k
.
value
==
KPerThread
-
KPerInnerLoop
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
}
protected:
// A[M0, M1, M2, KPerInnerLoop]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
KPerInnerLoop
>
{}));
// B[N0, N1, N2, KPerInnerLoop]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPerInnerLoop
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
};
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
LoopScheduler
LoopSched
>
constexpr
auto
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
()
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
return
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
return
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
};
// Blockwise gemm supporting
// 1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4
// 2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS
// source buffer
// 3. configurable k index starting position and step size after each FMA/XDL instruction
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
TransposeC
=
false
,
index_t
AMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>{}.
K0PerXdlops
,
index_t
BMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{}.
K0PerXdlops
>
struct
BlockwiseGemmXdlops_v2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
A_K0
=
ATileDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BTileDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
ATileDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BTileDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
static_assert
(
KPerThread
%
KPack
==
0
,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack"
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPack
*
xdlops_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KPack
*
xdlops_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
using
Tuple4
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmXdlops_v2
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
AMmaTileDesc
::
IsKnownAtCompileTime
()
&&
BMmaTileDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
}
__host__
__device__
BlockwiseGemmXdlops_v2
(
const
BlockwiseGemmXdlops_v2
&
other
)
:
a_thread_copy_
(
other
.
a_origin
),
b_thread_copy_
(
other
.
b_origin
)
{
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
N
,
M0
,
M1
,
M2
));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_G_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
auto
c_grid_desc_g_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_g_m_n
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
2
,
4
,
6
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
static
constexpr
AMmaTileDesc
a_block_desc_m0_m1_m2_k
;
static
constexpr
BMmaTileDesc
b_block_desc_n0_n1_n2_k
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
KPerThread
/
KPack
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
protected:
// A[M0, M1, M2, KPack]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPack
>
{}));
// B[N0, N1, N2, KPack]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPack
>
{}));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0K0BN0N1N2N3K1BlockDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
KPack
;
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
K0PerThread
=
K0PerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPerThread
*
xdlops_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KPerThread
*
xdlops_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0K0BN0N1N2N3K1BlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_G_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
auto
c_grid_desc_g_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_g_m_n
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
2
,
4
,
6
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_M0_M1_M2_K
()
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
void
MoveABlockSliceWindow
()
{
a_thread_copy_
.
MoveSrcSliceWindow
(
a_block_desc_m0_m1_m2_k
,
make_multi_index
(
0
,
0
,
0
,
K0PerBlock
*
KPack
));
}
__device__
void
ResetABlockStartWindow
()
{
a_thread_copy_
.
SetSrcCoord
(
CalculateAThreadOriginDataIndex
());
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
constexpr
index_t
k0
=
k
/
KPack
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
k0
,
n0
,
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
private:
// A[M0, M1, M2, KPerThread]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
// B[N0, N1, N2, KPerThread]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0PerThread
>
{},
// KPerThread
Number
<
NRepeat
>
{},
// repeat
Number
<
KPack
>
{}));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
AccDataType
,
typename
ThreadMap_M_K
,
// thread_id to m_k
typename
ThreadClusterDesc_M_K
,
typename
ThreadSliceDesc_M_K
,
bool
IgnoreNaN
=
false
>
struct
BlockwiseSoftmax
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
using
ThreadSliceDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
))));
using
ThreadwiseMaxReduce
=
typename
conditional
<
IgnoreNaN
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Max
,
false
,
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Max
,
false
>>::
type
;
using
ThreadwiseSumReduce
=
typename
conditional
<
IgnoreNaN
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Add
,
false
,
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Add
,
false
>>::
type
;
using
ThreadClusterLengths_M_K
=
decltype
(
ThreadClusterDesc_M_K
{}.
GetLengths
());
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction_v2
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadMap_M_K
,
reduce
::
Max
,
false
>
;
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction_v2
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadMap_M_K
,
reduce
::
Add
,
false
>
;
using
BufferType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MRepeat
,
true
>
;
template
<
typename
CThreadBuffer
,
typename
WorkspaceBuffer
>
__host__
__device__
void
Run
(
CThreadBuffer
&
in_thread_buf
,
WorkspaceBuffer
&
reduce_work_buf
)
{
// find max value
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
max_value_buf
(
I
)
=
reduce
::
Max
::
template
GetIdentityValue
<
AccDataType
>();
});
ThreadwiseMaxReduce
::
Reduce
(
in_thread_buf
,
max_value_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I
));
block_sync_lds
();
});
// calculate exp for elements, P=exp(s-max)
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
IgnoreNaN
&&
ck
::
math
::
isnan
(
in_thread_buf
[
offset
])
?
0
:
math
::
exp
(
in_thread_buf
[
offset
]
-
max_value_buf
(
iM
));
});
});
// sum data
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
sum_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
ThreadwiseSumReduce
::
Reduce
(
in_thread_buf
,
sum_value_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
sum_value_buf
(
I
));
block_sync_lds
();
});
}
BufferType
max_value_buf
;
BufferType
sum_value_buf
;
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
BlockSliceLengths
,
typename
ThreadSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
typename
SrcVectorTensorLengths
,
typename
DstVectorTensorLengths
,
typename
SrcVectorTensorContiguousDimOrder
,
typename
DstVectorTensorContiguousDimOrder
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
BlockwiseTensorSliceTransfer_v5r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseTensorSliceTransfer_v5r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
ThreadSliceLengths
{}
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! BlockSize too small"
);
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
ThreadSliceLengths
{};
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
);
}
}
template
<
typename
DstBuffer
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunWrite
(
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
// SrcMoveSliceWindowStepHack to control index calculation move slice window
template
<
typename
SrcMoveSliceWindowStepHack
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
,
const
SrcMoveSliceWindowStepHack
&
src_move_slice_window_step_hack
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
,
src_move_slice_window_step_hack
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v5r1
<
ThreadSliceLengths
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorTensorLengths
,
DstVectorTensorLengths
,
SrcVectorTensorContiguousDimOrder
,
DstVectorTensorContiguousDimOrder
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp"
namespace
ck
{
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace
// 2) work_buffer has T elements, and space size is no less than 3*BlockSize
// 3) mean_value, var_value and count is the input data in vgpr from each thread
// 4) mean_value, var_value and count is the over-written reduced output in vgpr for each thread
// 5) Merge mean and M from ThreadwiseWelford
// clang-format on
template
<
typename
T
,
index_t
BlockSize
,
typename
ThreadClusterLengths_M_K
,
typename
ThreadClusterArrangeOrder
,
bool
GetActualVariance
=
true
>
struct
BlockwiseWelford
{
static_assert
(
BlockSize
==
ThreadClusterLengths_M_K
::
At
(
0
)
*
ThreadClusterLengths_M_K
::
At
(
1
),
"The product of cluster lengths should be same as BlockSize!"
);
static
constexpr
auto
BufferLength_M
=
ThreadClusterLengths_M_K
::
At
(
0
);
static
constexpr
auto
BufferLength_K
=
ThreadClusterLengths_M_K
::
At
(
1
);
static
constexpr
auto
block_buf_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BufferLength_M
>
{},
Number
<
BufferLength_K
>
{}));
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
__device__
static
inline
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
{
int
count
=
count_a
+
count_b
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
var_a
+=
var_b
+
delta
*
delta
*
count_a
*
count_b_over_count
;
count_a
=
count
;
}
__device__
static
void
Run
(
T
&
mean_value
,
T
&
var_value
,
int
&
count
)
{
__shared__
T
mean_block_buf
[
BlockSize
];
__shared__
T
var_block_buf
[
BlockSize
];
__shared__
int
count_block_buf
[
BlockSize
];
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
index_t
offset1
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
);
mean_block_buf
[
offset1
]
=
mean_value
;
var_block_buf
[
offset1
]
=
var_value
;
count_block_buf
[
offset1
]
=
count
;
block_sync_lds
();
static_for
<
0
,
cluster_len_shift
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
indOffset
=
1
<<
(
cluster_len_shift
-
1
-
I
());
if
(
thread_k_cluster_id
<
indOffset
)
{
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
T
mean1
=
mean_block_buf
[
offset1
];
T
var1
=
var_block_buf
[
offset1
];
int
count1
=
count_block_buf
[
offset1
];
T
mean2
=
mean_block_buf
[
offset2
];
T
var2
=
var_block_buf
[
offset2
];
int
count2
=
count_block_buf
[
offset2
];
Merge
(
mean1
,
var1
,
count1
,
mean2
,
var2
,
count2
);
mean_block_buf
[
offset1
]
=
mean1
;
var_block_buf
[
offset1
]
=
var1
;
count_block_buf
[
offset1
]
=
count1
;
}
block_sync_lds
();
});
index_t
offset
=
block_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
));
count
=
count_block_buf
[
offset
];
mean_value
=
mean_block_buf
[
offset
];
if
constexpr
(
GetActualVariance
)
var_value
=
var_block_buf
[
offset
]
/
count
;
else
var_value
=
var_block_buf
[
offset
];
};
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
// 3) in_out_value is the input data in vgpr from each thread
// 4) in_out_value is the over-written reduced output in vgpr for each thread
// clang-format on
template
<
typename
AccDataType
,
index_t
BlockSize
,
typename
ThreadClusterLengths_M_K
,
typename
ThreadClusterArrangeOrder
,
typename
OpReduce
,
bool
PropagateNan
,
typename
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
>
>
struct
PartitionedBlockwiseReduction
{
static_assert
(
BlockSize
==
ThreadClusterLengths_M_K
::
At
(
0
)
*
ThreadClusterLengths_M_K
::
At
(
1
),
"The product of cluster lengths should be same as BlockSize!"
);
static
constexpr
auto
BufferLength_M
=
ThreadClusterLengths_M_K
::
At
(
0
);
static
constexpr
auto
BufferLength_K
=
ThreadClusterLengths_M_K
::
At
(
1
);
static_assert
(
BufferLength_K
>
1
,
"Parallel reduction need work on at least two elements"
);
static
constexpr
auto
block_buf_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BufferLength_M
>
{},
Number
<
BufferLength_K
>
{}));
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
template
<
typename
BufferType
>
__device__
static
void
Reduce
(
BufferType
&
work_buffer
,
AccDataType
&
in_out_value
)
{
static_assert
(
is_same
<
typename
BufferType
::
type
,
AccDataType
>
{},
"Buffer data type should be consistent as AccDataType!"
);
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
work_buffer
(
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
))
=
in_out_value
;
__syncthreads
();
static_for
<
0
,
cluster_len_shift
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
indOffset
=
1
<<
(
cluster_len_shift
-
1
-
I
());
if
(
thread_k_cluster_id
<
indOffset
)
{
index_t
offset1
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
);
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
AccDataType
opData1
=
work_buffer
[
offset1
];
AccDataType
opData2
=
work_buffer
[
offset2
];
Accumulation
::
Calculate
(
opData1
,
opData2
);
work_buffer
(
offset1
)
=
opData1
;
}
__syncthreads
();
});
index_t
offset
=
block_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
));
in_out_value
=
work_buffer
[
offset
];
};
};
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
// 3) in_out_value is the input data in vgpr from each thread
// 4) in_out_value is the over-written reduced output in vgpr for each thread
// clang-format on
template
<
typename
AccDataType
,
index_t
BlockSize
,
typename
ThreadClusterLengths_M_K
,
typename
ThreadClusterDesc
,
typename
OpReduce
,
bool
PropagateNan
,
typename
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
>
>
struct
PartitionedBlockwiseReduction_v2
{
static_assert
(
BlockSize
==
ThreadClusterLengths_M_K
::
At
(
0
)
*
ThreadClusterLengths_M_K
::
At
(
1
),
"The product of cluster lengths should be same as BlockSize!"
);
static
constexpr
auto
BufferLength_M
=
ThreadClusterLengths_M_K
::
At
(
0
);
static
constexpr
auto
BufferLength_K
=
ThreadClusterLengths_M_K
::
At
(
1
);
static_assert
(
BufferLength_K
>
1
,
"Parallel reduction need work on at least two elements"
);
static
constexpr
auto
block_buf_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BufferLength_M
>
{},
Number
<
BufferLength_K
>
{}));
static
constexpr
auto
thread_cluster_desc
=
ThreadClusterDesc
{};
template
<
typename
BufferType
>
__device__
static
void
Reduce
(
BufferType
&
work_buffer
,
AccDataType
&
in_out_value
)
{
static_assert
(
is_same
<
typename
BufferType
::
type
,
AccDataType
>
{},
"Buffer data type should be consistent as AccDataType!"
);
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
work_buffer
(
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
))
=
in_out_value
;
__syncthreads
();
static_for
<
0
,
cluster_len_shift
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
indOffset
=
1
<<
(
cluster_len_shift
-
1
-
I
());
if
(
thread_k_cluster_id
<
indOffset
)
{
index_t
offset1
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
);
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
AccDataType
opData1
=
work_buffer
[
offset1
];
AccDataType
opData2
=
work_buffer
[
offset2
];
Accumulation
::
Calculate
(
opData1
,
opData2
);
work_buffer
(
offset1
)
=
opData1
;
}
__syncthreads
();
});
index_t
offset
=
block_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
));
in_out_value
=
work_buffer
[
offset
];
};
};
// clang-format off
// Assume:
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize
// 3) in_out_value/in_out_index is the input data in vgpr from each thread
// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
// clang-format on
template
<
typename
AccDataType
,
typename
IndexDataType
,
index_t
BlockSize
,
typename
ThreadClusterLengths_M_K
,
typename
ThreadClusterArrangeOrder
,
typename
OpReduce
,
bool
PropagateNan
,
typename
Accumulation
=
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
,
IndexDataType
>
>
struct
PartitionedBlockwiseReductionWithIndex
{
static_assert
(
BlockSize
==
ThreadClusterLengths_M_K
::
At
(
0
)
*
ThreadClusterLengths_M_K
::
At
(
1
),
"The product of cluster lengths should be same as BlockSize!"
);
static
constexpr
auto
BufferLength_M
=
ThreadClusterLengths_M_K
::
At
(
0
);
static
constexpr
auto
BufferLength_K
=
ThreadClusterLengths_M_K
::
At
(
1
);
static_assert
(
BufferLength_K
>
1
,
"Parallel reduction need work on at least two elements"
);
static
constexpr
auto
block_buf_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BufferLength_M
>
{},
Number
<
BufferLength_K
>
{}));
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
// This interface accumulates on both data values and indices
template
<
typename
BufferType
,
typename
IdxBufferType
>
__device__
static
void
Reduce
(
BufferType
&
work_val_buffer
,
IdxBufferType
&
work_idx_buffer
,
AccDataType
&
in_out_value
,
IndexDataType
&
in_out_index
)
{
static_assert
(
is_same
<
typename
BufferType
::
type
,
AccDataType
>
{},
"Buffer data type should be consistent as AccDataType!"
);
static_assert
(
is_same
<
typename
IdxBufferType
::
type
,
IndexDataType
>
{},
"Buffer data type should be consistent as IndexDataType!"
);
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
work_val_buffer
(
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
))
=
in_out_value
;
work_idx_buffer
(
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
))
=
in_out_index
;
__syncthreads
();
static_for
<
0
,
cluster_len_shift
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
indOffset
=
1
<<
I
();
if
(
thread_k_cluster_id
%
(
indOffset
*
2
)
==
0
)
{
index_t
offset1
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
);
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
AccDataType
opData1
=
work_val_buffer
[
offset1
];
AccDataType
opData2
=
work_val_buffer
[
offset2
];
IndexDataType
currIndex1
=
work_idx_buffer
[
offset1
];
IndexDataType
currIndex2
=
work_idx_buffer
[
offset2
];
Accumulation
::
Calculate
(
opData1
,
opData2
,
currIndex1
,
currIndex2
);
work_val_buffer
(
offset1
)
=
opData1
;
work_idx_buffer
(
offset1
)
=
currIndex1
;
}
__syncthreads
();
});
index_t
offset
=
block_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
));
in_out_value
=
work_val_buffer
[
offset
];
in_out_index
=
work_idx_buffer
[
offset
];
};
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
typename
SrcElementwiseOperation
,
typename
DstElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
BlockSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
,
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v4r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v4r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
DstElementwiseOperation
&
dst_element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
src_element_op
,
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
thread_scratch_id
);
}
}
template
<
typename
DstBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunWrite
(
dst_desc
,
dst_buf
,
thread_scratch_id
);
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
index_t
ThreadScratchId
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
)
{
RunRead
(
src_desc
,
src_buf
,
thread_scratch_id
);
RunWrite
(
dst_desc
,
dst_buf
,
thread_scratch_id
);
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v3r1
<
decltype
(
thread_slice_lengths
),
SrcElementwiseOperation
,
DstElementwiseOperation
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
DstScalarStrideInVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
,
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
ThreadGroupTensorSliceTransfer_v6r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v6r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
Run
(
src_desc
,
src_buf
,
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v6r1
<
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
DstInMemOp
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. It does not keep reference to tensor descriptor
// 3. Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
Src0Data
,
typename
Src1Data
,
typename
DstData
,
typename
Src0Desc
,
typename
Src1Desc
,
typename
DstDesc
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrc0ResetCoordinateAfterRun
,
bool
ThreadTransferSrc1ResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
ThreadGroupTensorSliceTransfer_v6r2
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v6r2
(
const
Src0Desc
&
src0_desc
,
const
Index
&
src0_block_slice_origin
,
const
Src1Desc
&
src1_desc
,
const
Index
&
src1_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src0_desc
,
make_zero_multi_index
<
nDim
>
(),
src1_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
Src0Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
Src1Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrc0SliceOrigin
(
src0_desc
,
src0_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetSrc1SliceOrigin
(
src1_desc
,
src1_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
Src0Buffer
,
typename
Src1Buffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
Src0Desc
&
src0_desc
,
const
Src0Buffer
&
src0_buf
,
const
Src1Desc
&
src1_desc
,
const
Src1Buffer
&
src1_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
Run
(
src0_desc
,
src0_buf
,
src1_desc
,
src1_buf
,
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrc0SliceWindow
(
const
Src0Desc
&
src0_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc0SliceWindow
(
src0_desc
,
step
);
}
}
__device__
void
MoveSrc1SliceWindow
(
const
Src1Desc
&
src1_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc1SliceWindow
(
src1_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v6r2
<
Src0Data
,
Src1Data
,
DstData
,
Src0Desc
,
Src1Desc
,
DstDesc
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
DstInMemOp
,
ThreadTransferSrc0ResetCoordinateAfterRun
,
ThreadTransferSrc1ResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
Src0Data
,
typename
Src1Data
,
typename
Src2Data
,
typename
DstData
,
typename
Src0Desc
,
typename
Src1Desc
,
typename
Src2Desc
,
typename
DstDesc
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrc0ResetCoordinateAfterRun
,
bool
ThreadTransferSrc1ResetCoordinateAfterRun
,
bool
ThreadTransferSrc2ResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
ThreadGroupTensorSliceTransfer_v6r3
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
Src0Desc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v6r3
(
const
Src0Desc
&
src0_desc
,
const
Index
&
src0_block_slice_origin
,
const
Src1Desc
&
src1_desc
,
const
Index
&
src1_block_slice_origin
,
const
Src2Desc
&
src2_desc
,
const
Index
&
src2_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src0_desc
,
make_zero_multi_index
<
nDim
>
(),
src1_desc
,
make_zero_multi_index
<
nDim
>
(),
src2_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
Src0Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
Src1Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
Src2Desc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrc0SliceOrigin
(
src0_desc
,
src0_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetSrc1SliceOrigin
(
src1_desc
,
src1_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetSrc2SliceOrigin
(
src2_desc
,
src2_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
Src0Buffer
,
typename
Src1Buffer
,
typename
Src2Buffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
Src0Desc
&
src0_desc
,
const
Src0Buffer
&
src0_buf
,
const
Src1Desc
&
src1_desc
,
const
Src1Buffer
&
src1_buf
,
const
Src2Desc
&
src2_desc
,
const
Src2Buffer
&
src2_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
Run
(
src0_desc
,
src0_buf
,
src1_desc
,
src1_buf
,
src2_desc
,
src2_buf
,
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrc0SliceWindow
(
const
Src0Desc
&
src0_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc0SliceWindow
(
src0_desc
,
step
);
}
}
__device__
void
MoveSrc1SliceWindow
(
const
Src1Desc
&
src1_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc1SliceWindow
(
src1_desc
,
step
);
}
}
__device__
void
MoveSrc2SliceWindow
(
const
Src2Desc
&
src2_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrc2SliceWindow
(
src2_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v6r3
<
Src0Data
,
Src1Data
,
Src2Data
,
DstData
,
Src0Desc
,
Src1Desc
,
Src2Desc
,
DstDesc
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
DstInMemOp
,
ThreadTransferSrc0ResetCoordinateAfterRun
,
ThreadTransferSrc1ResetCoordinateAfterRun
,
ThreadTransferSrc2ResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
deps/cget/pkg/ROCmSoftwarePlatform__composable_kernel/install/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
0 → 100644
View file @
78a300ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp"
namespace
ck
{
// Thread-group level multi-source, multi-destination tensor slice data movement
// Assume:
// 1. All sources and destinations are DynamicBuffer
// 2. Same VectorDim and ScalerPerVector for all sources and destinations
// 3. DstInMemOps are per destination tensor
// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor
// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor
//
// Does following things to avoid scratch memory issue
// 1. Pass tensor descritpors by reference (or tuple of references)
// 2. Does not keep reference to tensor descriptor
// 3. Does not construct new tensor coordinate when call Run()
template
<
typename
ThreadGroup
,
typename
SrcDatas
,
typename
DstDatas
,
typename
SrcDescs
,
typename
DstDescs
,
typename
ElementwiseOperation
,
typename
DstInMemOps
,
// Sequence<InMemoryDataOperationEnum ...>
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
typename
ThreadTransferSrcResetCoordinateAfterRunFlags
,
typename
ThreadTransferDstResetCoordinateAfterRunFlags
>
struct
ThreadGroupTensorSliceTransfer_v7
{
static
constexpr
index_t
nDim
=
remove_cvref_t
<
tuple_element_t
<
0
,
SrcDescs
>>::
GetNumOfDimension
();
static
constexpr
index_t
nSrc
=
remove_cvref_t
<
SrcDescs
>::
Size
();
static
constexpr
index_t
nDst
=
remove_cvref_t
<
DstDescs
>::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
__device__
constexpr
ThreadGroupTensorSliceTransfer_v7
(
const
SrcDescs
&
src_descs
,
const
StaticallyIndexedArray
<
Index
,
nSrc
>&
src_block_slice_origins
,
const
DstDescs
&
dst_descs
,
const
StaticallyIndexedArray
<
Index
,
nDst
>&
dst_block_slice_origins
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_descs
,
StaticallyIndexedArray
<
Index
,
nSrc
>
{},
dst_descs
,
StaticallyIndexedArray
<
Index
,
nDst
>
{},
element_op
)
{
static_assert
(
nSrc
==
SrcDatas
::
Size
()
&&
nSrc
==
SrcDescs
::
Size
()
&&
nSrc
==
ThreadTransferSrcResetCoordinateAfterRunFlags
::
Size
()
&&
nDst
==
DstDatas
::
Size
()
&&
nDst
==
DstDescs
::
Size
()
&&
nDst
==
ThreadTransferDstResetCoordinateAfterRunFlags
::
Size
(),
"wrong!"
);
static_for
<
0
,
nSrc
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
SrcDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_for
<
0
,
nDst
,
1
>
{}([
&
](
auto
i
)
{
static_assert
(
nDim
==
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DstDescs
>>::
GetNumOfDimension
(),
"wrong!"
);
});
static_assert
(
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
src_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
src_block_slice_origins
[
i
]
+
thread_data_idx_begin
;
},
Number
<
nSrc
>
{});
const
auto
dst_thread_slice_origins
=
generate_tuple
(
[
&
](
auto
i
)
{
return
dst_block_slice_origins
[
i
]
+
thread_data_idx_begin
;
},
Number
<
nDst
>
{});
threadwise_transfer_
.
SetSrcSliceOrigins
(
src_descs
,
src_thread_slice_origins
);
threadwise_transfer_
.
SetDstSliceOrigins
(
dst_descs
,
dst_thread_slice_origins
);
}
}
template
<
typename
SrcBuffers
,
typename
DstBuffers
>
__device__
void
Run
(
const
SrcDescs
&
src_descs
,
const
SrcBuffers
&
src_bufs
,
const
DstDescs
&
dst_descs
,
DstBuffers
dst_bufs
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
Run
(
src_descs
,
src_bufs
,
dst_descs
,
dst_bufs
);
}
}
template
<
index_t
ISrc
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDescs
&
src_descs
,
Number
<
ISrc
>
iSrc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_descs
,
iSrc
,
step
);
}
}
template
<
index_t
IDst
>
__device__
void
MoveDstSliceWindow
(
const
DstDescs
&
dst_descs
,
Number
<
IDst
>
iDst
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_descs
,
iDst
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v7
<
SrcDatas
,
DstDatas
,
SrcDescs
,
DstDescs
,
ElementwiseOperation
,
DstInMemOps
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRunFlags
,
ThreadTransferDstResetCoordinateAfterRunFlags
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment