Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_onnxruntime
Commits
78e355fd
Commit
78e355fd
authored
Dec 20, 2022
by
gaoqiong
Browse files
onnxruntime
parent
fae08684
Pipeline
#494
failed with stages
in 0 seconds
Changes
358
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7065 additions
and
0 deletions
+7065
-0
include/ck/stream_config.hpp
include/ck/stream_config.hpp
+14
-0
include/ck/tensor/static_tensor.hpp
include/ck/tensor/static_tensor.hpp
+273
-0
include/ck/tensor_description/cluster_descriptor.hpp
include/ck/tensor_description/cluster_descriptor.hpp
+34
-0
include/ck/tensor_description/multi_index_transform.hpp
include/ck/tensor_description/multi_index_transform.hpp
+1954
-0
include/ck/tensor_description/multi_index_transform_helper.hpp
...de/ck/tensor_description/multi_index_transform_helper.hpp
+130
-0
include/ck/tensor_description/tensor_adaptor.hpp
include/ck/tensor_description/tensor_adaptor.hpp
+482
-0
include/ck/tensor_description/tensor_descriptor.hpp
include/ck/tensor_description/tensor_descriptor.hpp
+615
-0
include/ck/tensor_description/tensor_descriptor_helper.hpp
include/ck/tensor_description/tensor_descriptor_helper.hpp
+165
-0
include/ck/tensor_description/tensor_space_filling_curve.hpp
include/ck/tensor_description/tensor_space_filling_curve.hpp
+162
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
.../ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
+412
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp
.../tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp
+397
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
...ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
+178
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+998
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
..._operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
+321
-0
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
+115
-0
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
...ration/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
+156
-0
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
+108
-0
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
...sor_operation/gpu/block/reduction_functions_blockwise.hpp
+244
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+173
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
+134
-0
No files found.
Too many changes to show.
To preserve performance only
358 of 358+
files are displayed.
Plain diff
Email patch
include/ck/stream_config.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
struct
StreamConfig
{
hipStream_t
stream_id_
=
nullptr
;
bool
time_kernel_
=
false
;
int
log_level_
=
0
;
};
include/ck/tensor/static_tensor.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATIC_TENSOR_HPP
#define CK_STATIC_TENSOR_HPP
namespace
ck
{
// StaticTensor for Scalar
template
<
AddressSpaceEnum
AddressSpace
,
typename
T
,
typename
TensorDesc
,
bool
InvalidElementUseNumericalZeroValue
,
typename
enable_if
<
TensorDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
StaticTensor
{
static
constexpr
auto
desc_
=
TensorDesc
{};
static
constexpr
index_t
ndim_
=
TensorDesc
::
GetNumOfDimension
();
static
constexpr
index_t
element_space_size_
=
desc_
.
GetElementSpaceSize
();
__host__
__device__
constexpr
StaticTensor
()
:
invalid_element_scalar_value_
{
0
}
{}
__host__
__device__
constexpr
StaticTensor
(
T
invalid_element_value
)
:
invalid_element_scalar_value_
{
invalid_element_value
}
{
}
// read access
template
<
typename
Idx
,
typename
enable_if
<
is_known_at_compile_time
<
Idx
>
::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
const
T
&
operator
[](
Idx
)
const
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
return
data_
[
Number
<
offset
>
{}];
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
zero_scalar_value_
;
}
else
{
return
invalid_element_scalar_value_
;
}
}
}
// write access
template
<
typename
Idx
,
typename
enable_if
<
is_known_at_compile_time
<
Idx
>
::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
T
&
operator
()(
Idx
)
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
return
data_
(
Number
<
offset
>
{});
}
else
{
return
ignored_element_scalar_
;
}
}
StaticBuffer
<
AddressSpace
,
T
,
element_space_size_
,
true
>
data_
;
static
constexpr
T
zero_scalar_value_
=
T
{
0
};
const
T
invalid_element_scalar_value_
;
T
ignored_element_scalar_
;
};
// StaticTensor for vector
template
<
AddressSpaceEnum
AddressSpace
,
typename
S
,
index_t
ScalarPerVector
,
typename
TensorDesc
,
bool
InvalidElementUseNumericalZeroValue
,
typename
enable_if
<
TensorDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
StaticTensorTupleOfVectorBuffer
{
static
constexpr
auto
desc_
=
TensorDesc
{};
static
constexpr
index_t
ndim_
=
TensorDesc
::
GetNumOfDimension
();
static
constexpr
index_t
element_space_size_
=
desc_
.
GetElementSpaceSize
();
static
constexpr
index_t
num_of_vector_
=
math
::
integer_divide_ceil
(
element_space_size_
,
ScalarPerVector
);
using
V
=
vector_type
<
S
,
ScalarPerVector
>
;
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
()
:
invalid_element_scalar_value_
{
0
}
{
}
__host__
__device__
constexpr
StaticTensorTupleOfVectorBuffer
(
S
invalid_element_value
)
:
invalid_element_scalar_value_
{
invalid_element_value
}
{
}
// Get S
// Idx is for S, not V
template
<
typename
Idx
,
typename
enable_if
<
is_known_at_compile_time
<
Idx
>
::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
const
S
&
operator
[](
Idx
)
const
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
return
data_
[
Number
<
offset
>
{}];
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
zero_scalar_value_
;
}
else
{
return
invalid_element_scalar_value_
;
}
}
}
// Set S
// Idx is for S, not V
template
<
typename
Idx
,
typename
enable_if
<
is_known_at_compile_time
<
Idx
>
::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
S
&
operator
()(
Idx
)
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
return
data_
(
Number
<
offset
>
{});
}
else
{
return
ignored_element_scalar_
;
}
}
// Get X
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
X
GetAsType
(
Idx
)
const
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
return
data_
.
template
GetAsType
<
X
>(
Number
<
offset
>
{});
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
// TODO: is this right way to initialize a vector?
return
X
{
0
};
}
else
{
// TODO: is this right way to initialize a vector?
return
X
{
invalid_element_scalar_value_
};
}
}
}
// Set X
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
has_same_scalar_type
<
S
,
X
>
::
value
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Idx
,
X
x
)
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
constexpr
bool
is_valid
=
coordinate_has_valid_offset
(
desc_
,
coord
);
if
constexpr
(
is_valid
)
{
data_
.
template
SetAsType
<
X
>(
Number
<
offset
>
{},
x
);
}
}
// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
template
<
typename
Idx
>
__host__
__device__
constexpr
const
V
&
GetVectorTypeReference
(
Idx
)
const
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
return
data_
.
GetVectorTypeReference
(
Number
<
offset
>
{});
}
// Get read access to V. No is_valid check
// Idx is for S, not V. Idx should be aligned with V
template
<
typename
Idx
>
__host__
__device__
constexpr
V
&
GetVectorTypeReference
(
Idx
)
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
index_t
offset
=
coord
.
GetOffset
();
return
data_
.
GetVectorTypeReference
(
Number
<
offset
>
{});
}
StaticBufferTupleOfVector
<
AddressSpace
,
S
,
num_of_vector_
,
ScalarPerVector
,
true
>
data_
;
static
constexpr
S
zero_scalar_value_
=
S
{
0
};
const
S
invalid_element_scalar_value_
=
S
{
0
};
S
ignored_element_scalar_
;
};
template
<
AddressSpaceEnum
AddressSpace
,
typename
T
,
typename
TensorDesc
,
typename
enable_if
<
TensorDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
auto
make_static_tensor
(
TensorDesc
)
{
return
StaticTensor
<
AddressSpace
,
T
,
TensorDesc
,
true
>
{};
}
template
<
AddressSpaceEnum
AddressSpace
,
typename
T
,
typename
TensorDesc
,
typename
X
,
typename
enable_if
<
TensorDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
,
typename
enable_if
<
is_same
<
remove_cvref_t
<
T
>
,
remove_cvref_t
<
X
>>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
make_static_tensor
(
TensorDesc
,
X
invalid_element_value
)
{
return
StaticTensor
<
AddressSpace
,
T
,
TensorDesc
,
true
>
{
invalid_element_value
};
}
}
// namespace ck
#endif
include/ck/tensor_description/cluster_descriptor.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
template
<
typename
Lengths
,
typename
ArrangeOrder
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
Size
(),
1
>
::
type
>
__host__
__device__
constexpr
auto
make_cluster_descriptor
(
const
Lengths
&
lengths
,
ArrangeOrder
order
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
Size
(),
1
>::
type
{})
{
constexpr
index_t
ndim_low
=
Lengths
::
Size
();
const
auto
reordered_lengths
=
container_reorder_given_new2old
(
lengths
,
order
);
const
auto
low_lengths
=
generate_tuple
(
[
&
](
auto
idim_low
)
{
return
reordered_lengths
[
idim_low
];
},
Number
<
ndim_low
>
{});
const
auto
transform
=
make_merge_transform
(
low_lengths
);
constexpr
auto
low_dim_old_top_ids
=
ArrangeOrder
{};
constexpr
auto
up_dim_new_top_ids
=
Sequence
<
0
>
{};
return
make_single_stage_tensor_adaptor
(
make_tuple
(
transform
),
make_tuple
(
low_dim_old_top_ids
),
make_tuple
(
up_dim_new_top_ids
));
}
}
// namespace ck
include/ck/tensor_description/multi_index_transform.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/multi_index.hpp"
namespace
ck
{
template
<
typename
LowLength
>
struct
PassThrough
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}));
UpLengths
up_lengths_
;
__host__
__device__
constexpr
PassThrough
()
=
default
;
__host__
__device__
constexpr
PassThrough
(
const
LowLength
&
low_length
)
:
up_lengths_
{
make_tuple
(
low_length
)}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
static
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}];
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"PassThrough, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
LowLength
,
typename
LeftPadLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
=
false
>
struct
Pad
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}
+
LeftPadLength
{}
+
RightPadLength
{}));
UpLengths
up_lengths_
;
LeftPadLength
left_pad_length_
;
RightPadLength
right_pad_length_
;
__host__
__device__
constexpr
Pad
()
=
default
;
__host__
__device__
constexpr
Pad
(
const
LowLength
&
low_length
,
const
LeftPadLength
&
left_pad_length
,
const
RightPadLength
&
right_pad_length
)
:
up_lengths_
{
make_tuple
(
low_length
+
left_pad_length
+
right_pad_length
)},
left_pad_length_
{
left_pad_length
},
right_pad_length_
{
right_pad_length
}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}]
-
left_pad_length_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
SkipIsValidCheck
;
}
template
<
typename
UpIdx
>
__host__
__device__
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
idx_up
)
const
{
return
SkipIsValidCheck
||
((
idx_up
[
Number
<
0
>
{}]
>=
left_pad_length_
)
&&
(
idx_up
[
Number
<
0
>
{}]
<
up_lengths_
[
Number
<
0
>
{}]
-
right_pad_length_
));
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
LeftPadLength
>::
value
&&
is_known_at_compile_time
<
RightPadLength
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Pad, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"left_pad_length %d"
,
index_t
{
left_pad_length_
});
printf
(
"right_pad_length %d"
,
index_t
{
right_pad_length_
});
printf
(
"}"
);
}
};
template
<
typename
LowLength
,
typename
LeftPadLength
,
bool
SkipIsValidCheck
=
false
>
struct
LeftPad
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}
+
LeftPadLength
{}));
UpLengths
up_lengths_
;
LeftPadLength
left_pad_length_
;
__host__
__device__
constexpr
LeftPad
()
=
default
;
__host__
__device__
constexpr
LeftPad
(
const
LowLength
&
low_length
,
const
LeftPadLength
&
left_pad_length
)
:
up_lengths_
{
make_tuple
(
low_length
+
left_pad_length
)},
left_pad_length_
{
left_pad_length
}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}]
-
left_pad_length_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
SkipIsValidCheck
;
}
template
<
typename
UpIdx
>
__host__
__device__
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
idx_up
)
const
{
return
SkipIsValidCheck
||
(
idx_up
[
Number
<
0
>
{}]
>=
left_pad_length_
);
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
LeftPadLength
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"LeftPad, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"left_pad_length_ %d"
,
index_t
{
left_pad_length_
});
printf
(
"}"
);
}
};
template
<
typename
LowLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
=
false
>
struct
RightPad
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
LowLength
{}
+
RightPadLength
{}));
UpLengths
up_lengths_
;
LowLength
low_length_
;
RightPadLength
right_pad_length_
;
__host__
__device__
constexpr
RightPad
()
=
default
;
__host__
__device__
constexpr
RightPad
(
const
LowLength
&
low_length
,
const
RightPadLength
&
right_pad_length
)
:
up_lengths_
{
make_tuple
(
low_length
+
right_pad_length
)},
low_length_
{
low_length
},
right_pad_length_
{
right_pad_length
}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
static
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}];
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
SkipIsValidCheck
;
}
template
<
typename
UpIdx
>
__host__
__device__
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
idx_up
)
const
{
return
SkipIsValidCheck
||
(
idx_up
[
Number
<
0
>
{}]
<
low_length_
);
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
LowLength
>::
value
&&
is_known_at_compile_time
<
RightPadLength
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"RightPad, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"low_length_ %d"
,
index_t
{
low_length_
});
printf
(
"left_pad_length_ %d"
,
index_t
{
right_pad_length_
});
printf
(
"}"
);
}
};
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
// UpLengths and Coefficients can be either of the followings:
// 1) Tuple of index_t, which is known at run-time, or
// 2) Tuple of Number, which is known at compile-time, or
// 3) Tuple of mixture of index_t and Number, which is known partially at run-time and partially
// at compile-time
template
<
typename
UpLengths
,
typename
Coefficients
,
typename
enable_if
<
UpLengths
::
Size
()
==
Coefficients
::
Size
(),
bool
>
::
type
=
false
>
struct
Embed
{
static
constexpr
index_t
NDimUp
=
UpLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
NDimUp
>
;
UpLengths
up_lengths_
;
Coefficients
coefficients_
;
__host__
__device__
constexpr
Embed
()
=
default
;
__host__
__device__
constexpr
Embed
(
const
UpLengths
&
up_lengths
,
const
Coefficients
&
coefficients
)
:
up_lengths_
{
up_lengths
},
coefficients_
{
coefficients
}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
NDimUp
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
0
;
static_for
<
0
,
NDimUp
,
1
>
{}([
&
idx_low
,
&
idx_up
,
this
](
auto
i
)
{
idx_low
(
Number
<
0
>
{})
+=
idx_up
[
i
]
*
this
->
coefficients_
[
i
];
});
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
NDimUp
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
Number
<
0
>
{})
=
0
;
static_for
<
0
,
NDimUp
,
1
>
{}(
[
&
](
auto
i
)
{
idx_diff_low
(
Number
<
0
>
{})
+=
idx_diff_up
[
i
]
*
coefficients_
[
i
];
});
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
Coefficients
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Embed, "
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"coefficients_ "
);
print_multi_index
(
coefficients_
);
printf
(
"}"
);
}
};
// Implementation of "Merge" transformation primitive that uses regular to do lowering of
// multi-index and use carry-and-borrow check to do lowering of multi-index delta
template
<
typename
LowLengths
>
struct
Merge_v1_carry_check
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
LowLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v1_carry_check
()
=
default
;
__host__
__device__
constexpr
Merge_v1_carry_check
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
// normal division
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
tmp
-=
idx_low
[
i
]
*
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex_1a
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
/* idx_up_new */
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
// CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
// However,
// 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
// can be calculated at compile-time.
// 2) If idx_diff_up is not known at compile-time, but its value
// doesn't change during the whole kernel execution, then
// idx_diff_low_const also
// doesn't change during the whole kernel execution. Compiler generated
// ISA should
// only caclculate idx_diff_low_const once and save it durinng the whole
// kernel execution
// If neither 1) nor 2) is satisfied, then the calculation will also be
// computed at
// run-time each time this function is called, and can be very expensive.
LowerIndex
idx_diff_low_const
;
LowerIndex
idx_low_length_minus_idx_diff_low_const
;
LowerIndex
idx_low_length_plus_idx_diff_low_const
;
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t
tmp
=
idx_diff_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_diff_low_const
(
i
)
=
tmp
/
low_lengths_scan_
[
i
];
tmp
-=
idx_diff_low_const
[
i
]
*
low_lengths_scan_
[
i
];
});
idx_diff_low_const
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
static_for
<
0
,
NDimLow
,
1
>
{}([
&
](
auto
i
)
{
idx_low_length_minus_idx_diff_low_const
(
i
)
=
low_lengths_
[
i
]
-
idx_diff_low_const
[
i
];
idx_low_length_plus_idx_diff_low_const
(
i
)
=
low_lengths_
[
i
]
+
idx_diff_low_const
[
i
];
});
#else
// Hack: this force result into SGPR. Need to make sure the result is thread invariant
index_t
tmp
=
idx_diff_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_diff_low_const
(
i
)
=
__builtin_amdgcn_readfirstlane
(
tmp
/
low_lengths_scan_
[
i
]);
tmp
-=
idx_diff_low_const
[
i
]
*
low_lengths_scan_
[
i
];
});
idx_diff_low_const
(
Number
<
NDimLow
-
1
>
{})
=
__builtin_amdgcn_readfirstlane
(
tmp
);
static_for
<
0
,
NDimLow
,
1
>
{}([
&
](
auto
i
)
{
idx_low_length_minus_idx_diff_low_const
(
i
)
=
__builtin_amdgcn_readfirstlane
(
low_lengths_
[
i
]
-
idx_diff_low_const
[
i
]);
idx_low_length_plus_idx_diff_low_const
(
i
)
=
__builtin_amdgcn_readfirstlane
(
low_lengths_
[
i
]
+
idx_diff_low_const
[
i
]);
});
#endif
if
constexpr
(
Hack
==
1
)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t
carry
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
index_t
idx_low_tmp
=
idx_low
[
i
]
+
carry
;
bool
do_carry
=
idx_low_tmp
>=
idx_low_length_minus_idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_carry
?
-
idx_low_length_minus_idx_diff_low_const
[
i
]
:
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
+=
carry
;
carry
=
do_carry
?
1
:
0
;
});
idx_diff_low
(
Number
<
0
>
{})
=
idx_diff_low_const
[
Number
<
0
>
{}]
+
carry
;
idx_low
+=
idx_diff_low
;
}
else
if
constexpr
(
Hack
==
2
)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t
borrow
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
index_t
idx_low_tmp
=
idx_low
[
i
]
-
borrow
;
bool
do_borrow
=
idx_low_tmp
<
-
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_borrow
?
idx_low_length_plus_idx_diff_low_const
[
i
]
:
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
-=
borrow
;
borrow
=
do_borrow
?
1
:
0
;
});
idx_diff_low
(
Number
<
0
>
{})
=
idx_diff_low_const
[
Number
<
0
>
{}]
-
borrow
;
idx_low
+=
idx_diff_low
;
}
else
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t
carry
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
index_t
idx_low_tmp
=
idx_low
[
i
]
+
carry
;
bool
do_carry
=
idx_low_tmp
>=
idx_low_length_minus_idx_diff_low_const
[
i
];
bool
do_borrow
=
idx_low_tmp
<
-
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_carry
?
-
idx_low_length_minus_idx_diff_low_const
[
i
]
:
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_borrow
?
idx_low_length_plus_idx_diff_low_const
[
i
]
:
idx_diff_low
[
i
];
idx_diff_low
(
i
)
+=
carry
;
carry
=
do_carry
?
1
:
0
;
carry
=
do_borrow
?
-
1
:
carry
;
});
idx_diff_low
(
Number
<
0
>
{})
=
idx_diff_low_const
[
Number
<
0
>
{}]
+
carry
;
idx_low
+=
idx_diff_low
;
}
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex_1b
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
/* idx_up_new */
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
// CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
// However,
// 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
// can be calculated at compile-time.
// 2) If idx_diff_up is not known at compile-time, but its value
// doesn't change during the whole kernel execution, then
// idx_diff_low_const also
// doesn't change during the whole kernel execution. Compiler generated
// ISA should
// only caclculate idx_diff_low_const once and save it durinng the whole
// kernel execution
// If neither 1) nor 2) is satisfied, then the calculation will also be
// computed at
// run-time each time this function is called, and can be very expensive.
LowerIndex
idx_diff_low_const
;
LowerIndex
idx_low_length_minus_idx_diff_low_const
;
LowerIndex
idx_low_length_plus_idx_diff_low_const
;
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t
tmp
=
idx_diff_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_diff_low_const
(
i
)
=
tmp
/
low_lengths_scan_
[
i
];
tmp
-=
idx_diff_low_const
[
i
]
*
low_lengths_scan_
[
i
];
});
idx_diff_low_const
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
static_for
<
0
,
NDimLow
,
1
>
{}([
&
](
auto
i
)
{
idx_low_length_minus_idx_diff_low_const
(
i
)
=
low_lengths_
[
i
]
-
idx_diff_low_const
[
i
];
idx_low_length_plus_idx_diff_low_const
(
i
)
=
low_lengths_
[
i
]
+
idx_diff_low_const
[
i
];
});
#else
// Hack: this force result into SGPR. Need to make sure the result is thread invariant
index_t
tmp
=
idx_diff_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_diff_low_const
(
i
)
=
__builtin_amdgcn_readfirstlane
(
tmp
/
low_lengths_scan_
[
i
]);
tmp
-=
idx_diff_low_const
[
i
]
*
low_lengths_scan_
[
i
];
});
idx_diff_low_const
(
Number
<
NDimLow
-
1
>
{})
=
__builtin_amdgcn_readfirstlane
(
tmp
);
static_for
<
0
,
NDimLow
,
1
>
{}([
&
](
auto
i
)
{
idx_low_length_minus_idx_diff_low_const
(
i
)
=
__builtin_amdgcn_readfirstlane
(
low_lengths_
[
i
]
-
idx_diff_low_const
[
i
]);
idx_low_length_plus_idx_diff_low_const
(
i
)
=
low_lengths_
[
i
]
+
idx_diff_low_const
[
i
];
});
#endif
if
constexpr
(
Hack
==
1
)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t
carry
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
index_t
idx_low_tmp
=
idx_low
[
i
]
+
carry
;
bool
do_carry
=
idx_low_tmp
>=
idx_low_length_minus_idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_carry
?
-
idx_low_length_minus_idx_diff_low_const
[
i
]
:
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
+=
carry
;
carry
=
do_carry
?
1
:
0
;
});
idx_diff_low
(
Number
<
0
>
{})
=
idx_diff_low_const
[
Number
<
0
>
{}]
+
carry
;
idx_low
+=
idx_diff_low
;
}
else
if
constexpr
(
Hack
==
2
)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t
borrow
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
index_t
negative_idx_low_tmp
=
borrow
-
idx_low
[
i
];
bool
do_borrow
=
negative_idx_low_tmp
>
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_borrow
?
idx_low_length_plus_idx_diff_low_const
[
i
]
:
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
-=
borrow
;
borrow
=
do_borrow
?
1
:
0
;
});
idx_diff_low
(
Number
<
0
>
{})
=
idx_diff_low_const
[
Number
<
0
>
{}]
-
borrow
;
idx_low
+=
idx_diff_low
;
}
else
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
index_t
carry
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
index_t
idx_low_tmp
=
idx_low
[
i
]
+
carry
;
bool
do_carry
=
idx_low_tmp
>=
idx_low_length_minus_idx_diff_low_const
[
i
];
bool
do_borrow
=
idx_low_tmp
<
-
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_carry
?
-
idx_low_length_minus_idx_diff_low_const
[
i
]
:
idx_diff_low_const
[
i
];
idx_diff_low
(
i
)
=
do_borrow
?
idx_low_length_plus_idx_diff_low_const
[
i
]
:
idx_diff_low
[
i
];
idx_diff_low
(
i
)
+=
carry
;
carry
=
do_carry
?
1
:
0
;
carry
=
do_borrow
?
-
1
:
carry
;
});
idx_diff_low
(
Number
<
0
>
{})
=
idx_diff_low_const
[
Number
<
0
>
{}]
+
carry
;
idx_low
+=
idx_diff_low
;
}
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex_2
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
/* idx_up_new */
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
// CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions.
// However,
// 1) If idx_diff_up is known at compile-time, then idx_diff_low_const
// can be calculated at compile-time.
// 2) If idx_diff_up is not known at compile-time, but its value
// doesn't change during the whole kernel execution, then
// idx_diff_low_const also
// doesn't change during the whole kernel execution. Compiler generated
// ISA should
// only caclculate idx_diff_low_const once and save it durinng the whole
// kernel execution
// If neither 1) nor 2) is satisfied, then the calculation will also be
// computed at run-time each time this function is called, and can be
// very expensive.
LowerIndex
idx_diff_low_const
;
#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE
index_t
tmp
=
idx_diff_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_diff_low_const
(
i
)
=
tmp
/
low_lengths_scan_
[
i
];
tmp
-=
idx_diff_low_const
[
i
]
*
low_lengths_scan_
[
i
];
});
idx_diff_low_const
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
#else
// Hack: this force result into SGPR. Need to make sure the result is thread invariant
index_t
tmp
=
idx_diff_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_diff_low_const
(
i
)
=
__builtin_amdgcn_readfirstlane
(
tmp
/
low_lengths_scan_
[
i
]);
tmp
-=
idx_diff_low_const
[
i
]
*
low_lengths_scan_
[
i
];
});
idx_diff_low_const
(
Number
<
NDimLow
-
1
>
{})
=
__builtin_amdgcn_readfirstlane
(
tmp
);
#endif
if
constexpr
(
Hack
==
1
)
{
// do carry check on each low dimension in reversed order
// do not need to check the first dimension
bool
do_carry
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
idx_diff_low
(
i
)
=
idx_diff_low_const
[
i
]
+
do_carry
;
index_t
idx_low_tmp
=
idx_low
[
i
]
+
idx_diff_low
[
i
];
do_carry
=
idx_low_tmp
>=
low_lengths_
[
i
];
#if 0
// TODO: use exec-mask inline asm, which use 1 VALU
if(do_carry)
{
idx_diff_low(i) -= low_lengths_[i];
}
#elif
1
// this use 2 VALU
idx_diff_low
(
i
)
=
do_carry
?
idx_diff_low
[
i
]
-
low_lengths_
[
i
]
:
idx_diff_low
[
i
];
#elif 1
// this use 2 VALU
index_t
idx_diff_low_tmp
=
idx_diff_low
[
i
]
-
low_lengths_
[
i
];
idx_diff_low
(
i
)
=
do_carry
?
idx_diff_low_tmp
:
idx_diff_low
[
i
];
#endif
idx_low
(
i
)
+=
idx_diff_low
[
i
];
});
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_low_const
[
I0
]
+
do_carry
;
idx_low
(
I0
)
+=
idx_diff_low
[
I0
];
}
else
if
constexpr
(
Hack
==
2
)
{
// do borrow check on each low dimension in reversed order
// do not need to check the first dimension
bool
do_borrow
=
0
;
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
](
auto
i
)
{
idx_diff_low
(
i
)
=
idx_diff_low_const
[
i
]
-
do_borrow
;
index_t
idx_low_tmp
=
idx_low
[
i
]
+
idx_diff_low
[
i
];
do_borrow
=
idx_low_tmp
<
0
;
#if 0
// TODO: use exec-mask inline asm
if(do_borrow)
{
idx_diff_low(i) += low_lengths_[i];
}
#elif
1
idx_diff_low
(
i
)
=
do_borrow
?
idx_diff_low
[
i
]
+
low_lengths_
[
i
]
:
idx_diff_low
[
i
];
#elif 1
index_t
idx_diff_low_tmp
=
idx_diff_low
[
i
]
+
low_lengths_
[
i
];
idx_diff_low
(
i
)
=
do_borrow
?
idx_diff_low_tmp
:
idx_diff_low
[
i
];
#endif
idx_low
(
i
)
+=
idx_diff_low
[
i
];
});
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_low_const
[
I0
]
-
do_borrow
;
idx_low
(
I0
)
+=
idx_diff_low
[
I0
];
}
else
{
// not implemented
}
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
#if 1
UpdateLowerIndex_1a
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
#elif 0
UpdateLowerIndex_1b
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
#else
UpdateLowerIndex_2
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
#endif
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsScan
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v1_carry_check, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_scan_ "
);
print_multi_index
(
low_lengths_scan_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
LowLengths
>
struct
lambda_merge_generate_MagicDivision_calculate_magic_multiplier
{
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
()(
Number
<
I
>
i
)
const
{
return
MagicDivision
::
CalculateMagicMultiplier
(
LowLengths
{}[
i
]);
}
};
template
<
typename
LowLengths
>
struct
lambda_merge_generate_MagicDivision_calculate_magic_shift
{
template
<
index_t
I
>
__host__
__device__
constexpr
auto
operator
()(
Number
<
I
>
i
)
const
{
return
MagicDivision
::
CalculateMagicShift
(
LowLengths
{}[
i
]);
}
};
// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index
// Caution:
// 1. The magic number division implementation being used would produce correct result if the
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
// uint32_t is then used.
// 3. For Merge primitive, upper-index is the dividend.
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative.
template
<
typename
LowLengths
>
struct
Merge_v2_magic_division
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
using
LowLengthsMagicDivisorMultipiler
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_multiplier
<
LowLengths
>
{},
Number
<
NDimLow
>
{}));
using
LowLengthsMagicDivisorShift
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_shift
<
LowLengths
>
{},
Number
<
NDimLow
>
{}));
LowLengths
low_lengths_
;
LowLengthsMagicDivisorMultipiler
low_lengths_magic_divisor_multiplier_
;
LowLengthsMagicDivisorShift
low_lengths_magic_divisor_shift_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v2_magic_division
()
=
default
;
__host__
__device__
constexpr
Merge_v2_magic_division
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_magic_divisor_multiplier_
{
generate_tuple
(
[
&
](
auto
i
)
{
return
MagicDivision
::
CalculateMagicMultiplier
(
low_lengths
[
i
]);
},
Number
<
NDimLow
>
{})},
low_lengths_magic_divisor_shift_
{
generate_tuple
(
[
&
](
auto
i
)
{
return
MagicDivision
::
CalculateMagicShift
(
low_lengths
[
i
]);
},
Number
<
NDimLow
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
,
this
](
auto
i
)
{
index_t
tmp2
=
MagicDivision
::
DoMagicDivision
(
tmp
,
this
->
low_lengths_magic_divisor_multiplier_
[
i
],
this
->
low_lengths_magic_divisor_shift_
[
i
]);
idx_low
(
i
)
=
tmp
-
tmp2
*
this
->
low_lengths_
[
i
];
tmp
=
tmp2
;
});
idx_low
(
Number
<
0
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up_new
[
Number
<
0
>
{}];
static_for
<
NDimLow
-
1
,
0
,
-
1
>
{}([
&
,
this
](
auto
i
)
{
index_t
tmp2
=
MagicDivision
::
DoMagicDivision
(
tmp
,
this
->
low_lengths_magic_divisor_multiplier_
[
i
],
this
->
low_lengths_magic_divisor_shift_
[
i
]);
index_t
idx_low_old
=
idx_low
[
i
];
idx_low
(
i
)
=
tmp
-
tmp2
*
this
->
low_lengths_
[
i
];
tmp
=
tmp2
;
idx_diff_low
(
i
)
=
idx_low
[
i
]
-
idx_low_old
;
});
idx_diff_low
(
Number
<
0
>
{})
=
tmp
-
idx_low
(
Number
<
0
>
{});
idx_low
(
Number
<
0
>
{})
=
tmp
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsMagicDivisorMultipiler
>::
value
&&
is_known_at_compile_time
<
LowLengthsMagicDivisorShift
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v2_magic_division, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_magic_divisor_multiplier_ "
);
print_multi_index
(
low_lengths_magic_divisor_multiplier_
);
printf
(
"low_lengths_magic_divisor_shift_ "
);
print_multi_index
(
low_lengths_magic_divisor_shift_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index
// Caution:
// 1. The magic number division implementation being used would produce correct result if the
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
// uint32_t is then used.
// 3. For Merge primitive, upper-index is the dividend.
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative.
template
<
typename
LowLengths
>
struct
Merge_v2r2_magic_division
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
LowLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
using
LowLengthsScanMagicDivisorMultipiler
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_multiplier
<
LowLengthsScan
>
{},
Number
<
NDimLow
>
{}));
using
LowLengthsScanMagicDivisorShift
=
decltype
(
generate_tuple
(
lambda_merge_generate_MagicDivision_calculate_magic_shift
<
LowLengthsScan
>
{},
Number
<
NDimLow
>
{}));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
LowLengthsScanMagicDivisorMultipiler
low_lengths_scan_magic_divisor_multiplier_
;
LowLengthsScanMagicDivisorShift
low_lengths_scan_magic_divisor_shift_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v2r2_magic_division
()
=
default
;
__host__
__device__
constexpr
Merge_v2r2_magic_division
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})},
low_lengths_scan_magic_divisor_multiplier_
{
generate_tuple
(
[
&
](
auto
i
)
{
return
MagicDivision
::
CalculateMagicMultiplier
(
low_lengths_scan_
[
i
]);
},
Number
<
NDimLow
>
{})},
low_lengths_scan_magic_divisor_shift_
{
generate_tuple
(
[
&
](
auto
i
)
{
return
MagicDivision
::
CalculateMagicShift
(
low_lengths_scan_
[
i
]);
},
Number
<
NDimLow
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
,
this
](
auto
i
)
{
idx_low
(
i
)
=
MagicDivision
::
DoMagicDivision
(
tmp
,
this
->
low_lengths_scan_magic_divisor_multiplier_
[
i
],
this
->
low_lengths_scan_magic_divisor_shift_
[
i
]);
tmp
-=
idx_low
[
i
]
*
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up_new
[
Number
<
0
>
{}];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
,
this
](
auto
i
)
{
index_t
idx_low_old
=
idx_low
[
i
];
idx_low
(
i
)
=
MagicDivision
::
DoMagicDivision
(
tmp
,
this
->
low_lengths_scan_magic_divisor_multiplier_
[
i
],
this
->
low_lengths_scan_magic_divisor_shift_
[
i
]);
idx_diff_low
(
i
)
=
idx_low
[
i
]
-
idx_low_old
;
tmp
-=
idx_low
[
i
]
*
this
->
low_lengths_scan_
[
i
];
});
idx_diff_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
-
idx_low
[
Number
<
NDimLow
-
1
>
{}];
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsScanMagicDivisorMultipiler
>::
value
&&
is_known_at_compile_time
<
LowLengthsScanMagicDivisorShift
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v2r2_magic_division, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_scan "
);
print_multi_index
(
low_lengths_scan_
);
printf
(
"low_lengths_scan_magic_divisor_multiplier_ "
);
print_multi_index
(
low_lengths_scan_magic_divisor_multiplier_
);
printf
(
"low_lengths_scan_magic_divisor_shift_ "
);
print_multi_index
(
low_lengths_scan_magic_divisor_shift_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template
<
typename
LowLengths
>
struct
Merge_v3_division_mod
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
LowLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v3_division_mod
()
=
default
;
__host__
__device__
constexpr
Merge_v3_division_mod
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
// division and mod
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
INm1
=
Number
<
NDimLow
-
1
>
{};
index_t
tmp
=
idx_up_new
[
I0
];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
const
index_t
tmp2
=
idx_low
[
i
];
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
idx_diff_low
(
i
)
=
idx_low
[
i
]
-
tmp2
;
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
const
index_t
tmp2
=
idx_low
[
INm1
];
idx_low
(
INm1
)
=
tmp
;
idx_diff_low
(
INm1
)
=
idx_low
[
INm1
]
-
tmp2
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsScan
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v3_direct_division_mod, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_scan_ "
);
print_multi_index
(
low_lengths_scan_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
>
struct
UnMerge
{
static
constexpr
index_t
NDimUp
=
UpLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
NDimUp
>
;
using
UpLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
UpLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
UpLengths
up_lengths_
;
UpLengthsScan
up_lengths_scan_
;
__host__
__device__
constexpr
UnMerge
()
=
default
;
__host__
__device__
constexpr
UnMerge
(
const
UpLengths
&
up_lengths
)
:
up_lengths_
{
up_lengths
},
up_lengths_scan_
{
container_reverse_exclusive_scan
(
up_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
NDimUp
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
if
constexpr
(
!
Use24BitIntegerCalculation
)
{
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
NDimUp
-
1
>
{}];
static_for
<
0
,
NDimUp
-
1
,
1
>
{}(
[
&
](
auto
i
)
{
idx_low
(
Number
<
0
>
{})
+=
idx_up
[
i
]
*
up_lengths_scan_
[
i
];
});
}
else
{
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
NDimUp
-
1
>
{}];
static_for
<
0
,
NDimUp
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
Number
<
0
>
{})
=
(
0x00ffffff
&
idx_low
[
Number
<
0
>
{}])
+
(
0x00ffffff
&
idx_up
[
i
])
*
(
0x00ffffff
&
up_lengths_scan_
[
i
]);
});
}
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
const
{
CalculateLowerIndex
(
idx_diff_low
,
idx_diff_up
);
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
UpLengthsScan
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"UnMerge, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"up_lengths_scan_"
);
print_multi_index
(
up_lengths_scan_
);
printf
(
"}"
);
}
};
template
<
typename
LowerIndex
>
struct
Freeze
{
LowerIndex
low_idx_
;
__host__
__device__
constexpr
Freeze
()
=
default
;
__host__
__device__
constexpr
Freeze
(
const
LowerIndex
&
low_idx
)
:
low_idx_
{
low_idx
}
{}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
0
;
}
__host__
__device__
static
constexpr
auto
GetUpperLengths
()
{
return
Tuple
<>
{};
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
/* idx_up */
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
0
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
low_idx_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
/* idx_diff_up */
,
LowIdx
&
/* idx_low */
,
const
UpIdx
&
/* idx_up_new */
,
Number
<
Hack
>
)
{
idx_diff_low
(
Number
<
0
>
{})
=
0
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowerIndex
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"Freeze"
);
printf
(
"low_idx_ %d"
,
index_t
{
low_idx_
});
}
};
// Insert a dangling upper dimension without lower dimension
template
<
typename
UpperLength
>
struct
Insert
{
using
UpLengths
=
decltype
(
make_tuple
(
UpperLength
{}));
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Insert
()
=
default
;
__host__
__device__
constexpr
Insert
(
const
UpperLength
&
up_length
)
:
up_lengths_
{
make_tuple
(
up_length
)}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
0
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
auto
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
,
const
UpIdx
&
)
const
{
static_assert
(
LowIdx
::
Size
()
==
0
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
,
const
UpIdxDiff
&
,
LowIdx
&
,
const
UpIdx
&
,
Number
<
Hack
>
)
{
static_assert
(
LowIdxDiff
::
Size
()
==
0
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
0
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpperLength
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"Insert"
);
print_multi_index
(
up_lengths_
);
}
};
template
<
typename
VectorSize
,
typename
UpLength
>
struct
Vectorize
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
UpLength
{}));
UpLengths
up_lengths_
;
VectorSize
vector_size_
;
__host__
__device__
constexpr
Vectorize
()
=
default
;
__host__
__device__
constexpr
Vectorize
(
const
VectorSize
&
vector_size
,
const
UpLength
&
up_length
)
:
vector_size_
{
vector_size
},
up_lengths_
{
make_tuple
(
up_length
)}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
vector_size_
*
idx_up
[
Number
<
0
>
{}];
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
vector_size_
*
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Vectorize, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
LowLength
,
typename
SliceBegin
,
typename
SliceEnd
>
struct
Slice
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
SliceEnd
{}
-
SliceBegin
{}));
UpLengths
up_lengths_
;
SliceBegin
slice_begin_
;
SliceEnd
slice_end_
;
__host__
__device__
constexpr
Slice
()
=
default
;
__host__
__device__
constexpr
Slice
(
const
LowLength
&
,
const
SliceBegin
&
slice_begin
,
const
SliceEnd
&
slice_end
)
:
up_lengths_
{
make_tuple
(
slice_end
-
slice_begin
)},
slice_begin_
{
slice_begin
},
slice_end_
{
slice_end
}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}]
+
slice_begin_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
static
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
,
Number
<
Hack
>
)
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
idx_diff_low
(
I0
)
=
idx_diff_up
[
I0
];
idx_low
+=
idx_diff_low
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
)
const
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
&&
is_known_at_compile_time
<
SliceBegin
>::
value
&&
is_known_at_compile_time
<
SliceEnd
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Slice, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"slice_begin_ %d"
,
index_t
{
slice_begin_
});
printf
(
"slice_end %d"
,
index_t
{
slice_end_
});
printf
(
"}"
);
}
};
/*
* \brief lower_idx = upper_idx % modulus.
* TODO: Need an improved implementation since the modulo operation is expensive.
*/
template
<
typename
Modulus
,
typename
UpLength
>
struct
Modulo
{
using
LowerIndex
=
MultiIndex
<
1
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
UpLength
{}));
Modulus
modulus_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Modulo
()
=
default
;
__host__
__device__
constexpr
Modulo
(
const
Modulus
&
modulus
,
const
UpLength
&
up_length
)
:
modulus_
{
modulus
},
up_lengths_
{
make_tuple
(
up_length
)}
{
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
1
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}]
%
modulus_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
up_idx
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
1
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
1
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
const
auto
idx_low_old
=
idx_low
;
idx_low
(
I0
)
=
(
up_idx
(
I0
)
+
idx_diff_up
(
I0
))
%
modulus_
;
idx_diff_low
(
I0
)
=
idx_low
-
idx_low_old
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
UpLengths
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Modulus, "
);
printf
(
"up_lengths_"
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
}
// namespace ck
include/ck/tensor_description/multi_index_transform_helper.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform.hpp"
namespace
ck
{
template
<
typename
LowLength
>
__host__
__device__
constexpr
auto
make_pass_through_transform
(
const
LowLength
&
low_length
)
{
return
PassThrough
<
LowLength
>
{
low_length
};
}
template
<
typename
LowLength
,
typename
LeftPad
,
typename
RightPad
,
bool
SkipIsValidCheck
=
false
>
__host__
__device__
constexpr
auto
make_pad_transform
(
const
LowLength
&
low_length
,
const
LeftPad
&
left_pad
,
const
RightPad
&
right_pad
,
integral_constant
<
bool
,
SkipIsValidCheck
>
=
integral_constant
<
bool
,
false
>
{})
{
return
Pad
<
LowLength
,
LeftPad
,
RightPad
,
SkipIsValidCheck
>
{
low_length
,
left_pad
,
right_pad
};
}
template
<
typename
LowLength
,
typename
LeftPadLength
,
bool
SkipIsValidCheck
=
false
>
__host__
__device__
constexpr
auto
make_left_pad_transform
(
const
LowLength
&
low_length
,
const
LeftPadLength
&
left_pad
,
integral_constant
<
bool
,
SkipIsValidCheck
>
=
integral_constant
<
bool
,
false
>
{})
{
return
LeftPad
<
LowLength
,
LeftPadLength
,
SkipIsValidCheck
>
{
low_length
,
left_pad
};
}
template
<
typename
LowLength
,
typename
RightPadLength
,
bool
SkipIsValidCheck
=
false
>
__host__
__device__
constexpr
auto
make_right_pad_transform
(
const
LowLength
&
low_length
,
const
RightPadLength
&
right_pad
,
integral_constant
<
bool
,
SkipIsValidCheck
>
=
integral_constant
<
bool
,
false
>
{})
{
return
RightPad
<
LowLength
,
RightPadLength
,
SkipIsValidCheck
>
{
low_length
,
right_pad
};
}
template
<
typename
UpLengths
,
typename
Coefficients
,
typename
enable_if
<
UpLengths
::
Size
()
==
Coefficients
::
Size
(),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
auto
make_embed_transform
(
const
UpLengths
&
up_lengths
,
const
Coefficients
&
coefficients
)
{
return
Embed
<
UpLengths
,
Coefficients
>
{
up_lengths
,
coefficients
};
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform
(
const
LowLengths
&
low_lengths
)
{
#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return
make_merge_transform_v2_magic_division
(
low_lengths
);
#else
return
make_merge_transform_v1_carry_check
(
low_lengths
);
#endif
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v1_carry_check
(
const
LowLengths
&
low_lengths
)
{
return
Merge_v1_carry_check
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v2_magic_division
(
const
LowLengths
&
low_lengths
)
{
#if 1
return
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
#else
return
Merge_v2r2_magic_division
<
LowLengths
>
{
low_lengths
};
#endif
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v3_division_mod
(
const
LowLengths
&
low_lengths
)
{
return
Merge_v3_division_mod
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
=
false
>
__host__
__device__
constexpr
auto
make_unmerge_transform
(
const
UpLengths
&
up_lengths
,
integral_constant
<
bool
,
Use24BitIntegerCalculation
>
=
integral_constant
<
bool
,
false
>
{})
{
return
UnMerge
<
UpLengths
,
Use24BitIntegerCalculation
>
{
up_lengths
};
}
template
<
typename
LowerIndex
>
__host__
__device__
constexpr
auto
make_freeze_transform
(
const
LowerIndex
&
low_idx
)
{
return
Freeze
<
LowerIndex
>
{
low_idx
};
}
template
<
typename
UpperIndex
>
__host__
__device__
constexpr
auto
make_insert_transform
(
const
UpperIndex
&
up_idx
)
{
return
Insert
<
UpperIndex
>
{
up_idx
};
}
template
<
typename
LowLength
,
typename
SliceBegin
,
typename
SliceEnd
>
__host__
__device__
constexpr
auto
make_slice_transform
(
const
LowLength
&
low_length
,
const
SliceBegin
&
slice_begin
,
const
SliceEnd
&
slice_end
)
{
return
Slice
<
LowLength
,
SliceBegin
,
SliceEnd
>
{
low_length
,
slice_begin
,
slice_end
};
}
template
<
typename
VectorSize
,
typename
UpLength
>
__host__
__device__
constexpr
auto
make_vectorize_transform
(
const
VectorSize
&
vector_size
,
const
UpLength
&
up_length
)
{
return
Vectorize
<
VectorSize
,
UpLength
>
{
vector_size
,
up_length
};
}
template
<
typename
Modulus
,
typename
UpLength
>
__host__
__device__
constexpr
auto
make_modulo_transform
(
const
Modulus
&
modulus
,
const
UpLength
&
up_length
)
{
return
Modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
}
}
// namespace ck
include/ck/tensor_description/tensor_adaptor.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
namespace
ck
{
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// BottomDimensionHiddenIds : Sequence<...>
// TopDimensionHiddenIds : Sequence<...>
template
<
typename
Transforms
,
typename
LowerDimensionHiddenIdss
,
typename
UpperDimensionHiddenIdss
,
typename
BottomDimensionHiddenIds
,
typename
TopDimensionHiddenIds
>
struct
TensorAdaptor
{
__host__
__device__
static
constexpr
index_t
GetNumOfTransform
()
{
return
Transforms
::
Size
();
}
__host__
__device__
constexpr
const
auto
&
GetTransforms
()
const
{
return
transforms_
;
}
__host__
__device__
static
constexpr
auto
GetLowerDimensionHiddenIdss
()
{
return
LowerDimensionHiddenIdss
{};
}
__host__
__device__
static
constexpr
auto
GetUpperDimensionHiddenIdss
()
{
return
UpperDimensionHiddenIdss
{};
}
__host__
__device__
static
constexpr
auto
GetTopDimensionHiddenIds
()
{
return
TopDimensionHiddenIds
{};
}
__host__
__device__
static
constexpr
auto
GetBottomDimensionHiddenIds
()
{
return
BottomDimensionHiddenIds
{};
}
__host__
__device__
static
constexpr
auto
InitializeElementSize
(
const
Transforms
&
transforms
)
{
const
auto
lengths
=
generate_tuple
(
[
&
](
auto
idim_top
)
{
constexpr
auto
tmp
=
GetTransformAndItsUpperDimension
(
idim_top
);
constexpr
index_t
itran
=
tmp
[
Number
<
0
>
{}];
constexpr
index_t
idim_up
=
tmp
[
Number
<
1
>
{}];
constexpr
bool
found
=
tmp
[
Number
<
2
>
{}];
static_assert
(
found
==
true
,
"wrong! not found matching transformation and upper-dimension"
);
const
auto
length
=
transforms
[
Number
<
itran
>
{}].
GetUpperLengths
()[
Number
<
idim_up
>
{}];
return
length
;
},
Number
<
ndim_top_
>
{});
// TODO: make container_reduce support tuple of Number and index_t
return
container_reduce
(
lengths
,
math
::
multiplies
{},
Number
<
1
>
{});
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetTransformAndItsUpperDimension
(
Number
<
IDim
>
)
{
constexpr
auto
idim_top
=
Number
<
IDim
>
{};
constexpr
index_t
idim_hidden
=
TopDimensionHiddenIds
::
At
(
idim_top
);
index_t
itran_found
=
0
;
index_t
idim_up_found
=
0
;
bool
found
=
false
;
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
up_dim_ids
=
UpperDimensionHiddenIdss
{}[
itran
];
static_for
<
0
,
up_dim_ids
.
Size
(),
1
>
{}([
&
](
auto
idim_up
)
{
if
constexpr
(
up_dim_ids
[
idim_up
]
==
idim_hidden
)
{
itran_found
=
itran
;
idim_up_found
=
idim_up
;
found
=
true
;
}
});
});
return
make_tuple
(
itran_found
,
idim_up_found
,
found
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfBottomDimension
()
{
return
BottomDimensionHiddenIds
::
Size
();
}
__host__
__device__
static
constexpr
index_t
GetNumOfTopDimension
()
{
return
TopDimensionHiddenIds
::
Size
();
}
__host__
__device__
static
constexpr
index_t
GetNumOfHiddenDimension
()
{
constexpr
auto
all_low_dim_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
LowerDimensionHiddenIdss
{});
constexpr
auto
all_up_dim_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
UpperDimensionHiddenIdss
{});
constexpr
auto
all_dim_ids
=
merge_sequences
(
all_low_dim_ids
,
all_up_dim_ids
);
using
unique_sort_all_dim_ids
=
typename
sequence_unique_sort
<
decltype
(
all_dim_ids
),
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>::
type
;
return
unique_sort_all_dim_ids
::
Size
();
}
constexpr
static
index_t
ntransform_
=
GetNumOfTransform
();
constexpr
static
index_t
ndim_hidden_
=
GetNumOfHiddenDimension
();
constexpr
static
index_t
ndim_bottom_
=
GetNumOfBottomDimension
();
constexpr
static
index_t
ndim_top_
=
GetNumOfTopDimension
();
using
HiddenIndex
=
MultiIndex
<
ndim_hidden_
>
;
using
BottomIndex
=
MultiIndex
<
ndim_bottom_
>
;
using
TopIndex
=
MultiIndex
<
ndim_top_
>
;
// may be index_t or Number<>
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorAdaptor() = default;
#else
__host__
__device__
constexpr
TensorAdaptor
()
:
transforms_
{},
element_size_
{}
{}
#endif
__host__
__device__
constexpr
TensorAdaptor
(
const
Transforms
&
transforms
)
:
transforms_
{
transforms
},
element_size_
{
InitializeElementSize
(
transforms
)}
{
static_assert
(
Transforms
::
Size
()
==
ntransform_
&&
LowerDimensionHiddenIdss
::
Size
()
==
ntransform_
&&
UpperDimensionHiddenIdss
::
Size
()
==
ntransform_
,
"wrong! inconsistent # of transformations"
);
// TODO check dependency of dimensions is valid
}
__host__
__device__
constexpr
auto
GetElementSize
()
const
{
return
element_size_
;
}
#if 0 // debug
template <index_t I>
__host__ __device__ constexpr index_t GetTopDimensionLength(Number<I> idim) const
{
// TODO: not implemented
}
template <index_t I>
__host__ __device__ constexpr index_t GetBottomDimensionLength(Number<I> idim) const
{
// TODO: not implemented
}
#endif
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
static_assert
(
TopIdx
::
Size
()
==
TopDimensionHiddenIds
::
Size
(),
"wrong! # of dimension inconsistent"
);
constexpr
index_t
ntransform
=
GetNumOfTransform
();
constexpr
index_t
ndim_hidden
=
GetNumOfHiddenDimension
();
MultiIndex
<
ndim_hidden
>
idx_hidden
;
// initialize uppest index
set_container_subset
(
idx_hidden
,
GetTopDimensionHiddenIds
(),
idx_top
);
// calculate hidden index
static_for
<
ntransform
,
0
,
-
1
>
{}([
&
](
auto
itran_p1
)
{
auto
itran
=
itran_p1
-
Number
<
1
>
{};
const
auto
&
tran
=
GetTransforms
().
At
(
itran
);
constexpr
auto
dims_low
=
GetLowerDimensionHiddenIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
GetUpperDimensionHiddenIdss
().
At
(
itran
);
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
dims_up
);
MultiIndex
<
dims_low
.
Size
()
>
idx_low
;
tran
.
CalculateLowerIndex
(
idx_low
,
idx_up
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
});
return
get_container_subset
(
idx_hidden
,
BottomDimensionHiddenIds
{});
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cvref_t
<
decltype
(
Transforms
{}[
i
])
>::
IsKnownAtCompileTime
();
});
return
is_known
&&
is_known_at_compile_time
<
ElementSize
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"TensorAdaptor, "
);
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
i
)
{
printf
(
"transforms: "
);
transforms_
[
i
].
Print
();
printf
(
"LowerDimensionHiddenIds:"
);
LowerDimensionHiddenIdss
{}.
At
(
i
).
Print
();
printf
(
"UpperDimensionHiddenIds:"
);
UpperDimensionHiddenIdss
{}.
At
(
i
).
Print
();
});
printf
(
"BottomDimensionHiddenIds:"
);
BottomDimensionHiddenIds
::
Print
();
printf
(
"TopDimensionHiddenIds:"
);
TopDimensionHiddenIds
::
Print
();
printf
(
"}"
);
}
private:
Transforms
transforms_
;
ElementSize
element_size_
;
};
template
<
typename
TensorAdaptor0
,
typename
TensorAdaptor1
>
__host__
__device__
constexpr
auto
chain_tensor_adaptors
(
const
TensorAdaptor0
&
adaptor0
,
const
TensorAdaptor1
&
adaptor1
)
{
static_assert
(
TensorAdaptor0
::
GetNumOfTopDimension
()
==
TensorAdaptor1
::
GetNumOfBottomDimension
(),
"wrong!"
);
// all_transforms = transform0 + transform1
const
auto
all_transforms
=
container_concat
(
adaptor0
.
GetTransforms
(),
adaptor1
.
GetTransforms
());
// shift
constexpr
index_t
adaptor0_max_hidden_id
=
[
&
]()
{
index_t
adaptor0_max_hidden_id_
=
NumericLimits
<
index_t
>::
Min
();
static_for
<
0
,
TensorAdaptor0
::
GetNumOfTransform
(),
1
>
{}([
&
](
auto
itran
)
{
constexpr
index_t
ndim_low
=
TensorAdaptor0
{}.
GetTransforms
()[
itran
].
GetNumOfLowerDimension
();
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
idim_low
)
{
adaptor0_max_hidden_id_
=
math
::
max
(
adaptor0_max_hidden_id_
,
TensorAdaptor0
::
GetLowerDimensionHiddenIdss
()[
itran
][
idim_low
].
value
);
});
constexpr
index_t
ndim_up
=
TensorAdaptor0
{}.
GetTransforms
()[
itran
].
GetNumOfUpperDimension
();
static_for
<
0
,
ndim_up
,
1
>
{}([
&
](
auto
idim_up
)
{
adaptor0_max_hidden_id_
=
math
::
max
(
adaptor0_max_hidden_id_
,
TensorAdaptor0
::
GetUpperDimensionHiddenIdss
()[
itran
][
idim_up
].
value
);
});
});
return
adaptor0_max_hidden_id_
;
}();
constexpr
index_t
adaptor1_min_hidden_id
=
[
&
]()
{
index_t
adaptor1_min_hidden_id_
=
NumericLimits
<
index_t
>::
Max
();
static_for
<
0
,
TensorAdaptor1
::
GetNumOfTransform
(),
1
>
{}([
&
](
auto
itran
)
{
constexpr
index_t
ndim_low
=
TensorAdaptor1
{}.
GetTransforms
()[
itran
].
GetNumOfLowerDimension
();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
idim_low
)
{
constexpr
index_t
low_dim_hidden_id
=
TensorAdaptor1
::
GetLowerDimensionHiddenIdss
()[
itran
][
idim_low
].
value
;
bool
is_bottom_dim
=
false
;
static_for
<
0
,
TensorAdaptor1
::
GetNumOfBottomDimension
(),
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
low_dim_hidden_id
==
TensorAdaptor1
::
GetBottomDimensionHiddenIds
()[
i
])
{
is_bottom_dim
=
true
;
}
});
if
(
!
is_bottom_dim
)
{
adaptor1_min_hidden_id_
=
math
::
min
(
adaptor1_min_hidden_id_
,
low_dim_hidden_id
);
}
});
constexpr
index_t
ndim_up
=
TensorAdaptor1
{}.
GetTransforms
()[
itran
].
GetNumOfUpperDimension
();
// get the min of all upper dimensions
static_for
<
0
,
ndim_up
,
1
>
{}([
&
](
auto
idim_up
)
{
adaptor1_min_hidden_id_
=
math
::
min
(
adaptor1_min_hidden_id_
,
TensorAdaptor1
::
GetUpperDimensionHiddenIdss
()[
itran
][
idim_up
].
value
);
});
});
return
adaptor1_min_hidden_id_
;
}();
constexpr
index_t
adaptor1_hidden_id_shift
=
adaptor0_max_hidden_id
+
1
-
adaptor1_min_hidden_id
;
constexpr
index_t
ndim_bottom_1
=
TensorAdaptor1
::
GetNumOfBottomDimension
();
// all_low_dim_hidden_idss =
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
constexpr
auto
low_dim_hidden_idss_1
=
generate_tuple
(
// generate sequence of ids for a transform
[
&
](
auto
itran
)
{
constexpr
auto
ndim_low_1
=
TensorAdaptor1
::
GetLowerDimensionHiddenIdss
()[
itran
].
Size
();
constexpr
auto
low_dim_hidden_ids_1
=
TensorAdaptor1
::
GetLowerDimensionHiddenIdss
()[
itran
];
// sequence in, sequence out
constexpr
auto
low_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
auto
low_dim_hidden_ids_1_mod_
=
to_multi_index
(
low_dim_hidden_ids_1
);
// shift hidden id so every dim id is unique
static_for
<
0
,
ndim_low_1
,
1
>
{}([
&
](
auto
idim_low_1
)
{
low_dim_hidden_ids_1_mod_
(
idim_low_1
)
+=
adaptor1_hidden_id_shift
;
});
// match hidden id
static_for
<
0
,
ndim_low_1
,
1
>
{}([
&
](
auto
idim_low_1
)
{
static_for
<
0
,
ndim_bottom_1
,
1
>
{}([
&
](
auto
idim_bottom_1
)
{
// if this low dim is bottom dim, then do id matching
if
constexpr
(
low_dim_hidden_ids_1
[
idim_low_1
]
==
TensorAdaptor1
::
GetBottomDimensionHiddenIds
()[
idim_bottom_1
])
{
low_dim_hidden_ids_1_mod_
(
idim_low_1
)
=
TensorAdaptor0
::
GetTopDimensionHiddenIds
()[
idim_bottom_1
];
}
});
});
return
low_dim_hidden_ids_1_mod_
;
}
();
return
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
return
Number
<
low_dim_hidden_ids_1_mod
[
i
]
>
{};
},
Number
<
ndim_low_1
>
{});
},
Number
<
TensorAdaptor1
::
GetNumOfTransform
()
>
{});
constexpr
auto
all_low_dim_hidden_idss
=
container_concat
(
TensorAdaptor0
::
GetLowerDimensionHiddenIdss
(),
low_dim_hidden_idss_1
);
// all_up_dim_hidden_idss =
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
constexpr
auto
up_dim_hidden_idss_1
=
generate_tuple
(
// generate sequence of ids for a transform
[
&
](
auto
itran
)
{
constexpr
auto
ndim_up_1
=
TensorAdaptor1
::
GetUpperDimensionHiddenIdss
()[
itran
].
Size
();
constexpr
auto
up_dim_hidden_ids_1
=
TensorAdaptor1
::
GetUpperDimensionHiddenIdss
()[
itran
];
// sequence in, constexpr tuple out
constexpr
auto
up_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
auto
up_dim_hidden_ids_1_mod_
=
to_multi_index
(
up_dim_hidden_ids_1
);
// shift hidden id
static_for
<
0
,
ndim_up_1
,
1
>
{}([
&
](
auto
idim_up_1
)
{
up_dim_hidden_ids_1_mod_
(
idim_up_1
)
+=
adaptor1_hidden_id_shift
;
});
return
up_dim_hidden_ids_1_mod_
;
}
();
// constexpr tuple to sequence
return
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
return
Number
<
up_dim_hidden_ids_1_mod
[
i
]
>
{};
},
Number
<
ndim_up_1
>
{});
},
Number
<
TensorAdaptor1
::
GetNumOfTransform
()
>
{});
constexpr
auto
all_up_dim_hidden_idss
=
container_concat
(
TensorAdaptor0
::
GetUpperDimensionHiddenIdss
(),
up_dim_hidden_idss_1
);
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
constexpr
auto
bottom_dim_hidden_ids
=
TensorAdaptor0
::
GetBottomDimensionHiddenIds
();
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr
auto
top_dim_hidden_ids
=
TensorAdaptor1
::
GetTopDimensionHiddenIds
()
+
Number
<
adaptor1_hidden_id_shift
>
{};
// put everything together
return
TensorAdaptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
bottom_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
all_transforms
};
}
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template
<
typename
Transforms
,
typename
LowerDimensionOldTopIdss
,
typename
UpperDimensionNewTopIdss
>
__host__
__device__
constexpr
auto
make_single_stage_tensor_adaptor
(
const
Transforms
&
transforms
,
LowerDimensionOldTopIdss
,
UpperDimensionNewTopIdss
)
{
constexpr
index_t
ntransform
=
Transforms
::
Size
();
static_assert
(
LowerDimensionOldTopIdss
::
Size
()
==
ntransform
&&
UpperDimensionNewTopIdss
::
Size
()
==
ntransform
,
"wrong!"
);
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr
auto
all_low_dim_old_top_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
LowerDimensionOldTopIdss
{});
constexpr
auto
all_up_dim_new_top_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
UpperDimensionNewTopIdss
{});
static_assert
(
is_valid_sequence_map
<
decltype
(
all_low_dim_old_top_ids
)
>::
value
&&
is_valid_sequence_map
<
decltype
(
all_up_dim_new_top_ids
)
>::
value
,
"wrong!"
);
constexpr
index_t
ndim_old_top
=
all_low_dim_old_top_ids
.
Size
();
constexpr
index_t
ndim_new_top
=
all_up_dim_new_top_ids
.
Size
();
// low_dim_hidden_idss
constexpr
auto
low_dim_hidden_idss
=
LowerDimensionOldTopIdss
{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr
auto
up_dim_hidden_idss
=
generate_tuple
(
[](
auto
itran
)
{
return
UpperDimensionNewTopIdss
{}[
itran
]
+
Number
<
ndim_old_top
>
{};
},
Number
<
ntransform
>
{});
// bottom_dim_hidden_ids
constexpr
auto
bottom_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
0
,
ndim_old_top
,
1
>::
type
{};
// top_dim_hidden_ids
constexpr
auto
top_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
0
,
ndim_new_top
,
1
>::
type
{}
+
Number
<
ndim_old_top
>
{};
return
TensorAdaptor
<
remove_cv_t
<
Transforms
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
bottom_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
transforms
};
}
template
<
typename
X
,
typename
...
Xs
,
typename
enable_if
<
sizeof
...(
Xs
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
auto
chain_tensor_adaptors
(
const
X
&
x
,
const
Xs
&
...
xs
)
{
return
chain_tensor_adaptors
(
x
,
chain_tensor_adaptors
(
xs
...));
}
}
// namespace ck
include/ck/tensor_description/tensor_descriptor.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/tensor_description/multi_index_transform.hpp"
namespace
ck
{
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
struct
TensorCoordinate
;
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
TensorCoordinateStep
;
// Transforms: Tuple<transforms...>
// LowerDimensionIdss : Tuple<Sequence<...>, ...>
// UpperDimensionIdss : Tuple<Sequence<...>, ...>
// VisibleDimensionIds> : Sequence<...>
template
<
typename
Transforms
,
typename
LowerDimensionIdss
,
typename
UpperDimensionIdss
,
typename
VisibleDimensionIds
,
typename
ElementSpaceSize
>
struct
TensorDescriptor
{
// TODO make these private
__host__
__device__
static
constexpr
index_t
GetNumOfTransform
()
{
return
Transforms
::
Size
();
}
__host__
__device__
static
constexpr
index_t
GetNumOfVisibleDimension
()
{
return
VisibleDimensionIds
::
Size
();
}
__host__
__device__
static
constexpr
index_t
GetNumOfHiddenDimension
()
{
constexpr
auto
all_low_dim_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
LowerDimensionIdss
{});
constexpr
auto
all_up_dim_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
UpperDimensionIdss
{});
constexpr
auto
all_dim_ids
=
merge_sequences
(
all_low_dim_ids
,
all_up_dim_ids
);
using
unique_sort_all_dim_ids
=
typename
sequence_unique_sort
<
decltype
(
all_dim_ids
),
math
::
less
<
index_t
>
,
math
::
equal
<
index_t
>>::
type
;
return
unique_sort_all_dim_ids
::
Size
();
}
__host__
__device__
static
constexpr
auto
InitializeElementSize
(
const
Transforms
&
transforms
)
{
const
auto
lengths
=
generate_tuple
(
[
&
](
auto
idim_visible
)
{
constexpr
auto
tmp
=
GetTransformAndItsUpperDimension
(
idim_visible
);
constexpr
index_t
itran
=
tmp
[
Number
<
0
>
{}];
constexpr
index_t
idim_up
=
tmp
[
Number
<
1
>
{}];
constexpr
bool
found
=
tmp
[
Number
<
2
>
{}];
static_assert
(
found
==
true
,
"wrong! not found matching transformation and upper-dimension"
);
const
auto
length
=
transforms
[
Number
<
itran
>
{}].
GetUpperLengths
()[
Number
<
idim_up
>
{}];
return
length
;
},
Number
<
ndim_visible_
>
{});
// TODO: make container_reduce support tuple of Number and index_t
return
container_reduce
(
lengths
,
math
::
multiplies
{},
Number
<
1
>
{});
}
template
<
index_t
IDim
>
__host__
__device__
static
constexpr
auto
GetTransformAndItsUpperDimension
(
Number
<
IDim
>
)
{
constexpr
auto
idim_visible
=
Number
<
IDim
>
{};
constexpr
index_t
idim_hidden
=
VisibleDimensionIds
::
At
(
idim_visible
);
index_t
itran_found
=
0
;
index_t
idim_up_found
=
0
;
bool
found
=
false
;
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
up_dim_ids
=
UpperDimensionIdss
{}[
itran
];
static_for
<
0
,
up_dim_ids
.
Size
(),
1
>
{}([
&
](
auto
idim_up
)
{
if
constexpr
(
up_dim_ids
[
idim_up
]
==
idim_hidden
)
{
itran_found
=
itran
;
idim_up_found
=
idim_up
;
found
=
true
;
}
});
});
return
make_tuple
(
itran_found
,
idim_up_found
,
found
);
}
constexpr
static
index_t
ntransform_
=
GetNumOfTransform
();
constexpr
static
index_t
ndim_visible_
=
GetNumOfVisibleDimension
();
constexpr
static
index_t
ndim_hidden_
=
GetNumOfHiddenDimension
();
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
using
HiddenIndex
=
MultiIndex
<
ndim_hidden_
>
;
using
Coordinate
=
TensorCoordinate
<
ndim_hidden_
,
VisibleDimensionIds
>
;
// may be index_t or Number<>
using
ElementSize
=
remove_cv_t
<
decltype
(
InitializeElementSize
(
Transforms
{}))
>
;
public:
#if 0 // workaround compiler complaint about constexpr
__host__ __device__ constexpr TensorDescriptor() = default;
#else
__host__
__device__
constexpr
TensorDescriptor
()
:
transforms_
{},
element_size_
{},
element_space_size_
{}
{
}
#endif
__host__
__device__
constexpr
TensorDescriptor
(
const
Transforms
&
transforms
,
ElementSpaceSize
element_space_size
)
:
transforms_
{
transforms
},
element_size_
{
InitializeElementSize
(
transforms
)},
element_space_size_
{
element_space_size
}
{
static_assert
(
Transforms
::
Size
()
==
ntransform_
&&
LowerDimensionIdss
::
Size
()
==
ntransform_
&&
UpperDimensionIdss
::
Size
()
==
ntransform_
,
"wrong! inconsistent # of transformations"
);
// TODO check dependency of dimensions is valid
}
__host__
__device__
static
constexpr
index_t
GetNumOfDimension
()
{
return
GetNumOfVisibleDimension
();
}
template
<
index_t
IDim
>
__host__
__device__
constexpr
auto
GetLength
(
Number
<
IDim
>
)
const
{
static_assert
(
IDim
>=
0
&&
IDim
<
ndim_visible_
,
"wrong! out of range"
);
constexpr
auto
tmp
=
GetTransformAndItsUpperDimension
(
Number
<
IDim
>
{});
constexpr
index_t
itran
=
tmp
[
Number
<
0
>
{}];
constexpr
index_t
idim_up
=
tmp
[
Number
<
1
>
{}];
constexpr
bool
found
=
tmp
[
Number
<
2
>
{}];
static_assert
(
found
==
true
,
"wrong! not found matching transformation and upper-dimension"
);
return
transforms_
[
Number
<
itran
>
{}].
GetUpperLengths
()[
Number
<
idim_up
>
{}];
}
__host__
__device__
constexpr
auto
GetLengths
()
const
{
// FIXME: use Tuple of reference instead
return
generate_sequence_v2
([
&
](
auto
I
)
{
return
GetLength
(
I
);
},
Number
<
ndim_visible_
>
{});
}
__host__
__device__
constexpr
auto
GetElementSize
()
const
{
return
element_size_
;
}
__host__
__device__
constexpr
auto
GetElementSpaceSize
()
const
{
return
element_space_size_
;
}
template
<
typename
Idx
>
__host__
__device__
constexpr
index_t
CalculateOffset
(
const
Idx
&
idx
)
const
{
static_assert
(
Idx
::
Size
()
==
GetNumOfDimension
(),
"wrong! inconsistent # of dimension"
);
return
make_tensor_coordinate
(
*
this
,
idx
).
GetOffset
();
}
// TODO make these private
__host__
__device__
constexpr
const
auto
&
GetTransforms
()
const
{
return
transforms_
;
}
__host__
__device__
static
constexpr
auto
GetLowerDimensionIdss
()
{
return
LowerDimensionIdss
{};
}
__host__
__device__
static
constexpr
auto
GetUpperDimensionIdss
()
{
return
UpperDimensionIdss
{};
}
__host__
__device__
static
constexpr
auto
GetVisibleDimensionIds
()
{
return
VisibleDimensionIds
{};
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
Size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cvref_t
<
decltype
(
Transforms
{}[
i
])
>::
IsKnownAtCompileTime
();
});
return
is_known
&&
is_known_at_compile_time
<
ElementSize
>::
value
&&
is_known_at_compile_time
<
ElementSpaceSize
>::
value
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"TensorDescriptor, "
);
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
i
)
{
printf
(
"transforms: "
);
transforms_
[
i
].
Print
();
printf
(
"LowerDimensionIds:"
);
LowerDimensionIdss
{}.
At
(
i
).
Print
();
printf
(
"UpperDimensionIds:"
);
UpperDimensionIdss
{}.
At
(
i
).
Print
();
});
printf
(
"}"
);
VisibleDimensionIds
::
Print
();
}
// TODO make these private
Transforms
transforms_
;
ElementSize
element_size_
;
ElementSpaceSize
element_space_size_
;
};
template
<
index_t
NDimHidden
,
typename
VisibleDimensionIds
>
struct
TensorCoordinate
{
// TODO make these private
static
constexpr
index_t
ndim_visible_
=
VisibleDimensionIds
::
Size
();
using
HiddenIndex
=
MultiIndex
<
NDimHidden
>
;
using
VisibleIndex
=
MultiIndex
<
ndim_visible_
>
;
public:
__host__
__device__
constexpr
TensorCoordinate
()
=
default
;
__host__
__device__
constexpr
TensorCoordinate
(
const
HiddenIndex
&
idx_hidden
)
:
idx_hidden_
{
idx_hidden
}
{
}
__host__
__device__
constexpr
auto
GetIndex
()
const
{
return
GetVisibleIndex
();
}
__host__
__device__
constexpr
index_t
GetOffset
()
const
{
return
idx_hidden_
[
Number
<
0
>
{}];
}
// TODO make these private
__host__
__device__
constexpr
const
auto
&
GetHiddenIndex
()
const
{
return
idx_hidden_
;
}
__host__
__device__
auto
&
GetHiddenIndex
()
{
return
idx_hidden_
;
}
__host__
__device__
constexpr
auto
GetVisibleIndex
()
const
{
return
get_container_subset
(
idx_hidden_
,
VisibleDimensionIds
{});
}
// TODO make these private
HiddenIndex
idx_hidden_
;
};
template
<
index_t
NTransform
,
index_t
NDimVisible
,
typename
UpdateLowerIndexHack
>
struct
TensorCoordinateStep
{
// TODO make these private
using
VisibleIndex
=
MultiIndex
<
NDimVisible
>
;
public:
__host__
__device__
constexpr
TensorCoordinateStep
()
=
default
;
__host__
__device__
constexpr
TensorCoordinateStep
(
const
VisibleIndex
&
idx_diff_visible
,
const
MultiIndex
<
NTransform
>&
do_transforms
)
:
idx_diff_visible_
{
idx_diff_visible
},
do_transforms_
{
do_transforms
}
{
}
__host__
__device__
constexpr
const
auto
&
GetIndexDiff
()
const
{
return
GetVisibleIndexDiff
();
}
// TODO make these private
__host__
__device__
constexpr
const
auto
&
GetVisibleIndexDiff
()
const
{
return
idx_diff_visible_
;
}
VisibleIndex
idx_diff_visible_
;
MultiIndex
<
NTransform
>
do_transforms_
;
// HACK: control UpdateLowerIndex()
static
constexpr
UpdateLowerIndexHack
update_lower_index_hack_
;
};
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_tensor_descriptor) because template cannot be defined inside a function
// template
template
<
typename
NewTransforms
>
struct
lambda_get_up_dim_num
{
template
<
typename
I
>
__host__
__device__
constexpr
auto
operator
()(
I
)
const
{
using
Tran
=
remove_reference_t
<
decltype
(
NewTransforms
{}.
At
(
I
{}))
>
;
return
Number
<
Tran
::
GetNumOfUpperDimension
()
>
{};
}
};
template
<
typename
OldTensorDescriptor
,
typename
NewTransforms
,
typename
NewLowerDimensionOldVisibleIdss
,
typename
NewUpperDimensionNewVisibleIdss
>
__host__
__device__
constexpr
auto
transform_tensor_descriptor
(
const
OldTensorDescriptor
&
old_tensor_desc
,
const
NewTransforms
&
new_transforms
,
NewLowerDimensionOldVisibleIdss
,
NewUpperDimensionNewVisibleIdss
)
{
// sanity check
{
static_assert
(
NewTransforms
::
Size
()
==
NewLowerDimensionOldVisibleIdss
::
Size
()
&&
NewTransforms
::
Size
()
==
NewUpperDimensionNewVisibleIdss
::
Size
(),
"wrong! inconsitent number of transform"
);
constexpr
auto
all_old_top_ids
=
unpack
([](
auto
...
xs
)
{
return
merge_sequences
(
xs
...);
},
NewLowerDimensionOldVisibleIdss
{});
constexpr
auto
all_new_top_ids
=
unpack
([](
auto
...
xs
)
{
return
merge_sequences
(
xs
...);
},
NewUpperDimensionNewVisibleIdss
{});
static_assert
(
is_valid_sequence_map
<
decltype
(
all_old_top_ids
)
>::
value
&&
is_valid_sequence_map
<
decltype
(
all_new_top_ids
)
>::
value
,
"wrong!"
);
}
// lower dimension's hidden idss
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr
auto
low_dim_hidden_idss
=
transform_tuples
(
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
[](
auto
low_dim_visible_ids
)
constexpr
{
return
transform_sequences
(
// convert lower dimension visible id to hidden id
[](
auto
low_dim_visible_id
)
constexpr
{
return
OldTensorDescriptor
::
GetVisibleDimensionIds
()[
low_dim_visible_id
];
},
low_dim_visible_ids
);
},
NewLowerDimensionOldVisibleIdss
{});
constexpr
index_t
num_new_transform
=
NewTransforms
::
Size
();
// upper dimension's hidden idss
constexpr
index_t
old_hidden_dim_number
=
OldTensorDescriptor
::
GetNumOfHiddenDimension
();
constexpr
auto
up_dim_numbers
=
generate_sequence
(
lambda_get_up_dim_num
<
NewTransforms
>
{},
Number
<
num_new_transform
>
{});
constexpr
auto
up_dim_numbers_scan
=
merge_sequences
(
Sequence
<
0
>
{},
inclusive_scan_sequence
(
up_dim_numbers
,
math
::
plus
<
index_t
>
{},
Number
<
0
>
{}));
constexpr
auto
up_dim_hidden_idss
=
generate_tuple
(
[
old_hidden_dim_number
,
up_dim_numbers_scan
](
auto
i
)
constexpr
{
return
typename
arithmetic_sequence_gen
<
old_hidden_dim_number
+
up_dim_numbers_scan
[
i
],
old_hidden_dim_number
+
up_dim_numbers_scan
[
i
+
1
],
1
>::
type
{};
},
Number
<
num_new_transform
>
{});
// new visible dimension's hidden ids
constexpr
auto
unordered_new_visible_dim_hidden_ids
=
unpack
(
[](
auto
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
up_dim_hidden_idss
);
constexpr
auto
new_visible_dim_unordered2ordered
=
unpack
(
[](
auto
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
NewUpperDimensionNewVisibleIdss
{});
constexpr
auto
new_visible_dim_hidden_ids
=
unordered_new_visible_dim_hidden_ids
.
ReorderGivenOld2New
(
new_visible_dim_unordered2ordered
);
// put everything together
const
auto
all_transforms
=
container_concat
(
old_tensor_desc
.
GetTransforms
(),
new_transforms
);
constexpr
auto
all_low_dim_hidden_idss
=
container_concat
(
OldTensorDescriptor
::
GetLowerDimensionIdss
(),
low_dim_hidden_idss
);
constexpr
auto
all_up_dim_hidden_idss
=
container_concat
(
OldTensorDescriptor
::
GetUpperDimensionIdss
(),
up_dim_hidden_idss
);
const
auto
element_space_size
=
old_tensor_desc
.
GetElementSpaceSize
();
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
all_transforms
)
>
,
remove_cv_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
new_visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
all_transforms
,
element_space_size
};
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
make_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
const
VisibleIndex
&
idx_visible
)
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
"wrong! # of dimension inconsistent"
);
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
auto
visible_dim_ids
=
TensorDesc
::
GetVisibleDimensionIds
();
MultiIndex
<
ndim_hidden
>
idx_hidden
;
// initialize visible index
set_container_subset
(
idx_hidden
,
visible_dim_ids
,
idx_visible
);
// calculate hidden index
static_for
<
ntransform
,
0
,
-
1
>
{}([
&
tensor_desc
,
&
idx_hidden
](
auto
itran_p1
)
{
auto
itran
=
itran_p1
-
Number
<
1
>
{};
const
auto
&
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
dims_up
);
MultiIndex
<
dims_low
.
Size
()
>
idx_low
;
tran
.
CalculateLowerIndex
(
idx_low
,
idx_up
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
});
return
TensorCoordinate
<
ndim_hidden
,
decltype
(
visible_dim_ids
)
>
{
idx_hidden
};
}
// UpdateLowerIndexHack: Sequence<...>
// HACK: control UpdateLowerIndex
template
<
typename
TensorDesc
,
typename
VisibleIndex
,
typename
UpdateLowerIndexHack
>
__host__
__device__
constexpr
auto
make_tensor_coordinate_step
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
,
UpdateLowerIndexHack
)
{
static_assert
(
TensorDesc
::
GetNumOfDimension
()
==
VisibleIndex
::
Size
(),
"wrong! # of dimension inconsistent"
);
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
index_t
ndim_visible
=
TensorDesc
::
GetNumOfVisibleDimension
();
constexpr
auto
visible_dim_ids
=
TensorDesc
::
GetVisibleDimensionIds
();
static_assert
(
UpdateLowerIndexHack
::
Size
()
==
ntransform
,
"wrong!"
);
// use index_t for boolean type
auto
do_transforms
=
make_zero_multi_index
<
ntransform
>
();
auto
is_non_zero_diff
=
make_zero_multi_index
<
ndim_hidden
>
();
// decide do_transform by checkout non-zero index diff components
MultiIndex
<
VisibleIndex
::
Size
()
>
non_zero_diff_pick_visible
;
static_for
<
0
,
ndim_visible
,
1
>
{}(
[
&
](
auto
i
)
{
non_zero_diff_pick_visible
(
i
)
=
(
idx_diff_visible
[
i
]
!=
0
);
});
set_container_subset
(
is_non_zero_diff
,
visible_dim_ids
,
non_zero_diff_pick_visible
);
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
const
auto
non_zero_diff_pick_up
=
get_container_subset
(
is_non_zero_diff
,
dims_up
);
MultiIndex
<
dims_low
.
Size
()
>
non_zero_diff_pick_low
;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const
bool
idx_diff_up_has_non_zero
=
container_reduce
(
non_zero_diff_pick_up
,
[](
auto
a
,
auto
b
)
constexpr
{
return
a
or
b
;
},
false
);
do_transforms
(
itran
)
=
idx_diff_up_has_non_zero
;
static_for
<
0
,
dims_low
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
non_zero_diff_pick_low
(
i
)
=
idx_diff_up_has_non_zero
;
});
set_container_subset
(
is_non_zero_diff
,
dims_low
,
non_zero_diff_pick_low
);
});
return
TensorCoordinateStep
<
ntransform
,
ndim_visible
,
UpdateLowerIndexHack
>
{
idx_diff_visible
,
do_transforms
};
}
template
<
typename
TensorDesc
,
typename
VisibleIndex
>
__host__
__device__
constexpr
auto
make_tensor_coordinate_step
(
const
TensorDesc
&
,
const
VisibleIndex
&
idx_diff_visible
)
{
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
return
make_tensor_coordinate_step
(
TensorDesc
{},
idx_diff_visible
,
typename
uniform_sequence_gen
<
ntransform
,
0
>::
type
{});
}
template
<
typename
TensorDesc
,
typename
TensorCoord
,
typename
TensorCoordStep
>
__host__
__device__
constexpr
void
move_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
TensorCoord
&
coord
,
const
TensorCoordStep
&
coord_step
)
{
constexpr
index_t
ndim_hidden
=
TensorDesc
::
GetNumOfHiddenDimension
();
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
// this is what needs to be calculated
auto
idx_diff_hidden
=
make_zero_multi_index
<
ndim_hidden
>
();
// initialize visible index diff
set_container_subset
(
idx_diff_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
coord_step
.
GetVisibleIndexDiff
());
// this is what needs to be updated
auto
&
idx_hidden
=
coord
.
GetHiddenIndex
();
// update visible index
auto
idx_hidden_pick_visible
=
get_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
());
idx_hidden_pick_visible
+=
coord_step
.
GetIndexDiff
();
set_container_subset
(
idx_hidden
,
TensorDesc
::
GetVisibleDimensionIds
(),
idx_hidden_pick_visible
);
// update rest of hidden index
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
if
(
coord_step
.
do_transforms_
[
itran
])
{
const
auto
&
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
constexpr
auto
dims_low
=
TensorDesc
::
GetLowerDimensionIdss
().
At
(
itran
);
constexpr
auto
dims_up
=
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
);
const
auto
idx_up_new
=
get_container_subset
(
idx_hidden
,
dims_up
);
auto
idx_low
=
get_container_subset
(
idx_hidden
,
dims_low
);
const
auto
idx_diff_up
=
get_container_subset
(
idx_diff_hidden
,
dims_up
);
MultiIndex
<
dims_low
.
Size
()
>
idx_diff_low
;
// HACK: control UpdateLowerIndex for Merge using hack
constexpr
index_t
Hack
=
decltype
(
coord_step
.
update_lower_index_hack_
)
::
At
(
itran
);
tran
.
UpdateLowerIndex
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
,
Number
<
Hack
>
{});
set_container_subset
(
idx_diff_hidden
,
dims_low
,
idx_diff_low
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
}
});
}
template
<
typename
TensorDesc
,
typename
TensorCoord
>
__host__
__device__
constexpr
bool
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
const
TensorDesc
&
tensor_desc
,
const
TensorCoord
&
coord
)
{
bool
valid
=
true
;
constexpr
index_t
ntransform
=
TensorDesc
::
GetNumOfTransform
();
const
auto
&
idx_hidden
=
coord
.
GetHiddenIndex
();
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
tensor_desc
,
&
idx_hidden
,
&
valid
](
auto
itran
)
{
const
auto
tran
=
tensor_desc
.
GetTransforms
().
At
(
itran
);
// check validity, only if current transformation does not always has a valid mapping
if
constexpr
(
!
decltype
(
tran
)
::
IsValidUpperIndexAlwaysMappedToValidLowerIndex
())
{
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
TensorDesc
::
GetUpperDimensionIdss
().
At
(
itran
));
// Comment: using valid = valid && .. will result in weird control flow in ISA
valid
&=
tran
.
IsValidUpperIndexMappedToValidLowerIndex
(
idx_up
);
}
});
return
valid
;
}
template
<
typename
TensorDesc
,
typename
TensorCoord
>
__host__
__device__
constexpr
bool
coordinate_has_valid_offset
(
const
TensorDesc
&
tensor_desc
,
const
TensorCoord
&
coord
)
{
// check visible index
const
auto
&
idx_visible
=
coord
.
GetVisibleIndex
();
bool
is_visible_index_valid
=
true
;
static_for
<
0
,
TensorDesc
::
GetNumOfDimension
(),
1
>
{}(
[
&
is_visible_index_valid
,
&
idx_visible
,
&
tensor_desc
](
auto
i
)
{
is_visible_index_valid
=
is_visible_index_valid
&&
(
idx_visible
[
i
]
>=
0
&&
idx_visible
[
i
]
<
tensor_desc
.
GetLength
(
i
));
});
// check other hidden index
return
is_visible_index_valid
&&
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
tensor_desc
,
coord
);
}
template
<
typename
TensorDesc
>
using
TensorCoordinate_t
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
MultiIndex
<
remove_cvref_t
<
TensorDesc
>::
GetNumOfDimension
()
>
{}));
template
<
typename
TensorDesc
>
using
TensorCoordinateStep_t
=
decltype
(
make_tensor_coordinate_step
(
TensorDesc
{},
MultiIndex
<
remove_cvref_t
<
TensorDesc
>::
GetNumOfDimension
()
>
{}));
}
// namespace ck
include/ck/tensor_description/tensor_descriptor_helper.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
namespace
ck
{
/*
* These functions create tensor descriptor at runtime. If they are not constexpr, you will
* likely see usage of scratch memory during construction of these tensor descriptors. So
* it's better to call these functions on host and then pass the constructed tensor descritpors
* to GPU. If the tensor descritpors being constructed are constexpr, then you can call these
* functions on GPU without worrying about scratch memory usage.
*/
#if CK_WORKAROUND_SWDEV_275126
template
<
typename
Lengths
,
typename
Strides
,
index_t
I
,
typename
AccOld
>
__host__
__device__
constexpr
auto
calculate_element_space_size_impl
(
const
Lengths
&
lengths
,
const
Strides
&
strides
,
Number
<
I
>
i
,
AccOld
acc_old
)
{
auto
acc_new
=
acc_old
+
(
lengths
[
i
]
-
Number
<
1
>
{})
*
strides
[
i
];
if
constexpr
(
i
.
value
<
Lengths
::
Size
()
-
1
)
{
return
calculate_element_space_size_impl
(
lengths
,
strides
,
i
+
Number
<
1
>
{},
acc_new
);
}
else
{
return
acc_new
;
}
}
#endif
// Lengths..., Strides... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) LongNumber<>
template
<
typename
...
Lengths
,
typename
...
Strides
,
typename
enable_if
<
sizeof
...(
Lengths
)
==
sizeof
...(
Strides
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
auto
make_naive_tensor_descriptor
(
const
Tuple
<
Lengths
...
>&
lengths
,
const
Tuple
<
Strides
...
>&
strides
)
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
const
auto
transforms
=
make_tuple
(
make_embed_transform
(
lengths
,
strides
));
constexpr
auto
low_dim_hidden_idss
=
make_tuple
(
Sequence
<
0
>
{});
constexpr
auto
up_dim_hidden_idss
=
make_tuple
(
typename
arithmetic_sequence_gen
<
1
,
N
+
1
,
1
>::
type
{});
constexpr
auto
visible_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
1
,
N
+
1
,
1
>::
type
{};
#if !CK_WORKAROUND_SWDEV_275126
// rocm-4.1 compiler would crash for recursive labmda
// recursive function for reduction
auto
f
=
[
&
](
auto
fs
,
auto
i
,
auto
acc_old
)
{
auto
acc_new
=
acc_old
+
(
lengths
[
i
]
-
Number
<
1
>
{})
*
strides
[
i
];
if
constexpr
(
i
.
value
<
N
-
1
)
{
return
fs
(
fs
,
i
+
Number
<
1
>
{},
acc_new
);
}
else
{
return
acc_new
;
}
};
const
auto
element_space_size
=
f
(
f
,
Number
<
0
>
{},
LongNumber
<
1
>
{});
#else
const
auto
element_space_size
=
calculate_element_space_size_impl
(
lengths
,
strides
,
Number
<
0
>
{},
LongNumber
<
1
>
{});
#endif
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) LongNumber<>
template
<
typename
...
Lengths
>
__host__
__device__
constexpr
auto
make_naive_tensor_descriptor_packed
(
const
Tuple
<
Lengths
...
>&
lengths
)
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
const
auto
transforms
=
make_tuple
(
make_unmerge_transform
(
lengths
));
constexpr
auto
low_dim_hidden_idss
=
make_tuple
(
Sequence
<
0
>
{});
constexpr
auto
up_dim_hidden_idss
=
make_tuple
(
typename
arithmetic_sequence_gen
<
1
,
N
+
1
,
1
>::
type
{});
constexpr
auto
visible_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
1
,
N
+
1
,
1
>::
type
{};
const
auto
element_space_size
=
container_reduce
(
lengths
,
math
::
multiplies
{},
LongNumber
<
1
>
{});
return
TensorDescriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>>
{
transforms
,
element_space_size
};
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) Number<>, which is known at compile-time
// align could be:
// 1) index_t, or
// 2) Number<>
template
<
typename
...
Lengths
,
typename
Align
>
__host__
__device__
constexpr
auto
make_naive_tensor_descriptor_aligned
(
const
Tuple
<
Lengths
...
>&
lengths
,
Align
align
)
{
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
index_t
N
=
sizeof
...(
Lengths
);
const
auto
stride_n_minus_2
=
math
::
integer_least_multiple
(
lengths
[
Number
<
N
-
1
>
{}],
align
);
auto
strides
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
.
value
==
N
-
1
)
{
return
I1
;
}
else
if
constexpr
(
i
.
value
==
N
-
2
)
{
return
Number
<
stride_n_minus_2
>
{};
}
else
{
return
container_reduce
(
lengths
,
math
::
multiplies
{},
Number
<
stride_n_minus_2
>
{},
i
+
I1
,
Number
<
N
-
1
>
{},
I1
);
}
},
Number
<
N
>
{});
return
make_naive_tensor_descriptor
(
lengths
,
strides
);
}
}
// namespace ck
include/ck/tensor_description/tensor_space_filling_curve.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/statically_indexed_array_multi_index.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
template
<
typename
TensorLengths
,
typename
DimAccessOrder
,
typename
ScalarsPerAccess
,
bool
SnakeCurved
=
true
>
// # of scalars per access in each dimension
struct
SpaceFillingCurve
{
static
constexpr
index_t
nDim
=
TensorLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
index_t
ScalarPerVector
=
reduce_on_sequence
(
ScalarsPerAccess
{},
math
::
multiplies
{},
Number
<
1
>
{});
static
constexpr
auto
access_lengths
=
TensorLengths
{}
/
ScalarsPerAccess
{};
static
constexpr
auto
dim_access_order
=
DimAccessOrder
{};
static
constexpr
auto
ordered_access_lengths
=
container_reorder_given_new2old
(
access_lengths
,
dim_access_order
);
static
constexpr
auto
to_index_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
ordered_access_lengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
nDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
index_t
GetNumOfAccess
()
{
static_assert
(
TensorLengths
::
Size
()
==
ScalarsPerAccess
::
Size
());
static_assert
(
TensorLengths
{}
%
ScalarsPerAccess
{}
==
typename
uniform_sequence_gen
<
TensorLengths
::
Size
(),
0
>::
type
{});
return
reduce_on_sequence
(
TensorLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})
/
ScalarPerVector
;
}
template
<
index_t
AccessIdx1dBegin
,
index_t
AccessIdx1dEnd
>
static
__device__
__host__
constexpr
auto
GetStepBetween
(
Number
<
AccessIdx1dBegin
>
,
Number
<
AccessIdx1dEnd
>
)
{
static_assert
(
AccessIdx1dBegin
>=
0
,
"1D index should be non-negative"
);
static_assert
(
AccessIdx1dBegin
<
GetNumOfAccess
(),
"1D index should be larger than 0"
);
static_assert
(
AccessIdx1dEnd
>=
0
,
"1D index should be non-negative"
);
static_assert
(
AccessIdx1dEnd
<
GetNumOfAccess
(),
"1D index should be larger than 0"
);
constexpr
auto
idx_begin
=
GetIndex
(
Number
<
AccessIdx1dBegin
>
{});
constexpr
auto
idx_end
=
GetIndex
(
Number
<
AccessIdx1dEnd
>
{});
return
idx_end
-
idx_begin
;
}
template
<
index_t
AccessIdx1d
>
static
__device__
__host__
constexpr
auto
GetForwardStep
(
Number
<
AccessIdx1d
>
)
{
static_assert
(
AccessIdx1d
<
GetNumOfAccess
(),
"1D index should be larger than 0"
);
return
GetStepBetween
(
Number
<
AccessIdx1d
>
{},
Number
<
AccessIdx1d
+
1
>
{});
}
template
<
index_t
AccessIdx1d
>
static
__device__
__host__
constexpr
auto
GetBackwardStep
(
Number
<
AccessIdx1d
>
)
{
static_assert
(
AccessIdx1d
>
0
,
"1D index should be larger than 0"
);
return
GetStepBetween
(
Number
<
AccessIdx1d
>
{},
Number
<
AccessIdx1d
-
1
>
{});
}
template
<
index_t
AccessIdx1d
>
static
__device__
__host__
constexpr
Index
GetIndex
(
Number
<
AccessIdx1d
>
)
{
#if 0
/*
* \todo: TensorAdaptor::CalculateBottomIndex does NOT return constexpr as expected.
*/
constexpr auto ordered_access_idx = to_index_adaptor.CalculateBottomIndex(make_multi_index(Number<AccessIdx1d>{}));
#else
constexpr
auto
access_strides
=
container_reverse_exclusive_scan
(
ordered_access_lengths
,
math
::
multiplies
{},
Number
<
1
>
{});
constexpr
auto
idx_1d
=
Number
<
AccessIdx1d
>
{};
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
constexpr
auto
compute_index
=
[
idx_1d
,
access_strides
](
auto
idim
)
constexpr
{
constexpr
auto
compute_index_impl
=
[
idx_1d
,
access_strides
](
auto
jdim
)
constexpr
{
auto
res
=
idx_1d
.
value
;
auto
id
=
0
;
static_for
<
0
,
jdim
.
value
+
1
,
1
>
{}([
&
](
auto
kdim
)
{
id
=
res
/
access_strides
[
kdim
].
value
;
res
-=
id
*
access_strides
[
kdim
].
value
;
});
return
id
;
};
constexpr
auto
id
=
compute_index_impl
(
idim
);
return
Number
<
id
>
{};
};
constexpr
auto
ordered_access_idx
=
generate_tuple
(
compute_index
,
Number
<
nDim
>
{});
#endif
constexpr
auto
forward_sweep
=
[
&
]()
{
StaticallyIndexedArray
<
bool
,
nDim
>
forward_sweep_
;
forward_sweep_
(
I0
)
=
true
;
static_for
<
1
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
index_t
tmp
=
ordered_access_idx
[
I0
];
static_for
<
1
,
idim
,
1
>
{}(
[
&
](
auto
j
)
{
tmp
=
tmp
*
ordered_access_lengths
[
j
]
+
ordered_access_idx
[
j
];
});
forward_sweep_
(
idim
)
=
tmp
%
2
==
0
;
});
return
forward_sweep_
;
}();
// calculate multi-dim tensor index
auto
idx_md
=
[
&
]()
{
Index
ordered_idx
;
static_for
<
0
,
nDim
,
1
>
{}([
&
](
auto
idim
)
{
ordered_idx
(
idim
)
=
!
SnakeCurved
||
forward_sweep
[
idim
]
?
ordered_access_idx
[
idim
]
:
ordered_access_lengths
[
idim
]
-
1
-
ordered_access_idx
[
idim
];
});
return
container_reorder_given_old2new
(
ordered_idx
,
dim_access_order
)
*
ScalarsPerAccess
{};
}();
return
idx_md
;
}
// FIXME: rename this function
template
<
index_t
AccessIdx1d
>
static
__device__
__host__
constexpr
auto
GetIndexTupleOfNumber
(
Number
<
AccessIdx1d
>
)
{
constexpr
auto
idx
=
GetIndex
(
Number
<
AccessIdx1d
>
{});
return
generate_tuple
([
&
](
auto
i
)
{
return
Number
<
idx
[
i
]
>
{};
},
Number
<
nDim
>
{});
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp"
namespace
ck
{
// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc_BK0_BM_BK1
,
typename
BBlockDesc_BK0_BN_BK1
,
index_t
BM1PerThreadBM11
,
index_t
BN1PerThreadBN11
,
index_t
BK0PerThread
,
typename
BM10BN10ThreadClusterBM10Xs
,
// Sequence<BM10BN10ThreadClusterBM100,
// BM10BN10ThreadClusterBM101, ...>
typename
BM10BN10ThreadClusterBN10Xs
,
// Sequence<BM10BN10ThreadClusterBN100,
// BM10BN10ThreadClusterBN101, ...>
index_t
AThreadCopyScalarPerVector_BM11
,
index_t
BThreadCopyScalarPerVector_BN11
,
typename
enable_if
<
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
BK0
=
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I0
);
static
constexpr
index_t
BK1
=
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I2
);
static
constexpr
index_t
BM
=
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I1
);
static
constexpr
index_t
BN
=
BBlockDesc_BK0_BN_BK1
{}.
GetLength
(
I1
);
static
constexpr
index_t
BM100
=
BM10BN10ThreadClusterBM10Xs
{}[
I0
];
static
constexpr
index_t
BN100
=
BM10BN10ThreadClusterBN10Xs
{}[
I0
];
static
constexpr
index_t
BM101
=
BM10BN10ThreadClusterBM10Xs
{}[
I1
];
static
constexpr
index_t
BN101
=
BM10BN10ThreadClusterBN10Xs
{}[
I1
];
static
constexpr
index_t
BM11
=
BM1PerThreadBM11
;
static
constexpr
index_t
BN11
=
BN1PerThreadBN11
;
static
constexpr
index_t
BM1
=
BM100
*
BM101
*
BM11
;
static
constexpr
index_t
BN1
=
BN100
*
BN101
*
BN11
;
static
constexpr
index_t
BM0
=
BM
/
BM1
;
static
constexpr
index_t
BN0
=
BN
/
BN1
;
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
const
ABlockDesc_BK0_BM_BK1
&
a_block_desc_bk0_bm_bk1
)
{
const
auto
a_block_bk0_bm0_bm1_bk1
=
transform_tensor_descriptor
(
a_block_desc_bk0_bm_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM1
>
{})),
make_pass_through_transform
(
Number
<
BK1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
a_block_bk0_bm0_bm1_bk1
;
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
const
BBlockDesc_BK0_BN_BK1
&
b_block_desc_bk0_bn_bk1
)
{
const
auto
b_block_desc_bk0_bn0_bn1_bk1
=
transform_tensor_descriptor
(
b_block_desc_bk0_bn_bk1
,
make_tuple
(
make_pass_through_transform
(
Number
<
BK0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN1
>
{})),
make_pass_through_transform
(
Number
<
BK1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
b_block_desc_bk0_bn0_bn1_bk1
;
}
__host__
__device__
static
constexpr
auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN
()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM, BN]
constexpr
auto
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
BM0
>
{},
Number
<
BM100
>
{},
Number
<
BM101
>
{},
Number
<
BM11
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
BN0
>
{},
Number
<
BN100
>
{},
Number
<
BN101
>
{},
Number
<
BN11
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
,
5
,
6
,
7
>
{}));
return
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n
;
}
__host__
__device__
static
constexpr
auto
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1
()
{
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// lower: [BM0, BM1, BN0, BN1]
constexpr
auto
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
Number
<
BM0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BM100
>
{},
Number
<
BM101
>
{},
Number
<
BM11
>
{})),
make_pass_through_transform
(
Number
<
BN0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
BN100
>
{},
Number
<
BN101
>
{},
Number
<
BN11
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{}));
return
c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1
;
}
__host__
__device__
static
constexpr
auto
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
()
{
return
Sequence
<
BM0
,
BM11
,
BN0
,
BN11
>
{};
}
static
constexpr
auto
a_block_desc_bk0_bm0_bm1_bk1_
=
MakeABlockDescriptor_BK0_BM0_BM1_BK1
(
ABlockDesc_BK0_BM_BK1
{});
static
constexpr
auto
b_block_desc_bk0_bn0_bn1_bk1_
=
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
BBlockDesc_BK0_BN_BK1
{});
public:
__device__
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
],
0
)},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I2
],
c_thread_origin_data_idx_
[
I3
],
0
)}
{
static_assert
(
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
BM101
*
BM100
*
BN101
*
BN100
,
"wrong! blocksize and cluster size not consistent"
);
static_assert
(
BM
%
BM1
==
0
&&
BN
%
BN1
==
0
,
"wrong!"
);
static_assert
(
ABlockDesc_BK0_BM_BK1
{}.
GetLength
(
I0
)
==
BBlockDesc_BK0_BN_BK1
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
// TODO remove this restriction
static_assert
(
BM10BN10ThreadClusterBM10Xs
::
Size
()
==
2
&&
BM10BN10ThreadClusterBN10Xs
::
Size
()
==
2
,
"wrong!"
);
// TODO: remove this restriction
static_assert
(
BM0
==
2
,
"wrong"
);
static_assert
(
BM0
==
2
&&
BN0
==
2
,
"wrong"
);
}
__device__
static
CIndex
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
index_t
thread_id
)
{
// lower: [BM0, BM1, BN0, BN1]
// upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
constexpr
auto
adaptor0
=
MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1
();
// lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11]
// upper: [Tid, BM0, BM11, BN0, BN11]
constexpr
auto
adaptor1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
BM100
,
BN100
,
BM101
,
BN101
)),
make_pass_through_transform
(
BM0
),
make_pass_through_transform
(
BM11
),
make_pass_through_transform
(
BN0
),
make_pass_through_transform
(
BN11
)),
make_tuple
(
Sequence
<
1
,
5
,
2
,
6
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
constexpr
auto
adaptor
=
chain_tensor_adaptors
(
adaptor0
,
adaptor1
);
return
adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
,
0
,
0
,
0
,
0
));
}
template
<
typename
CThreadDesc_BM0_BM11_BN0_BN11
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
CThreadDesc_BM0_BM11_BN0_BN11
&
,
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
CThreadDesc_BM0_BM11_BN0_BN11
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
// TODO: remove this restriction
static_assert
(
BM0
==
2
&&
BN0
==
2
&&
CThreadDesc_BM0_BM11_BN0_BN11
{}.
GetLength
(
I0
)
==
BM0
&&
CThreadDesc_BM0_BM11_BN0_BN11
{}.
GetLength
(
I2
)
==
BN0
,
"wrong"
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_thread_desc_bk0_bm0_bm1_bk1_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_bk0_bn0_bn1_bk1_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_contraction
=
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_desc_bk0_bm0_bm1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
CThreadDesc_BM0_BM11_BN0_BN11
,
Sequence
<
BK0PerThread
,
BK1
>
,
Sequence
<
1
,
BM1PerThreadBM11
>
,
Sequence
<
1
,
BN1PerThreadBN11
>>
{};
// read A_sub_0
a_thread_copy_
.
Run
(
a_block_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
b_block_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
b_block_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
a_block_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I1
,
I0
));
// loop over rest of bk0
static_for
<
BK0PerThread
,
BK0
,
BK0PerThread
>
{}([
&
](
auto
bk0
)
{
// read A_sub_0
a_thread_copy_
.
Run
(
a_block_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
bk0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// read B_sub_0
b_thread_copy_
.
Run
(
b_block_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
bk0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I1
,
I0
));
// read B_sub_1
b_thread_copy_
.
Run
(
b_block_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
bk0
,
I1
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_bk0_bn0_bn1_bk1_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
a_block_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
bk0
,
I1
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_bk0_bm0_bm1_bk1_
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I1
,
I0
));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_contraction
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I1
,
I0
));
}
private:
// A[BK0, BM0, BM1, BK1]
static
constexpr
auto
a_thread_desc_bk0_bm0_bm1_bk1_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BM0
>
{},
Number
<
BM1PerThreadBM11
>
{},
Number
<
BK1
>
{}));
// B[BK0, BN0, BN1, BK1]
static
constexpr
auto
b_thread_desc_bk0_bn0_bn1_bk1_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BK0PerThread
>
{},
Number
<
BN0
>
{},
Number
<
BN1PerThreadBN11
>
{},
Number
<
BK1
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_bk0_bm0_bm1_bk1_
),
decltype
(
a_thread_desc_bk0_bm0_bm1_bk1_
),
Sequence
<
BK0PerThread
,
1
,
BM1PerThreadBM11
,
BK1
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder
Sequence
<
1
,
1
,
BM1PerThreadBM11
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4r1
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_bk0_bn0_bn1_bk1_
),
decltype
(
b_thread_desc_bk0_bn0_bn1_bk1_
),
Sequence
<
BK0PerThread
,
1
,
BN1PerThreadBN11
,
BK1
>
,
// SliceLengths
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder
Sequence
<
1
,
1
,
BN1PerThreadBN11
,
BK1
>
,
// SrcVectorTensorLengths
Sequence
<
0
,
1
,
2
,
3
>>
;
// SrcVectorTensorContiguousDimOrder
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_contraction_dlops.hpp"
namespace
ck
{
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. AKMBlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. BKNBlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CM0M1N0N1ThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
// Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AKMBlockDesc
,
typename
BKNBlockDesc
,
index_t
M1PerThreadM11
,
index_t
N1PerThreadN11
,
index_t
KPerThread
,
index_t
M1N1ThreadClusterM100
,
index_t
M1N1ThreadClusterN100
,
index_t
M1N1ThreadClusterM101
,
index_t
M1N1ThreadClusterN101
,
index_t
AThreadCopyScalarPerVector_M11
,
index_t
BThreadCopyScalarPerVector_N11
,
typename
enable_if
<
AKMBlockDesc
::
IsKnownAtCompileTime
()
&&
BKNBlockDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
K
=
AKMBlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
M
=
AKMBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N
=
BKNBlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
M100
=
M1N1ThreadClusterM100
;
static
constexpr
index_t
N100
=
M1N1ThreadClusterN100
;
static
constexpr
index_t
M101
=
M1N1ThreadClusterM101
;
static
constexpr
index_t
N101
=
M1N1ThreadClusterN101
;
static
constexpr
index_t
M11
=
M1PerThreadM11
;
static
constexpr
index_t
N11
=
N1PerThreadN11
;
static
constexpr
index_t
M1
=
M1N1ThreadClusterM100
*
M1N1ThreadClusterM101
*
M1PerThreadM11
;
static
constexpr
index_t
N1
=
M1N1ThreadClusterN100
*
M1N1ThreadClusterN101
*
N1PerThreadN11
;
static
constexpr
index_t
M0
=
M
/
M1
;
static
constexpr
index_t
N0
=
N
/
N1
;
__host__
__device__
static
constexpr
auto
MakeAKM0M1BlockDescriptor
(
const
AKMBlockDesc
&
/* a_k_m_block_desc */
)
{
const
auto
a_k_m0_m1_block_desc
=
transform_tensor_descriptor
(
AKMBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
M0
>
{},
Number
<
M1
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
return
a_k_m0_m1_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeBKN0N1BlockDescriptor
(
const
BKNBlockDesc
&
/* b_k_n_block_desc */
)
{
const
auto
b_k_n0_n1_block_desc
=
transform_tensor_descriptor
(
BKNBlockDesc
{},
make_tuple
(
make_pass_through_transform
(
Number
<
K
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
N0
>
{},
Number
<
N1
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
return
b_k_n0_n1_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor
()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M, N]
constexpr
auto
c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
M0
>
{},
Number
<
M100
>
{},
Number
<
M101
>
{},
Number
<
M11
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
N0
>
{},
Number
<
N100
>
{},
Number
<
N101
>
{},
Number
<
N11
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
,
5
,
6
,
7
>
{}));
return
c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor
;
}
__host__
__device__
static
constexpr
auto
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor
()
{
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
// lower: [M0, M1, N0, N1]
constexpr
auto
c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_pass_through_transform
(
Number
<
M0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
M100
>
{},
Number
<
M101
>
{},
Number
<
M11
>
{})),
make_pass_through_transform
(
Number
<
N0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
N100
>
{},
Number
<
N101
>
{},
Number
<
N11
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
,
6
,
7
>
{}));
return
c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor
;
}
__host__
__device__
static
constexpr
auto
GetCM0M1N0N1ThreadTensorLengths
()
{
return
Sequence
<
M0
,
M11
,
N0
,
N11
>
{};
}
static
constexpr
auto
a_k_m0_m1_block_desc_
=
MakeAKM0M1BlockDescriptor
(
AKMBlockDesc
{});
static
constexpr
auto
b_k_n0_n1_block_desc_
=
MakeBKN0N1BlockDescriptor
(
BKNBlockDesc
{});
public:
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
()
:
c_thread_origin_data_idx_
{
CalculateCM0M1N0N1ThreadOriginOnBlock
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
])},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I2
],
c_thread_origin_data_idx_
[
I3
])}
{
static_assert
(
AKMBlockDesc
::
IsKnownAtCompileTime
()
&&
BKNBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
M101
*
M100
*
N101
*
N100
,
"wrong! blocksize and cluster size not consistent"
);
static_assert
(
M
%
M1
==
0
&&
N
%
N1
==
0
,
"wrong!"
);
static_assert
(
AKMBlockDesc
{}.
GetLength
(
I0
)
==
BKNBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
// TODO: remove this restriction
static_assert
(
M0
==
2
&&
N0
==
2
,
"wrong"
);
}
__device__
static
CIndex
CalculateCM0M1N0N1ThreadOriginOnBlock
(
index_t
thread_id
)
{
// lower: [M0, M1, N0, N1]
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
constexpr
auto
adaptor0
=
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor
();
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
// upper: [Tid, M0, M11, N0, N11]
constexpr
auto
adaptor1
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M100
,
N100
,
M101
,
N101
)),
make_pass_through_transform
(
M0
),
make_pass_through_transform
(
M11
),
make_pass_through_transform
(
N0
),
make_pass_through_transform
(
N11
)),
make_tuple
(
Sequence
<
1
,
5
,
2
,
6
>
{},
Sequence
<
0
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
constexpr
auto
adaptor
=
chain_tensor_adaptors
(
adaptor0
,
adaptor1
);
return
adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
,
0
,
0
,
0
,
0
));
}
__host__
__device__
static
constexpr
index_t
GetABlockAlignment
()
{
return
M1PerThreadM11
;
}
__host__
__device__
static
constexpr
auto
GetBBlockAlignment
()
{
return
N1PerThreadN11
;
}
template
<
typename
CM0M1N0N1ThreadDesc
,
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
CM0M1N0N1ThreadDesc
&
/* c_m0_m1_n0_n1_thread_desc */
,
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
CM0M1N0N1ThreadDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
// TODO: remove this restriction
static_assert
(
M0
==
2
&&
N0
==
2
&&
CM0M1N0N1ThreadDesc
{}.
GetLength
(
I0
)
==
M0
&&
CM0M1N0N1ThreadDesc
{}.
GetLength
(
I2
)
==
N0
,
"wrong"
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_k_m0_m1_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_k_n0_n1_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_k_m0_m1_thread_desc_
),
decltype
(
b_k_n0_n1_thread_desc_
),
CM0M1N0N1ThreadDesc
,
Sequence
<
KPerThread
>
,
Sequence
<
1
,
M1PerThreadM11
>
,
Sequence
<
1
,
N1PerThreadN11
>>
{};
// read A_sub_0
a_thread_copy_
.
Run
(
a_k_m0_m1_block_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_block_buf
,
a_k_m0_m1_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
// read B_sub_0
b_thread_copy_
.
Run
(
b_k_n0_n1_block_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_block_buf
,
b_k_n0_n1_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
// read B_sub_1
b_thread_copy_
.
Run
(
b_k_n0_n1_block_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_block_buf
,
b_k_n0_n1_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
a_k_m0_m1_block_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_block_buf
,
a_k_m0_m1_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I1
,
I0
));
// loop over rest of k
static_for
<
KPerThread
,
K
,
KPerThread
>
{}([
&
](
auto
k
)
{
// read A_sub_0
a_thread_copy_
.
Run
(
a_k_m0_m1_block_desc_
,
make_tuple
(
k
,
I0
,
I0
),
a_block_buf
,
a_k_m0_m1_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// read B_sub_0
b_thread_copy_
.
Run
(
b_k_n0_n1_block_desc_
,
make_tuple
(
k
,
I0
,
I0
),
b_block_buf
,
b_k_n0_n1_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I1
,
I0
));
// read B_sub_1
b_thread_copy_
.
Run
(
b_k_n0_n1_block_desc_
,
make_tuple
(
k
,
I1
,
I0
),
b_block_buf
,
b_k_n0_n1_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
);
// read A_sub_1
a_thread_copy_
.
Run
(
a_k_m0_m1_block_desc_
,
make_tuple
(
k
,
I1
,
I0
),
a_block_buf
,
a_k_m0_m1_thread_desc_
,
make_tuple
(
I0
,
I1
,
I0
),
a_thread_buf
);
// C_sub_00 += transpose(A_sub_0) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
// C_sub_01 += transpose(A_sub_0) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I1
,
I0
));
});
// C_sub_10 += transpose(A_sub_1) * B_sub_0
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I0
,
I0
));
// C_sub_11 += transpose(A_sub_1) * B_sub_1
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I1
,
I0
),
c_thread_buf
,
make_tuple
(
I1
,
I0
,
I1
,
I0
));
}
private:
// A[K, M0, M1]
static
constexpr
auto
a_k_m0_m1_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
M0
>
{},
Number
<
M1PerThreadM11
>
{}));
// B[K, N0, N1]
static
constexpr
auto
b_k_n0_n1_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
N0
>
{},
Number
<
N1PerThreadN11
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_k_m0_m1_block_desc_
),
decltype
(
a_k_m0_m1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
M1PerThreadM11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
AThreadCopyScalarPerVector_M11
,
1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_k_n0_n1_block_desc_
),
decltype
(
b_k_n0_n1_thread_desc_
),
Sequence
<
KPerThread
,
1
,
N1PerThreadN11
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BThreadCopyScalarPerVector_N11
,
1
>
;
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#include "common_header.hpp"
#include "threadwise_gemm_dlops_v3.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc_E1_K1_E2
,
typename
BBlockDesc_E1_N_Ho_Wo_E2
,
typename
CThreadDesc_K_N_Ho_Wo
,
index_t
EPerThreadLoop
,
index_t
KPerThreadLoop
>
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
E1
=
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I0
);
static
constexpr
auto
KPerBlock
=
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I1
);
static
constexpr
auto
E2
=
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I2
);
static
constexpr
auto
HoPerBlock
=
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I2
);
static
constexpr
auto
WoPerBlock
=
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I3
);
static
constexpr
auto
KPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I0
);
static
constexpr
auto
HoPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I2
);
static
constexpr
auto
WoPerThread
=
CThreadDesc_K_N_Ho_Wo
{}.
GetLength
(
I3
);
static
constexpr
auto
a_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadLoop
>
{},
Number
<
E2
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{},
Number
<
E2
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
:
c_thread_origin_data_idx_
{
GetBeginOfCThreadDesc_K_N_Ho_Wo
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
]
*
KPerThread
,
0
)}
{
static_assert
(
ABlockDesc_E1_K1_E2
::
IsKnownAtCompileTime
()
&&
BBlockDesc_E1_N_Ho_Wo_E2
::
IsKnownAtCompileTime
()
&&
CThreadDesc_K_N_Ho_Wo
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I0
)
==
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I0
)
&&
ABlockDesc_E1_K1_E2
{}.
GetLength
(
I2
)
==
BBlockDesc_E1_N_Ho_Wo_E2
{}.
GetLength
(
I4
),
"wrong! E dimension not consistent
\n
"
);
static_assert
(
E1
%
EPerThreadLoop
==
0
,
""
);
static_assert
(
KPerThread
%
KPerThreadLoop
==
0
,
""
);
static_assert
(
KPerBlock
%
KPerThread
==
0
&&
HoPerBlock
%
HoPerThread
==
0
&&
WoPerBlock
%
WoPerThread
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
constexpr
auto
KThreadCluster
=
KPerBlock
/
KPerThread
;
constexpr
auto
HThreadCluster
=
HoPerBlock
/
HoPerThread
;
constexpr
auto
WThreadCluster
=
WoPerBlock
/
WoPerThread
;
static_assert
(
BlockSize
==
KThreadCluster
*
HThreadCluster
*
WThreadCluster
,
"wrong! wrong blocksize
\n
"
);
}
__device__
static
constexpr
auto
GetCThreadDesc_K_N_Ho_WoLengths
()
{
return
Sequence
<
KPerThread
,
I1
,
HoPerThread
,
WoPerThread
>
{};
}
__device__
static
CIndex
GetBeginOfCThreadDesc_K_N_Ho_Wo
(
index_t
thread_id
)
{
constexpr
auto
K0
=
KPerBlock
/
KPerThread
;
constexpr
auto
N0
=
I1
;
constexpr
auto
H0
=
HoPerBlock
/
HoPerThread
;
constexpr
auto
W0
=
WoPerBlock
/
WoPerThread
;
constexpr
auto
c_threadid_to_k_n_h_w_thread_cluster_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
K0
,
N0
,
H0
,
W0
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
c_k_n_h_w_thread_cluster_idx
=
c_threadid_to_k_n_h_w_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
return
c_k_n_h_w_thread_cluster_idx
;
}
template
<
typename
ABlockBuffer
,
typename
BThreadBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BThreadBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABlockBuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
a_block_mtx
=
ABlockDesc_E1_K1_E2
{};
// thread A buffer for GEMM
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
(),
true
>
a_thread_buf
;
constexpr
auto
threadwise_gemm
=
ThreadwiseGemmDlops_km_kn_mn_v3
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_mtx_
),
decltype
(
b_thread_mtx_
),
decltype
(
c_thread_mtx_
)
>
{};
static_for
<
0
,
E1
,
EPerThreadLoop
>
{}([
&
](
auto
e_begin
)
{
static_for
<
0
,
KPerThread
,
KPerThreadLoop
>
{}([
&
](
auto
k_begin
)
{
a_thread_copy_
.
Run
(
a_block_mtx
,
make_tuple
(
e_begin
,
k_begin
,
I0
),
a_block_buf
,
a_thread_mtx_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
e_begin
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
k_begin
,
I0
,
I0
,
I0
));
});
});
}
template
<
typename
ABlockSliceMoveStepIdx
>
__device__
void
MoveABlockSliceWindow
(
const
ABlockSliceMoveStepIdx
&
a_block_slice_move_step_idx
)
{
a_thread_copy_
.
MoveSrcSliceWindow
(
ABlockDesc_E1_K1_E2
{},
a_block_slice_move_step_idx
);
}
private:
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc_E1_K1_E2
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadLoop
,
E2
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
E2
,
E2
>
;
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
enum
struct
LoopScheduler
{
Default
,
Interwave
,
};
constexpr
LoopScheduler
make_default_loop_scheduler
()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return
LoopScheduler
::
Interwave
;
#else
return
LoopScheduler
::
Default
;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
}
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
TileDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
(
const
TileDesc_K0_MN_K1
&
)
{
constexpr
index_t
K0
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
0
>
{});
constexpr
index_t
K1
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
2
>
{});
return
transform_tensor_descriptor
(
TileDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
KPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
*
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPerThread
*
xdlops_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KPerThread
*
xdlops_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk4D
(
xdlops_i
,
blk_i
);
return
make_tuple
(
Number
<
m0
>
{},
Number
<
n0
>
{},
waveId_m
,
waveId_n
,
blk_idx
[
I0
],
blk_idx
[
I1
],
blk_idx
[
I2
],
blk_idx
[
I3
]);
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_G_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
auto
c_grid_desc_g_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_g_m_n
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
2
,
4
,
6
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_M0_M1_M2_K
()
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_N0_N1_N2_K
()
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
B_K0
>
{},
Number
<
B_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
static
constexpr
auto
b_block_desc_n0_n1_n2_k
=
MakeBBlockDescriptor_N0_N1_N2_K
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
protected:
// A[M0, M1, M2, KPerThread]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
// B[N0, N1, N2, KPerThread]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
};
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
index_t
NumMacClusters
=
CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
>
struct
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
:
public
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{
using
Base
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
;
#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
using
Base
::
a_block_desc_m0_m1_m2_k
;
using
Base
::
A_K1
;
using
Base
::
b_block_desc_n0_n1_n2_k
;
using
Base
::
B_K1
;
using
Base
::
c_thread_buf_
;
using
Base
::
c_thread_desc_
;
using
Base
::
CalculateAThreadOriginDataIndex
;
using
Base
::
CalculateBThreadOriginDataIndex
;
using
Base
::
I0
;
using
Base
::
I1
;
using
Base
::
KPerThread
;
using
Base
::
xdlops_gemm
;
static
constexpr
index_t
KPerInnerLoop
=
math
::
max
(
KPerThread
/
NumMacClusters
,
KPack
);
// 2-wave optimized blockwise gemm
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
KPerThread
,
KPerInnerLoop
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
k
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
k
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
// the first, as we can shorten non-MAC cluster a bit and there's no observable negative
// impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
// some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
// chance of latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if
constexpr
(
k
.
value
!=
0
||
KPerInnerLoop
==
KPerThread
)
{
asm
volatile
(
"s_barrier"
::
);
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
0
,
0
,
k_
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
0
,
0
,
k_
+
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from blockwise_gemm is
// moved here B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts It is performed near the end of MAC cluster to
// minimize lgkmcnt penalty
if
constexpr
(
k
.
value
==
KPerThread
-
KPerInnerLoop
&&
k_
.
value
==
KPerInnerLoop
-
KPack
&&
m0
.
value
==
MRepeat
-
1
&&
n0
.
value
==
NRepeat
-
1
)
{
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_setprio
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
});
}
protected:
// A[M0, M1, M2, KPerInnerLoop]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
KPerInnerLoop
>
{}));
// B[N0, N1, N2, KPerInnerLoop]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
KPerInnerLoop
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerInnerLoop
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
};
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
LoopScheduler
LoopSched
>
constexpr
auto
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
()
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
return
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
return
BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerXDL
,
NPerXDL
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
};
// Blockwise gemm supporting
// 1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4
// 2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS
// source buffer
// 3. configurable k index starting position and step size after each FMA/XDL instruction
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
ATileDesc
,
typename
BTileDesc
,
typename
AMmaTileDesc
,
typename
BMmaTileDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
TransposeC
=
false
,
index_t
AMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>{}.
K0PerXdlops
,
index_t
BMmaKStride
=
KPack
*
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{}.
K0PerXdlops
>
struct
BlockwiseGemmXdlops_v2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
A_K0
=
ATileDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BTileDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
ATileDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BTileDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
static_assert
(
KPerThread
%
KPack
==
0
,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack"
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPack
*
xdlops_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KPack
*
xdlops_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk4D
(
xdlops_i
,
blk_i
);
return
make_tuple
(
m0
,
n0
,
waveId_m
,
waveId_n
,
blk_idx
[
I0
],
blk_idx
[
I1
],
blk_idx
[
I2
],
blk_idx
[
I3
]);
}
using
Tuple4
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmXdlops_v2
(
Tuple4
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
AMmaTileDesc
::
IsKnownAtCompileTime
()
&&
BMmaTileDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
}
__host__
__device__
BlockwiseGemmXdlops_v2
(
const
BlockwiseGemmXdlops_v2
&
other
)
:
a_thread_copy_
(
other
.
a_origin
),
b_thread_copy_
(
other
.
b_origin
)
{
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
N
,
M0
,
M1
,
M2
));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_G_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
auto
c_grid_desc_g_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_g_m_n
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
2
,
4
,
6
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
static
constexpr
AMmaTileDesc
a_block_desc_m0_m1_m2_k
;
static
constexpr
BMmaTileDesc
b_block_desc_n0_n1_n2_k
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
KPerThread
/
KPack
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k
*
AMmaKStride
>
{}),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
protected:
// A[M0, M1, M2, KPack]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPack
>
{}));
// B[N0, N1, N2, KPack]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPack
>
{}));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPack
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0K0BN0N1N2N3K1BlockDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
KPack
;
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
K0PerThread
=
K0PerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPerThread
*
xdlops_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KPerThread
*
xdlops_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0K0BN0N1N2N3K1BlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_G_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
auto
c_grid_desc_g_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_g_m_n
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
2
,
4
,
6
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_M0_M1_M2_K
()
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
void
MoveABlockSliceWindow
()
{
a_thread_copy_
.
MoveSrcSliceWindow
(
a_block_desc_m0_m1_m2_k
,
make_multi_index
(
0
,
0
,
0
,
K0PerBlock
*
KPack
));
}
__device__
void
ResetABlockStartWindow
()
{
a_thread_copy_
.
SetSrcCoord
(
CalculateAThreadOriginDataIndex
());
}
static
constexpr
auto
a_block_desc_m0_m1_m2_k
=
MakeABlockDescriptor_M0_M1_M2_K
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
constexpr
index_t
k0
=
k
/
KPack
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
k0
,
n0
,
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
private:
// A[M0, M1, M2, KPerThread]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
// B[N0, N1, N2, KPerThread]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0PerThread
>
{},
// KPerThread
Number
<
NRepeat
>
{},
// repeat
Number
<
KPack
>
{}));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_m0_m1_m2_k
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp"
#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
AccDataType
,
typename
ThreadMap_M_K
,
// thread_id to m_k
typename
ThreadClusterDesc_M_K
,
typename
ThreadSliceDesc_M_K
,
bool
IgnoreNaN
=
false
>
struct
BlockwiseSoftmax
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
MRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
);
static
constexpr
index_t
KRepeat
=
ThreadSliceDesc_M_K
{}.
GetLength
(
I1
);
using
ThreadSliceDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
ThreadSliceDesc_M_K
{}.
GetLength
(
I0
))));
using
ThreadwiseMaxReduce
=
typename
conditional
<
IgnoreNaN
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Max
,
false
,
detail
::
AccumulateWithNanIgnore
<
reduce
::
Max
,
AccDataType
>>
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Max
,
false
>>::
type
;
using
ThreadwiseSumReduce
=
typename
conditional
<
IgnoreNaN
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Add
,
false
,
detail
::
AccumulateWithNanIgnore
<
reduce
::
Add
,
AccDataType
>>
,
ThreadwiseReduction
<
AccDataType
,
ThreadSliceDesc_M_K
,
ThreadSliceDesc_M
,
reduce
::
Add
,
false
>>::
type
;
using
ThreadClusterLengths_M_K
=
decltype
(
ThreadClusterDesc_M_K
{}.
GetLengths
());
using
BlockwiseMaxReduce
=
PartitionedBlockwiseReduction_v2
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadMap_M_K
,
reduce
::
Max
,
false
>
;
using
BlockwiseSumReduce
=
PartitionedBlockwiseReduction_v2
<
AccDataType
,
BlockSize
,
ThreadClusterLengths_M_K
,
ThreadMap_M_K
,
reduce
::
Add
,
false
>
;
using
BufferType
=
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MRepeat
,
true
>
;
template
<
typename
CThreadBuffer
,
typename
WorkspaceBuffer
>
__host__
__device__
void
Run
(
CThreadBuffer
&
in_thread_buf
,
WorkspaceBuffer
&
reduce_work_buf
)
{
// find max value
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
max_value_buf
(
I
)
=
reduce
::
Max
::
template
GetIdentityValue
<
AccDataType
>();
});
ThreadwiseMaxReduce
::
Reduce
(
in_thread_buf
,
max_value_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseMaxReduce
::
Reduce
(
reduce_work_buf
,
max_value_buf
(
I
));
block_sync_lds
();
});
// calculate exp for elements, P=exp(s-max)
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
iK
)
{
auto
offset
=
Number
<
ThreadSliceDesc_M_K
{}.
CalculateOffset
(
make_tuple
(
iM
,
iK
))
>
{};
in_thread_buf
(
offset
)
=
IgnoreNaN
&&
ck
::
math
::
isnan
(
in_thread_buf
[
offset
])
?
0
:
math
::
exp
(
in_thread_buf
[
offset
]
-
max_value_buf
(
iM
));
});
});
// sum data
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
sum_value_buf
(
I
)
=
reduce
::
Add
::
template
GetIdentityValue
<
AccDataType
>();
});
ThreadwiseSumReduce
::
Reduce
(
in_thread_buf
,
sum_value_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
I
)
{
BlockwiseSumReduce
::
Reduce
(
reduce_work_buf
,
sum_value_buf
(
I
));
block_sync_lds
();
});
}
BufferType
max_value_buf
;
BufferType
sum_value_buf
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
index_t
BlockSize
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
BlockSliceLengths
,
typename
ThreadSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
typename
SrcVectorTensorLengths
,
typename
DstVectorTensorLengths
,
typename
SrcVectorTensorContiguousDimOrder
,
typename
DstVectorTensorContiguousDimOrder
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
BlockwiseTensorSliceTransfer_v5r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
BlockwiseTensorSliceTransfer_v5r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
())
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
BlockSliceLengths
::
Size
()
&&
nDim
==
ThreadSliceLengths
::
Size
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
ThreadSliceLengths
{}
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
BlockSize
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! BlockSize too small"
);
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
ThreadSliceLengths
{};
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
);
}
}
template
<
typename
DstBuffer
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunWrite
(
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
// SrcMoveSliceWindowStepHack to control index calculation move slice window
template
<
typename
SrcMoveSliceWindowStepHack
>
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
,
const
SrcMoveSliceWindowStepHack
&
src_move_slice_window_step_hack
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
,
src_move_slice_window_step_hack
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v5r1
<
ThreadSliceLengths
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorTensorLengths
,
DstVectorTensorLengths
,
SrcVectorTensorContiguousDimOrder
,
DstVectorTensorContiguousDimOrder
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_welford.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp"
namespace
ck
{
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace
// 2) work_buffer has T elements, and space size is no less than 3*BlockSize
// 3) mean_value, var_value and count is the input data in vgpr from each thread
// 4) mean_value, var_value and count is the over-written reduced output in vgpr for each thread
// 5) Merge mean and M from ThreadwiseWelford
// clang-format on
template
<
typename
T
,
index_t
BlockSize
,
typename
ThreadClusterLengths_M_K
,
typename
ThreadClusterArrangeOrder
,
bool
GetActualVariance
=
true
>
struct
BlockwiseWelford
{
static_assert
(
BlockSize
==
ThreadClusterLengths_M_K
::
At
(
0
)
*
ThreadClusterLengths_M_K
::
At
(
1
),
"The product of cluster lengths should be same as BlockSize!"
);
static
constexpr
auto
BufferLength_M
=
ThreadClusterLengths_M_K
::
At
(
0
);
static
constexpr
auto
BufferLength_K
=
ThreadClusterLengths_M_K
::
At
(
1
);
static
constexpr
auto
block_buf_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BufferLength_M
>
{},
Number
<
BufferLength_K
>
{}));
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
__device__
static
inline
void
Merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
{
int
count
=
count_a
+
count_b
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
type_convert
<
T
>
(
count_b
)
/
count
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
var_a
+=
var_b
+
delta
*
delta
*
count_a
*
count_b_over_count
;
count_a
=
count
;
}
__device__
static
void
Run
(
T
&
mean_value
,
T
&
var_value
,
int
&
count
)
{
__shared__
T
mean_block_buf
[
BlockSize
];
__shared__
T
var_block_buf
[
BlockSize
];
__shared__
int
count_block_buf
[
BlockSize
];
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
index_t
offset1
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
);
mean_block_buf
[
offset1
]
=
mean_value
;
var_block_buf
[
offset1
]
=
var_value
;
count_block_buf
[
offset1
]
=
count
;
block_sync_lds
();
static_for
<
0
,
cluster_len_shift
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
indOffset
=
1
<<
(
cluster_len_shift
-
1
-
I
());
if
(
thread_k_cluster_id
<
indOffset
)
{
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
T
mean1
=
mean_block_buf
[
offset1
];
T
var1
=
var_block_buf
[
offset1
];
int
count1
=
count_block_buf
[
offset1
];
T
mean2
=
mean_block_buf
[
offset2
];
T
var2
=
var_block_buf
[
offset2
];
int
count2
=
count_block_buf
[
offset2
];
Merge
(
mean1
,
var1
,
count1
,
mean2
,
var2
,
count2
);
mean_block_buf
[
offset1
]
=
mean1
;
var_block_buf
[
offset1
]
=
var1
;
count_block_buf
[
offset1
]
=
count1
;
}
block_sync_lds
();
});
index_t
offset
=
block_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
));
count
=
count_block_buf
[
offset
];
mean_value
=
mean_block_buf
[
offset
];
if
constexpr
(
GetActualVariance
)
var_value
=
var_block_buf
[
offset
]
/
count
;
else
var_value
=
var_block_buf
[
offset
];
};
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
// 3) in_out_value is the input data in vgpr from each thread
// 4) in_out_value is the over-written reduced output in vgpr for each thread
// clang-format on
template
<
typename
AccDataType
,
index_t
BlockSize
,
typename
ThreadClusterLengths_M_K
,
typename
ThreadClusterArrangeOrder
,
typename
OpReduce
,
bool
PropagateNan
,
typename
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
>
>
struct
PartitionedBlockwiseReduction
{
static_assert
(
BlockSize
==
ThreadClusterLengths_M_K
::
At
(
0
)
*
ThreadClusterLengths_M_K
::
At
(
1
),
"The product of cluster lengths should be same as BlockSize!"
);
static
constexpr
auto
BufferLength_M
=
ThreadClusterLengths_M_K
::
At
(
0
);
static
constexpr
auto
BufferLength_K
=
ThreadClusterLengths_M_K
::
At
(
1
);
static_assert
(
BufferLength_K
>
1
,
"Parallel reduction need work on at least two elements"
);
static
constexpr
auto
block_buf_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BufferLength_M
>
{},
Number
<
BufferLength_K
>
{}));
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
template
<
typename
BufferType
>
__device__
static
void
Reduce
(
BufferType
&
work_buffer
,
AccDataType
&
in_out_value
)
{
static_assert
(
is_same
<
typename
BufferType
::
type
,
AccDataType
>
{},
"Buffer data type should be consistent as AccDataType!"
);
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
work_buffer
(
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
))
=
in_out_value
;
__syncthreads
();
static_for
<
0
,
cluster_len_shift
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
indOffset
=
1
<<
(
cluster_len_shift
-
1
-
I
());
if
(
thread_k_cluster_id
<
indOffset
)
{
index_t
offset1
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
);
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
AccDataType
opData1
=
work_buffer
[
offset1
];
AccDataType
opData2
=
work_buffer
[
offset2
];
Accumulation
::
Calculate
(
opData1
,
opData2
);
work_buffer
(
offset1
)
=
opData1
;
}
__syncthreads
();
});
index_t
offset
=
block_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
));
in_out_value
=
work_buffer
[
offset
];
};
};
// clang-format off
// Assume:
// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize
// 3) in_out_value is the input data in vgpr from each thread
// 4) in_out_value is the over-written reduced output in vgpr for each thread
// clang-format on
template
<
typename
AccDataType
,
index_t
BlockSize
,
typename
ThreadClusterLengths_M_K
,
typename
ThreadClusterDesc
,
typename
OpReduce
,
bool
PropagateNan
,
typename
Accumulation
=
detail
::
AccumulateWithNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
>
>
struct
PartitionedBlockwiseReduction_v2
{
static_assert
(
BlockSize
==
ThreadClusterLengths_M_K
::
At
(
0
)
*
ThreadClusterLengths_M_K
::
At
(
1
),
"The product of cluster lengths should be same as BlockSize!"
);
static
constexpr
auto
BufferLength_M
=
ThreadClusterLengths_M_K
::
At
(
0
);
static
constexpr
auto
BufferLength_K
=
ThreadClusterLengths_M_K
::
At
(
1
);
static_assert
(
BufferLength_K
>
1
,
"Parallel reduction need work on at least two elements"
);
static
constexpr
auto
block_buf_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BufferLength_M
>
{},
Number
<
BufferLength_K
>
{}));
static
constexpr
auto
thread_cluster_desc
=
ThreadClusterDesc
{};
template
<
typename
BufferType
>
__device__
static
void
Reduce
(
BufferType
&
work_buffer
,
AccDataType
&
in_out_value
)
{
static_assert
(
is_same
<
typename
BufferType
::
type
,
AccDataType
>
{},
"Buffer data type should be consistent as AccDataType!"
);
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
work_buffer
(
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
))
=
in_out_value
;
__syncthreads
();
static_for
<
0
,
cluster_len_shift
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
indOffset
=
1
<<
(
cluster_len_shift
-
1
-
I
());
if
(
thread_k_cluster_id
<
indOffset
)
{
index_t
offset1
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
);
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
AccDataType
opData1
=
work_buffer
[
offset1
];
AccDataType
opData2
=
work_buffer
[
offset2
];
Accumulation
::
Calculate
(
opData1
,
opData2
);
work_buffer
(
offset1
)
=
opData1
;
}
__syncthreads
();
});
index_t
offset
=
block_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
));
in_out_value
=
work_buffer
[
offset
];
};
};
// clang-format off
// Assume:
// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data
// 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize
// 3) in_out_value/in_out_index is the input data in vgpr from each thread
// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread
// clang-format on
template
<
typename
AccDataType
,
typename
IndexDataType
,
index_t
BlockSize
,
typename
ThreadClusterLengths_M_K
,
typename
ThreadClusterArrangeOrder
,
typename
OpReduce
,
bool
PropagateNan
,
typename
Accumulation
=
detail
::
AccumulateWithIndexAndNanCheck
<
PropagateNan
,
OpReduce
,
AccDataType
,
IndexDataType
>
>
struct
PartitionedBlockwiseReductionWithIndex
{
static_assert
(
BlockSize
==
ThreadClusterLengths_M_K
::
At
(
0
)
*
ThreadClusterLengths_M_K
::
At
(
1
),
"The product of cluster lengths should be same as BlockSize!"
);
static
constexpr
auto
BufferLength_M
=
ThreadClusterLengths_M_K
::
At
(
0
);
static
constexpr
auto
BufferLength_K
=
ThreadClusterLengths_M_K
::
At
(
1
);
static_assert
(
BufferLength_K
>
1
,
"Parallel reduction need work on at least two elements"
);
static
constexpr
auto
block_buf_desc_m_k
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
BufferLength_M
>
{},
Number
<
BufferLength_K
>
{}));
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_K
{},
ThreadClusterArrangeOrder
{});
// This interface accumulates on both data values and indices
template
<
typename
BufferType
,
typename
IdxBufferType
>
__device__
static
void
Reduce
(
BufferType
&
work_val_buffer
,
IdxBufferType
&
work_idx_buffer
,
AccDataType
&
in_out_value
,
IndexDataType
&
in_out_index
)
{
static_assert
(
is_same
<
typename
BufferType
::
type
,
AccDataType
>
{},
"Buffer data type should be consistent as AccDataType!"
);
static_assert
(
is_same
<
typename
IdxBufferType
::
type
,
IndexDataType
>
{},
"Buffer data type should be consistent as IndexDataType!"
);
constexpr
auto
cluster_len_shift
=
get_shift
<
BufferLength_K
>
();
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
Number
<
0
>
{}];
const
auto
thread_k_cluster_id
=
thread_cluster_idx
[
Number
<
1
>
{}];
work_val_buffer
(
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
))
=
in_out_value
;
work_idx_buffer
(
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
))
=
in_out_index
;
__syncthreads
();
static_for
<
0
,
cluster_len_shift
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
indOffset
=
1
<<
I
();
if
(
thread_k_cluster_id
%
(
indOffset
*
2
)
==
0
)
{
index_t
offset1
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
);
index_t
offset2
=
block_buf_desc_m_k
.
CalculateOffset
(
thread_cluster_idx
+
make_tuple
(
0
,
indOffset
));
AccDataType
opData1
=
work_val_buffer
[
offset1
];
AccDataType
opData2
=
work_val_buffer
[
offset2
];
IndexDataType
currIndex1
=
work_idx_buffer
[
offset1
];
IndexDataType
currIndex2
=
work_idx_buffer
[
offset2
];
Accumulation
::
Calculate
(
opData1
,
opData2
,
currIndex1
,
currIndex2
);
work_val_buffer
(
offset1
)
=
opData1
;
work_idx_buffer
(
offset1
)
=
currIndex1
;
}
__syncthreads
();
});
index_t
offset
=
block_buf_desc_m_k
.
CalculateOffset
(
make_tuple
(
thread_m_cluster_id
,
0
));
in_out_value
=
work_val_buffer
[
offset
];
in_out_index
=
work_idx_buffer
[
offset
];
};
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
typename
SrcElementwiseOperation
,
typename
DstElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
BlockSliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
SrcDimAccessOrder
,
typename
DstDimAccessOrder
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
SrcScalarPerVector
,
index_t
DstScalarPerVector
,
index_t
SrcScalarStrideInVector
,
index_t
DstScalarStrideInVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
,
index_t
NumThreadScratch
=
1
>
struct
ThreadGroupTensorSliceTransfer_v4r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
BlockSliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v4r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
SrcElementwiseOperation
&
src_element_op
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
DstElementwiseOperation
&
dst_element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
src_element_op
,
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
SrcDimAccessOrder
::
Size
()
&&
nDim
==
DstDimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
BlockSliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
thread_scratch_id
);
}
}
template
<
typename
DstBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
=
Number
<
ThreadScratchId
>
{})
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunWrite
(
dst_desc
,
dst_buf
,
thread_scratch_id
);
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
index_t
ThreadScratchId
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
,
Number
<
ThreadScratchId
>
thread_scratch_id
)
{
RunRead
(
src_desc
,
src_buf
,
thread_scratch_id
);
RunWrite
(
dst_desc
,
dst_buf
,
thread_scratch_id
);
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v3r1
<
decltype
(
thread_slice_lengths
),
SrcElementwiseOperation
,
DstElementwiseOperation
,
DstInMemOp
,
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
SrcDimAccessOrder
,
DstDimAccessOrder
,
SrcVectorDim
,
DstVectorDim
,
SrcScalarPerVector
,
DstScalarPerVector
,
SrcScalarStrideInVector
,
DstScalarStrideInVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
,
NumThreadScratch
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
0 → 100644
View file @
78e355fd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
DstInMemOp
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
ThreadGroupTensorSliceTransfer_v6r1
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v6r1
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
Run
(
src_desc
,
src_buf
,
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v6r1
<
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
DstInMemOp
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
Prev
1
…
12
13
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment