Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5a9c4962
Commit
5a9c4962
authored
Apr 24, 2024
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2
parents
3970cf73
43879b89
Changes
269
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5781 additions
and
0 deletions
+5781
-0
include/ck_tile/core/tensor/null_tile_window.hpp
include/ck_tile/core/tensor/null_tile_window.hpp
+88
-0
include/ck_tile/core/tensor/shuffle_tile.hpp
include/ck_tile/core/tensor/shuffle_tile.hpp
+177
-0
include/ck_tile/core/tensor/slice_tile.hpp
include/ck_tile/core/tensor/slice_tile.hpp
+92
-0
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+190
-0
include/ck_tile/core/tensor/store_tile.hpp
include/ck_tile/core/tensor/store_tile.hpp
+93
-0
include/ck_tile/core/tensor/sweep_tile.hpp
include/ck_tile/core/tensor/sweep_tile.hpp
+30
-0
include/ck_tile/core/tensor/tensor_adaptor.hpp
include/ck_tile/core/tensor/tensor_adaptor.hpp
+945
-0
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
+257
-0
include/ck_tile/core/tensor/tensor_coordinate.hpp
include/ck_tile/core/tensor/tensor_coordinate.hpp
+92
-0
include/ck_tile/core/tensor/tensor_descriptor.hpp
include/ck_tile/core/tensor/tensor_descriptor.hpp
+467
-0
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+281
-0
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+759
-0
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
+760
-0
include/ck_tile/core/tensor/tile_elementwise.hpp
include/ck_tile/core/tensor/tile_elementwise.hpp
+263
-0
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+740
-0
include/ck_tile/core/utility/bit_cast.hpp
include/ck_tile/core/utility/bit_cast.hpp
+19
-0
include/ck_tile/core/utility/functional.hpp
include/ck_tile/core/utility/functional.hpp
+208
-0
include/ck_tile/core/utility/ignore.hpp
include/ck_tile/core/utility/ignore.hpp
+22
-0
include/ck_tile/core/utility/magic_div.hpp
include/ck_tile/core/utility/magic_div.hpp
+240
-0
include/ck_tile/core/utility/random.hpp
include/ck_tile/core/utility/random.hpp
+58
-0
No files found.
Too many changes to show.
To preserve performance only
269 of 269+
files are displayed.
Plain diff
Email patch
include/ck_tile/core/tensor/null_tile_window.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tensor_view.hpp"
namespace
ck_tile
{
// placeholder type if we want to opt-out a tile window parameter
template
<
typename
WindowLengths_
>
struct
null_tile_window
{
using
BottomTensorView
=
null_tensor_view
;
using
WindowLengths
=
remove_cvref_t
<
WindowLengths_
>
;
using
BottomTensorIndex
=
array
<
index_t
,
WindowLengths
::
size
()
>
;
CK_TILE_DEVICE
constexpr
null_tile_window
()
=
default
;
CK_TILE_DEVICE
constexpr
null_tile_window
(
const
WindowLengths
&
window_lengths
)
:
window_lengths_
{
window_lengths
}
{
}
CK_TILE_DEVICE
constexpr
auto
get_window_lengths
()
const
{
return
window_lengths_
;
}
CK_TILE_DEVICE
constexpr
auto
get_bottom_tensor_view
()
const
{
return
null_tensor_view
{};
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
BottomTensorIndex
{};
}
WindowLengths
window_lengths_
;
};
// utility to check if this is a Null Tile Window
namespace
impl
{
template
<
typename
>
struct
is_null_tile_window
:
public
std
::
false_type
{
};
template
<
typename
T
>
struct
is_null_tile_window
<
null_tile_window
<
T
>>
:
public
std
::
true_type
{
};
}
// namespace impl
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
auto
is_null_tile_window
(
const
T
&
)
{
return
impl
::
is_null_tile_window
<
remove_cvref_t
<
T
>>::
value
;
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
constexpr
auto
make_null_tile_window
(
const
WindowLengths
&
window_lengths
)
{
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths
>::
value
,
"wrong! lengths should be static"
);
return
null_tile_window
<
remove_cvref_t
<
WindowLengths
>>
{
window_lengths
};
}
template
<
typename
WindowLengths
,
typename
...
Ts
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
null_tensor_view
,
const
WindowLengths
&
window_lengths
,
const
multi_index
<
WindowLengths
::
size
()
>&
/*origin*/
,
Ts
&&
...)
{
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths
>::
value
,
"wrong! lengths should be static"
);
return
null_tile_window
<
remove_cvref_t
<
WindowLengths
>>
{
window_lengths
};
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
void
move_tile_window
(
null_tile_window
<
WindowLengths
>&
,
const
typename
null_tile_window
<
WindowLengths
>::
BottomTensorIndex
&
)
{
}
}
// namespace ck_tile
include/ck_tile/core/tensor/shuffle_tile.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
namespace
ck_tile
{
namespace
detail
{
template
<
typename
OutTensor
,
typename
InTensor
>
CK_TILE_DEVICE
void
shuffle_tile_impl_in_thread
(
OutTensor
&
out_tensor
,
const
InTensor
&
in_tensor
)
{
constexpr
auto
I0
=
number
<
0
>
{};
using
DataType
=
typename
InTensor
::
DataType
;
constexpr
auto
y_in_desc
=
InTensor
::
get_tile_distribution
().
get_ys_to_d_descriptor
();
constexpr
auto
y_out_desc
=
OutTensor
::
get_tile_distribution
().
get_ys_to_d_descriptor
();
// y_dim_out_to_in
constexpr
auto
get_rh_major_minor_to_y
=
[](
auto
dstr_tensor
)
{
using
DstrEncode
=
typename
decltype
(
dstr_tensor
.
get_tile_distribution
())
::
DstrEncode
;
map
<
array
<
index_t
,
2
>
,
index_t
>
rh_major_minor_to_y_
;
static_for
<
0
,
DstrEncode
::
NDimY
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
rh_major
=
DstrEncode
::
ys_to_rhs_major_
[
i
];
constexpr
index_t
rh_minor
=
DstrEncode
::
ys_to_rhs_minor_
[
i
];
rh_major_minor_to_y_
({
rh_major
,
rh_minor
})
=
i
;
});
return
rh_major_minor_to_y_
;
};
constexpr
auto
rh_major_minor_to_y_in
=
get_rh_major_minor_to_y
(
InTensor
{});
constexpr
auto
rh_major_minor_to_y_out
=
get_rh_major_minor_to_y
(
OutTensor
{});
constexpr
auto
y_dim_out_to_in
=
[
&
]
{
map
<
index_t
,
index_t
>
y_dim_out_to_in_
;
for
(
const
auto
&
[
rh_major_minor
,
y_out
]
:
rh_major_minor_to_y_out
)
{
y_dim_out_to_in_
(
y_out
)
=
rh_major_minor_to_y_in
[
rh_major_minor
];
}
return
y_dim_out_to_in_
;
}();
//
constexpr
index_t
NDimY
=
InTensor
::
get_tile_distribution
().
get_num_of_dimension_y
();
constexpr
auto
y_lengths
=
to_sequence
(
y_in_desc
.
get_lengths
());
// input and output vector dim in the order of input Y dims
constexpr
index_t
y_dim_vec_in
=
NDimY
-
1
;
constexpr
index_t
y_dim_vec_out
=
y_dim_out_to_in
[
NDimY
-
1
];
// vector lengths
constexpr
index_t
vec_length_in
=
y_lengths
[
y_dim_vec_in
];
constexpr
index_t
vec_length_out
=
y_lengths
[
y_dim_vec_out
];
// # of vectors
constexpr
index_t
num_vec_in
=
vec_length_out
;
constexpr
index_t
num_vec_out
=
vec_length_in
;
using
InVec
=
array
<
DataType
,
vec_length_in
>
;
using
OutVec
=
array
<
DataType
,
vec_length_out
>
;
// using InVec = typename InVec::type;
// using OutVec = typename OutVec::type;
// SFC
constexpr
auto
scalars_per_access_arr
=
generate_array
(
[
&
](
auto
i
)
{
return
(
i
==
y_dim_vec_in
or
i
==
y_dim_vec_out
)
?
y_lengths
[
i
]
:
1
;
},
number
<
NDimY
>
{});
constexpr
auto
scalars_per_access
=
TO_SEQUENCE
(
scalars_per_access_arr
,
NDimY
);
using
SFC_Y
=
space_filling_curve
<
decltype
(
y_lengths
),
typename
arithmetic_sequence_gen
<
0
,
NDimY
,
1
>::
type
,
decltype
(
scalars_per_access
)
>
;
constexpr
index_t
num_access
=
SFC_Y
::
get_num_of_access
();
static_assert
(
num_access
>
0
,
"wrong! num_access should be larger than 0"
);
// in/out vectors to be transposed
thread_buffer
<
InVec
,
num_vec_in
>
in_vectors
;
thread_buffer
<
OutVec
,
num_vec_out
>
out_vectors
;
// loop over SFC and do transpose
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
// data index [y0, y1, ...] in the order of input tensor
constexpr
auto
idx_y_start
=
SFC_Y
::
get_index
(
iAccess
);
// get input vectors
static_for
<
0
,
num_vec_in
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
idx_y_in
=
generate_array
(
[
&
](
auto
ii
)
{
return
ii
==
y_dim_vec_out
?
idx_y_start
[
ii
]
+
i
:
idx_y_start
[
ii
];
},
number
<
NDimY
>
{});
constexpr
index_t
in_offset
=
y_in_desc
.
calculate_offset
(
idx_y_in
);
static_assert
(
in_offset
%
vec_length_in
==
0
);
in_vectors
(
i
).
template
get_as
<
InVec
>()(
I0
)
=
in_tensor
.
get_thread_buffer
()
.
template
get_as
<
InVec
>()[
number
<
in_offset
/
vec_length_in
>
{}];
});
// transpose
transpose_vectors
<
DataType
,
num_vec_in
,
num_vec_out
>
{}(
in_vectors
,
out_vectors
);
// set output vectors
static_for
<
0
,
num_vec_out
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
idx_y_out_tmp
=
generate_array
(
[
&
](
auto
ii
)
{
return
ii
==
y_dim_vec_in
?
idx_y_start
[
ii
]
+
i
:
idx_y_start
[
ii
];
},
number
<
NDimY
>
{});
constexpr
auto
idx_y_out
=
container_reorder_given_new2old
(
idx_y_out_tmp
,
y_dim_out_to_in
);
constexpr
index_t
out_offset
=
y_out_desc
.
calculate_offset
(
idx_y_out
);
static_assert
(
out_offset
%
vec_length_out
==
0
);
out_tensor
.
get_thread_buffer
().
template
set_as
<
OutVec
>(
number
<
out_offset
/
vec_length_out
>
{},
out_vectors
[
i
].
template
get_as
<
OutVec
>()[
I0
]);
});
});
}
}
// namespace detail
template
<
typename
OutTensor
,
typename
InTensor
>
CK_TILE_DEVICE
void
shuffle_tile
(
OutTensor
&
out
,
const
InTensor
&
in
)
{
using
InDataType
=
typename
InTensor
::
DataType
;
using
OutDataType
=
typename
OutTensor
::
DataType
;
using
InDstrEncode
=
typename
InTensor
::
StaticTileDistribution
::
DstrEncode
;
using
OutDstrEncode
=
typename
OutTensor
::
StaticTileDistribution
::
DstrEncode
;
// type convert
const
auto
in_tmp
=
tile_elementwise_in
(
type_convert
<
OutDataType
,
InDataType
>
,
in
);
// shuffle
if
constexpr
(
InDstrEncode
::
rs_lengths_
==
OutDstrEncode
::
rs_lengths_
&&
InDstrEncode
::
hs_lengthss_
==
OutDstrEncode
::
hs_lengthss_
&&
InDstrEncode
::
ps_to_rhss_major_
==
OutDstrEncode
::
ps_to_rhss_major_
&&
InDstrEncode
::
ps_to_rhss_minor_
==
OutDstrEncode
::
ps_to_rhss_minor_
&&
InDstrEncode
::
NDimY
==
OutDstrEncode
::
NDimY
)
{
detail
::
shuffle_tile_impl_in_thread
(
out
,
in_tmp
);
}
else
{
// NOT implemented
}
}
}
// namespace ck_tile
include/ck_tile/core/tensor/slice_tile.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
index_t
...
SliceBegins
,
index_t
...
SliceEnds
>
CK_TILE_DEVICE
constexpr
auto
get_slice_tile
(
const
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile
,
sequence
<
SliceBegins
...
>
slice_begins
,
sequence
<
SliceEnds
...
>
slice_ends
)
{
using
TileWindow
=
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>
;
// NOTE: This API will override the origin of the tile window!
static_assert
(
sizeof
...(
SliceBegins
)
==
sizeof
...(
SliceEnds
));
static_assert
(
sizeof
...(
SliceBegins
)
==
TileWindow
::
get_num_of_dimension
());
constexpr
auto
slice_lengths
=
slice_ends
-
slice_begins
;
return
make_tile_window
(
tile
.
get_bottom_tensor_view
(),
sequence_to_tuple_of_number
(
slice_lengths
),
to_multi_index
(
slice_begins
));
}
template
<
typename
DataType_
,
typename
StaticTileDistribution_
,
index_t
...
SliceBegins
,
index_t
...
SliceEnds
>
CK_TILE_DEVICE
constexpr
auto
get_slice_tile
(
const
static_distributed_tensor
<
DataType_
,
StaticTileDistribution_
>&
tile
,
sequence
<
SliceBegins
...
>
slice_begins
,
sequence
<
SliceEnds
...
>
slice_ends
)
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
using
Distribution
=
remove_cvref_t
<
StaticTileDistribution_
>
;
constexpr
auto
sliced_dstr_yidx_ylen
=
detail
::
slice_distribution_from_x
(
Distribution
{},
slice_begins
,
slice_ends
);
constexpr
auto
sliced_dstr
=
sliced_dstr_yidx_ylen
.
template
at
<
0
>();
constexpr
auto
sliced_y_origins
=
sliced_dstr_yidx_ylen
.
template
at
<
1
>();
constexpr
auto
sliced_y_lengths
=
sliced_dstr_yidx_ylen
.
template
at
<
2
>();
auto
sliced_tensor
=
make_static_distributed_tensor
<
DataType
>
(
sliced_dstr
);
sliced_tensor
.
get_thread_buffer
()
=
tile
.
get_y_sliced_thread_data
(
sliced_y_origins
,
sliced_y_lengths
);
return
sliced_tensor
;
}
template
<
typename
DstDataType_
,
typename
DstStaticTileDistribution_
,
typename
SrcDataType_
,
typename
SrcStaticTileDistribution_
,
index_t
...
SliceBegins
,
index_t
...
SliceEnds
>
CK_TILE_DEVICE
constexpr
auto
set_slice_tile
(
static_distributed_tensor
<
DstDataType_
,
DstStaticTileDistribution_
>&
dst_tile
,
const
static_distributed_tensor
<
SrcDataType_
,
SrcStaticTileDistribution_
>&
src_tile
,
sequence
<
SliceBegins
...
>
slice_begins
,
sequence
<
SliceEnds
...
>
slice_ends
)
{
using
DstDistribution
=
remove_cvref_t
<
DstStaticTileDistribution_
>
;
constexpr
auto
sliced_dstr_yidx_ylen
=
detail
::
slice_distribution_from_x
(
DstDistribution
{},
slice_begins
,
slice_ends
);
constexpr
auto
sliced_dstr
=
sliced_dstr_yidx_ylen
.
template
at
<
0
>();
constexpr
auto
sliced_y_origins
=
sliced_dstr_yidx_ylen
.
template
at
<
1
>();
constexpr
auto
sliced_y_lengths
=
sliced_dstr_yidx_ylen
.
template
at
<
2
>();
static_assert
(
std
::
is_same_v
<
decltype
(
sliced_dstr
),
DstDistribution
>
,
"wrong!"
);
dst_tile
.
SetSlicedThreadData
(
sliced_y_origins
,
sliced_y_lengths
,
src_tile
.
get_thread_buffer
());
}
}
// namespace ck_tile
include/ck_tile/core/tensor/static_distributed_tensor.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
namespace
ck_tile
{
template
<
typename
DataType_
,
typename
StaticTileDistribution_
>
struct
static_distributed_tensor
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
using
StaticTileDistribution
=
remove_cvref_t
<
StaticTileDistribution_
>
;
static_assert
(
StaticTileDistribution
::
is_static
(),
"wrong! StaticTileDistribution should be known at compile tile"
);
using
ThreadTensorDesc
=
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
{
return
StaticTileDistribution
::
get_num_of_dimension_x
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lengths
()
{
return
StaticTileDistribution
::
get_lengths
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_tile_distribution
()
{
return
StaticTileDistribution
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_distributed_spans
()
{
return
StaticTileDistribution
::
get_distributed_spans
();
}
CK_TILE_HOST_DEVICE
void
initialize
(
const
DataType
&
x
)
{
thread_buf_
.
initialize
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_thread_buffer
()
const
{
return
thread_buf_
;
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_thread_buffer
()
{
return
thread_buf_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_thread_buffer_size
()
{
return
kThreadElementSpaceSize
;
}
template
<
index_t
...
YSliceOrigins
,
index_t
...
YSliceLengths
>
CK_TILE_HOST_DEVICE
auto
get_y_sliced_thread_data
(
sequence
<
YSliceOrigins
...
>
,
sequence
<
YSliceLengths
...
>
)
const
{
static_assert
(
sizeof
...(
YSliceOrigins
)
==
StaticTileDistribution
::
NDimY
&&
sizeof
...(
YSliceLengths
)
==
StaticTileDistribution
::
NDimY
,
"wrong!"
);
constexpr
auto
sliced_thread_tensor_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
YSliceLengths
...));
thread_buffer
<
DataType
,
sliced_thread_tensor_desc
.
get_element_space_size
()
>
sliced_thread_data
;
static_ford
<
sequence
<
YSliceLengths
...
>>
{}([
&
](
auto
idx
)
{
constexpr
auto
idx_ys
=
idx
+
sequence
<
YSliceOrigins
...
>
{};
sliced_thread_data
(
number
<
sliced_thread_tensor_desc
.
calculate_offset
(
idx
)
>
{})
=
thread_buf_
[
number
<
ThreadTensorDesc
{}.
calculate_offset
(
idx_ys
)
>
{}];
});
return
sliced_thread_data
;
}
template
<
index_t
...
YSliceOrigins
,
index_t
...
YSliceLengths
,
typename
SlicedThreadData
>
CK_TILE_HOST_DEVICE
void
set_y_sliced_thread_data
(
sequence
<
YSliceOrigins
...
>
,
sequence
<
YSliceLengths
...
>
,
const
SlicedThreadData
&
sliced_thread_data
)
{
static_assert
(
sizeof
...(
YSliceOrigins
)
==
StaticTileDistribution
::
NDimY
&&
sizeof
...(
YSliceLengths
)
==
StaticTileDistribution
::
NDimY
,
"wrong!"
);
constexpr
auto
sliced_thread_tensor_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
YSliceLengths
...));
static_ford
<
sequence
<
YSliceLengths
...
>>
{}([
&
](
auto
idx
)
{
constexpr
auto
idx_ys
=
idx
+
sequence
<
YSliceOrigins
...
>
{};
thread_buf_
(
number
<
ThreadTensorDesc
{}.
calculate_offset
(
idx_ys
)
>
{})
=
sliced_thread_data
[
number
<
sliced_thread_tensor_desc
.
calculate_offset
(
idx
)
>
{}];
});
}
template
<
typename
TileDistributedIndices
>
CK_TILE_HOST_DEVICE
constexpr
const
DataType
&
operator
[](
TileDistributedIndices
)
const
{
static_assert
(
is_static_v
<
TileDistributedIndices
>
,
"wrong! Tile Distributed Indices should be static"
);
constexpr
auto
y_idx
=
get_tile_distribution
().
get_y_indices_from_distributed_indices
(
TileDistributedIndices
{});
return
thread_buf_
[
number
<
ThreadTensorDesc
{}.
calculate_offset
(
y_idx
)
>
{}];
}
template
<
typename
TileDistributedIndices
>
CK_TILE_HOST_DEVICE
constexpr
DataType
&
operator
()(
TileDistributedIndices
)
{
static_assert
(
is_static_v
<
TileDistributedIndices
>
,
"wrong! Tile Distributed Indices should be static"
);
constexpr
auto
y_idx
=
get_tile_distribution
().
get_y_indices_from_distributed_indices
(
TileDistributedIndices
{});
return
thread_buf_
(
number
<
ThreadTensorDesc
{}.
calculate_offset
(
y_idx
)
>
{});
}
//
thread_buffer
<
DataType
,
kThreadElementSpaceSize
>
thread_buf_
;
};
template
<
typename
DataType
,
typename
StaticTileDistribution
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_static_distributed_tensor
(
const
StaticTileDistribution
&
)
{
return
static_distributed_tensor
<
remove_cvref_t
<
DataType
>
,
remove_cvref_t
<
StaticTileDistribution
>>
{};
}
template
<
typename
DataType
,
typename
StaticTileDistribution
,
typename
ThreadBuffer
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_static_distributed_tensor
(
const
StaticTileDistribution
&
,
ThreadBuffer
&&
thread_buffer_
)
{
return
static_distributed_tensor
<
remove_cvref_t
<
DataType
>
,
remove_cvref_t
<
StaticTileDistribution
>>
{
thread_buffer_
};
}
// get X indices from tuple of tile_distributed_index<>
template
<
typename
StaticTileDistribution
,
typename
DistributedIndices
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_x_indices_from_distributed_indices
(
StaticTileDistribution
tile_distribution
,
DistributedIndices
distributed_indices
)
{
const
auto
partition_index
=
detail
::
get_partition_index
(
tile_distribution
);
constexpr
auto
y_indices
=
tile_distribution
.
get_y_indices_from_distributed_indices
(
distributed_indices
);
const
auto
x_coord
=
make_tensor_adaptor_coordinate
(
tile_distribution
.
get_ps_ys_to_xs_adaptor
(),
container_concat
(
partition_index
,
to_array
<
ck_tile
::
index_t
,
y_indices
.
size
()
>
(
y_indices
)));
return
x_coord
.
get_bottom_index
();
}
template
<
typename
DataType
,
typename
StaticTileDistribution
,
typename
XIndicesPredicate
>
CK_TILE_HOST_DEVICE
void
set_tile_if
(
static_distributed_tensor
<
DataType
,
StaticTileDistribution
>&
out_tensor
,
DataType
value
,
XIndicesPredicate
predicate
)
{
constexpr
auto
out_spans
=
static_distributed_tensor
<
DataType
,
StaticTileDistribution
>::
get_distributed_spans
();
sweep_tile_span
(
out_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
out_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
distributed_indices
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
StaticTileDistribution
{},
distributed_indices
);
if
(
predicate
(
x_indices
))
{
out_tensor
(
distributed_indices
)
=
value
;
}
});
});
}
}
// namespace ck_tile
include/ck_tile/core/tensor/store_tile.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
DataType_
>
,
DataType
>
,
"wrong!"
);
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
tile_window
=
make_tile_window
(
tile_window_tmp
.
get_bottom_tensor_view
(),
tile_window_tmp
.
get_window_lengths
(),
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_window
.
store
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
DataType_
>
,
DataType
>
,
"wrong!"
);
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
tile_window
=
make_tile_window
(
tile_window_tmp
.
get_bottom_tensor_view
(),
tile_window_tmp
.
get_window_lengths
(),
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_window
.
store_raw
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
store
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
store_raw
(
dstr_tensor
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/sweep_tile.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// sweep over a span of a distribted tile and apply lambda function F
template
<
typename
TileDistributedSpan_
,
// tile_distributed_span<...>
typename
F
// signature: F(tile_distributed_index<...>)
>
CK_TILE_DEVICE
void
sweep_tile_span
(
TileDistributedSpan_
,
const
F
&
f
)
{
using
DstrSpan
=
remove_cvref_t
<
TileDistributedSpan_
>
;
static_ford
<
typename
DstrSpan
::
Impl
>
{}([
&
](
auto
dstr_idx_impl
)
{
constexpr
auto
dstr_idx
=
detail
::
make_tile_distributed_index
(
dstr_idx_impl
);
f
(
dstr_idx
);
});
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tensor_adaptor.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
namespace
ck_tile
{
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// BottomDimensionHiddenIds : Sequence<...>
// TopDimensionHiddenIds : Sequence<...>
template
<
typename
Transforms
,
typename
LowerDimensionHiddenIdss
,
typename
UpperDimensionHiddenIdss
,
typename
BottomDimensionHiddenIds
,
typename
TopDimensionHiddenIds
>
struct
tensor_adaptor
{
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_transform
()
{
return
Transforms
::
size
();
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_transforms
()
const
{
return
transforms_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lower_dimension_hidden_idss
()
{
return
LowerDimensionHiddenIdss
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_upper_dimension_hidden_idss
()
{
return
UpperDimensionHiddenIdss
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_bottom_dimension_hidden_ids
()
{
return
BottomDimensionHiddenIds
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_top_dimension_hidden_ids
()
{
return
TopDimensionHiddenIds
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
initialize_element_size
(
const
Transforms
&
transforms
)
{
const
auto
lengths
=
generate_tuple
(
[
&
](
auto
idim_top
)
{
constexpr
index_t
idim_hidden
=
TopDimensionHiddenIds
::
at
(
idim_top
);
constexpr
auto
tmp
=
get_transform_and_its_upper_dimension
(
number
<
idim_hidden
>
{});
constexpr
index_t
itran
=
tmp
[
number
<
0
>
{}];
constexpr
index_t
idim_up
=
tmp
[
number
<
1
>
{}];
constexpr
bool
found
=
tmp
[
number
<
2
>
{}];
static_assert
(
found
==
true
,
"wrong! not found matching transformation and upper-dimension"
);
const
auto
length
=
transforms
[
number
<
itran
>
{}].
get_upper_lengths
()[
number
<
idim_up
>
{}];
return
length
;
},
number
<
ndim_top_
>
{});
// TODO: make container_reduce support tuple of number and index_t
return
container_reduce
(
lengths
,
multiplies
{},
number
<
1
>
{});
}
template
<
index_t
IDimHidden
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_transform_and_its_upper_dimension
(
number
<
IDimHidden
>
)
{
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
// saved in transformation
static_assert
(
IDimHidden
>=
ndim_bottom_
,
"wrong! not implemented"
);
index_t
itran_found
=
0
;
index_t
idim_up_found
=
0
;
bool
found
=
false
;
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
up_dim_ids
=
UpperDimensionHiddenIdss
{}[
itran
];
static_for
<
0
,
up_dim_ids
.
size
(),
1
>
{}([
&
](
auto
idim_up
)
{
if
constexpr
(
up_dim_ids
[
idim_up
]
==
IDimHidden
)
{
itran_found
=
itran
;
idim_up_found
=
idim_up
;
found
=
true
;
}
});
});
return
make_tuple
(
itran_found
,
idim_up_found
,
found
);
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_bottom_dimension
()
{
return
BottomDimensionHiddenIds
::
size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_top_dimension
()
{
return
TopDimensionHiddenIds
::
size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_hidden_dimension
()
{
constexpr
auto
all_low_dim_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
LowerDimensionHiddenIdss
{});
constexpr
auto
all_up_dim_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
UpperDimensionHiddenIdss
{});
constexpr
auto
all_dim_ids
=
merge_sequences
(
all_low_dim_ids
,
all_up_dim_ids
);
using
unique_sort_all_dim_ids
=
typename
sequence_unique_sort
<
decltype
(
all_dim_ids
),
less
<
index_t
>
,
equal
<
index_t
>>::
type
;
return
unique_sort_all_dim_ids
::
size
();
}
constexpr
static
index_t
ntransform_
=
get_num_of_transform
();
constexpr
static
index_t
ndim_hidden_
=
get_num_of_hidden_dimension
();
constexpr
static
index_t
ndim_bottom_
=
get_num_of_bottom_dimension
();
constexpr
static
index_t
ndim_top_
=
get_num_of_top_dimension
();
using
HiddenIndex
=
multi_index
<
ndim_hidden_
>
;
using
BottomIndex
=
multi_index
<
ndim_bottom_
>
;
using
TopIndex
=
multi_index
<
ndim_top_
>
;
// may be index_t or number<>
using
ElementSize
=
remove_cv_t
<
decltype
(
initialize_element_size
(
Transforms
{}))
>
;
public:
CK_TILE_HOST_DEVICE
constexpr
tensor_adaptor
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
tensor_adaptor
(
const
Transforms
&
transforms
)
:
transforms_
{
transforms
},
element_size_
{
initialize_element_size
(
transforms
)}
{
static_assert
(
Transforms
::
size
()
==
ntransform_
&&
LowerDimensionHiddenIdss
::
size
()
==
ntransform_
&&
UpperDimensionHiddenIdss
::
size
()
==
ntransform_
,
"wrong! inconsistent # of transformations"
);
// TODO check dependency of dimensions is valid
}
CK_TILE_HOST_DEVICE
constexpr
auto
get_element_size
()
const
{
return
element_size_
;
}
// FIXME: this logic is wrong when getting bottome dimension lengths
template
<
index_t
IDimHidden
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_hidden_dimension_length
(
number
<
IDimHidden
>
)
const
{
static_assert
(
IDimHidden
>=
0
&&
IDimHidden
<
ndim_hidden_
,
"wrong! out of range"
);
constexpr
auto
tmp
=
get_transform_and_its_upper_dimension
(
number
<
IDimHidden
>
{});
constexpr
index_t
itran
=
tmp
[
number
<
0
>
{}];
constexpr
index_t
idim_up
=
tmp
[
number
<
1
>
{}];
constexpr
bool
found
=
tmp
[
number
<
2
>
{}];
static_assert
(
found
==
true
,
"wrong! not found matching transformation and upper-dimension"
);
return
transforms_
[
number
<
itran
>
{}].
get_upper_lengths
()[
number
<
idim_up
>
{}];
}
template
<
index_t
IDimTop
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_top_dimension_length
(
number
<
IDimTop
>
idim_top
)
const
{
return
get_hidden_dimension_length
(
TopDimensionHiddenIds
::
at
(
idim_top
));
}
#if 0
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
template <index_t IDimBottom>
CK_TILE_HOST_DEVICE constexpr index_t
get_bottom_dimension_length(number<IDimBottom> idim_bottom) const
{
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_bottom));
}
#endif
CK_TILE_HOST_DEVICE
constexpr
auto
get_top_dimension_lengths
()
const
{
return
generate_tuple
([
&
](
auto
i
)
{
return
get_top_dimension_length
(
i
);
},
number
<
ndim_top_
>
{});
}
#if 0
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
CK_TILE_HOST_DEVICE constexpr auto GetBottomDimensionLengths() const
{
return generate_tuple([&](auto i) { return get_bottom_dimension_length(i); },
number<ndim_bottom_>{});
}
#endif
template
<
typename
TopIdx
>
CK_TILE_HOST_DEVICE
constexpr
auto
calculate_bottom_index
(
const
TopIdx
&
idx_top
)
const
{
static_assert
(
TopIdx
::
size
()
==
TopDimensionHiddenIds
::
size
(),
"wrong! # of dimension inconsistent"
);
constexpr
index_t
ntransform
=
get_num_of_transform
();
constexpr
index_t
ndim_hidden
=
get_num_of_hidden_dimension
();
multi_index
<
ndim_hidden
>
idx_hidden
;
// initialize uppest index
set_container_subset
(
idx_hidden
,
get_top_dimension_hidden_ids
(),
idx_top
);
// calculate hidden index
static_for
<
ntransform
,
0
,
-
1
>
{}([
&
](
auto
itran_p1
)
{
auto
itran
=
itran_p1
-
number
<
1
>
{};
const
auto
&
tran
=
get_transforms
().
at
(
itran
);
constexpr
auto
dims_low
=
get_lower_dimension_hidden_idss
().
at
(
itran
);
constexpr
auto
dims_up
=
get_upper_dimension_hidden_idss
().
at
(
itran
);
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
dims_up
);
multi_index
<
dims_low
.
size
()
>
idx_low
;
tran
.
calculate_lower_index
(
idx_low
,
idx_up
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
});
return
get_container_subset
(
idx_hidden
,
BottomDimensionHiddenIds
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cvref_t
<
decltype
(
Transforms
{}[
i
])
>::
is_known_at_compile_time
();
});
return
is_known
&&
ck_tile
::
is_known_at_compile_time
<
ElementSize
>::
value
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
is_static
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_top_dimension_safe_vector_length_strides
(
const
array
<
index_t
,
ndim_hidden_
>&
guaranteed_vector_lengths
,
const
array
<
index_t
,
ndim_hidden_
>&
guaranteed_vector_strides
)
{
auto
vector_lengths
=
guaranteed_vector_lengths
;
auto
vector_strides
=
guaranteed_vector_strides
;
static_for
<
0
,
get_num_of_transform
(),
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
low_dims
=
get_lower_dimension_hidden_idss
().
at
(
itran
);
constexpr
auto
up_dims
=
get_upper_dimension_hidden_idss
().
at
(
itran
);
const
auto
up_guaranteed_vector_lengths
=
get_container_subset
(
guaranteed_vector_lengths
,
up_dims
);
const
auto
up_guaranteed_vector_strides
=
get_container_subset
(
guaranteed_vector_strides
,
up_dims
);
// only need type of transform
auto
[
up_vector_lengths
,
up_vector_strides
]
=
Transforms
{}.
at
(
itran
).
calculate_upper_dimension_safe_vector_length_strides
(
get_container_subset
(
vector_lengths
,
low_dims
),
get_container_subset
(
vector_strides
,
low_dims
));
if
constexpr
(
up_dims
.
size
()
>
0
)
{
for
(
index_t
i
=
0
;
i
<
up_dims
.
size
();
++
i
)
{
up_vector_lengths
(
i
)
=
(
up_guaranteed_vector_lengths
[
i
]
!=
-
1
)
?
up_guaranteed_vector_lengths
[
i
]
:
up_vector_lengths
[
i
];
up_vector_strides
(
i
)
=
(
up_guaranteed_vector_strides
[
i
]
!=
-
1
)
?
up_guaranteed_vector_strides
[
i
]
:
up_vector_strides
[
i
];
}
}
set_container_subset
(
vector_lengths
,
up_dims
,
up_vector_lengths
);
set_container_subset
(
vector_strides
,
up_dims
,
up_vector_strides
);
});
constexpr
auto
top_dims
=
TopDimensionHiddenIds
{};
return
make_tuple
(
get_container_subset
(
vector_lengths
,
top_dims
),
get_container_subset
(
vector_strides
,
top_dims
));
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tensor_adaptor{"
);
//
printf
(
"transforms: "
);
print
(
transforms_
);
printf
(
", "
);
//
printf
(
"LowerDimensionHiddenIds: "
);
print
(
LowerDimensionHiddenIdss
{});
printf
(
", "
);
//
printf
(
"UpperDimensionHiddenIds: "
);
print
(
UpperDimensionHiddenIdss
{});
printf
(
", "
);
//
printf
(
"BottomDimensionHiddenIds: "
);
print
(
BottomDimensionHiddenIds
{});
printf
(
", "
);
//
printf
(
"TopDimensionHiddenIds: "
);
print
(
TopDimensionHiddenIds
{});
printf
(
"}"
);
}
private:
Transforms
transforms_
;
ElementSize
element_size_
;
};
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template
<
typename
Transforms
,
typename
LowerDimensionOldTopIdss
,
typename
UpperDimensionNewTopIdss
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_single_stage_tensor_adaptor
(
const
Transforms
&
transforms
,
LowerDimensionOldTopIdss
,
UpperDimensionNewTopIdss
)
{
constexpr
index_t
ntransform
=
Transforms
::
size
();
static_assert
(
LowerDimensionOldTopIdss
::
size
()
==
ntransform
&&
UpperDimensionNewTopIdss
::
size
()
==
ntransform
,
"wrong!"
);
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr
auto
all_low_dim_old_top_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
LowerDimensionOldTopIdss
{});
constexpr
auto
all_up_dim_new_top_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
UpperDimensionNewTopIdss
{});
static_assert
(
is_valid_sequence_map
<
decltype
(
all_low_dim_old_top_ids
)
>::
value
&&
is_valid_sequence_map
<
decltype
(
all_up_dim_new_top_ids
)
>::
value
,
"wrong!"
);
constexpr
index_t
ndim_old_top
=
all_low_dim_old_top_ids
.
size
();
constexpr
index_t
ndim_new_top
=
all_up_dim_new_top_ids
.
size
();
// low_dim_hidden_idss
constexpr
auto
low_dim_hidden_idss
=
LowerDimensionOldTopIdss
{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr
auto
up_dim_hidden_idss
=
generate_tuple
(
[](
auto
itran
)
{
return
UpperDimensionNewTopIdss
{}[
itran
]
+
number
<
ndim_old_top
>
{};
},
number
<
ntransform
>
{});
// bottom_dim_hidden_ids
constexpr
auto
bottom_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
0
,
ndim_old_top
,
1
>::
type
{};
// top_dim_hidden_ids
constexpr
auto
top_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
0
,
ndim_new_top
,
1
>::
type
{}
+
number
<
ndim_old_top
>
{};
return
tensor_adaptor
<
remove_cvref_t
<
Transforms
>
,
remove_cvref_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cvref_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cvref_t
<
decltype
(
bottom_dim_hidden_ids
)
>
,
remove_cvref_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
transforms
};
}
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_tensor_adaptor) because template cannot be defined inside a function
// template
template
<
typename
NewTransforms
>
struct
lambda_get_up_dim_num
{
template
<
typename
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
I
)
const
{
using
Tran
=
remove_reference_t
<
decltype
(
NewTransforms
{}.
at
(
I
{}))
>
;
return
number
<
Tran
::
get_num_of_upper_dimension
()
>
{};
}
};
template
<
typename
OldTensorAdaptor
,
typename
NewTransforms
,
typename
NewLowerDimensionOldTopIdss
,
typename
NewUpperDimensionNewTopIdss
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tensor_adaptor
(
const
OldTensorAdaptor
&
old_tensor_adaptor
,
const
NewTransforms
&
new_transforms
,
NewLowerDimensionOldTopIdss
,
NewUpperDimensionNewTopIdss
)
{
// sanity check
{
static_assert
(
NewTransforms
::
size
()
==
NewLowerDimensionOldTopIdss
::
size
()
&&
NewTransforms
::
size
()
==
NewUpperDimensionNewTopIdss
::
size
(),
"wrong! inconsitent number of transform"
);
constexpr
auto
all_old_top_ids
=
unpack
([](
auto
...
xs
)
{
return
merge_sequences
(
xs
...);
},
NewLowerDimensionOldTopIdss
{});
constexpr
auto
all_new_top_ids
=
unpack
([](
auto
...
xs
)
{
return
merge_sequences
(
xs
...);
},
NewUpperDimensionNewTopIdss
{});
static_assert
(
is_valid_sequence_map
<
decltype
(
all_old_top_ids
)
>::
value
&&
is_valid_sequence_map
<
decltype
(
all_new_top_ids
)
>::
value
,
"wrong!"
);
}
// lower dimension's hidden idss
// convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr
auto
low_dim_hidden_idss
=
transform_tuples
(
// convert lower dimension top ids (a sequence) to hidden ids (a sequence)
[](
auto
low_dim_top_ids
)
constexpr
{
return
transform_sequences
(
// convert lower dimension top id to hidden id
[](
auto
low_dim_top_id
)
constexpr
{
return
OldTensorAdaptor
::
get_top_dimension_hidden_ids
()[
low_dim_top_id
];
},
low_dim_top_ids
);
},
NewLowerDimensionOldTopIdss
{});
constexpr
index_t
num_new_transform
=
NewTransforms
::
size
();
// upper dimension's hidden idss
constexpr
index_t
old_hidden_dim_number
=
OldTensorAdaptor
::
get_num_of_hidden_dimension
();
constexpr
auto
up_dim_numbers
=
generate_sequence
(
lambda_get_up_dim_num
<
NewTransforms
>
{},
number
<
num_new_transform
>
{});
constexpr
auto
up_dim_numbers_scan
=
merge_sequences
(
sequence
<
0
>
{},
inclusive_scan_sequence
(
up_dim_numbers
,
plus
<
index_t
>
{},
number
<
0
>
{}));
constexpr
auto
up_dim_hidden_idss
=
generate_tuple
(
[
old_hidden_dim_number
,
up_dim_numbers_scan
](
auto
i
)
constexpr
{
return
typename
arithmetic_sequence_gen
<
old_hidden_dim_number
+
up_dim_numbers_scan
[
i
],
old_hidden_dim_number
+
up_dim_numbers_scan
[
i
+
1
],
1
>::
type
{};
},
number
<
num_new_transform
>
{});
// new top dimension's hidden ids
constexpr
auto
unordered_new_top_dim_hidden_ids
=
unpack
(
[](
auto
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
up_dim_hidden_idss
);
constexpr
auto
new_top_dim_unordered2ordered
=
unpack
(
[](
auto
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
NewUpperDimensionNewTopIdss
{});
constexpr
auto
new_top_dim_hidden_ids
=
unordered_new_top_dim_hidden_ids
.
reorder_old_to_new
(
new_top_dim_unordered2ordered
);
// put everything together
const
auto
all_transforms
=
container_concat
(
old_tensor_adaptor
.
get_transforms
(),
new_transforms
);
constexpr
auto
all_low_dim_hidden_idss
=
container_concat
(
OldTensorAdaptor
::
get_lower_dimension_hidden_idss
(),
low_dim_hidden_idss
);
constexpr
auto
all_up_dim_hidden_idss
=
container_concat
(
OldTensorAdaptor
::
get_upper_dimension_hidden_idss
(),
up_dim_hidden_idss
);
return
tensor_adaptor
<
remove_cvref_t
<
decltype
(
all_transforms
)
>
,
remove_cvref_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cvref_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cvref_t
<
decltype
(
OldTensorAdaptor
::
get_bottom_dimension_hidden_ids
())
>
,
remove_cvref_t
<
decltype
(
new_top_dim_hidden_ids
)
>>
{
all_transforms
};
}
template
<
typename
TensorAdaptor0
,
typename
TensorAdaptor1
>
CK_TILE_HOST_DEVICE
constexpr
auto
chain_tensor_adaptors
(
const
TensorAdaptor0
&
adaptor0
,
const
TensorAdaptor1
&
adaptor1
)
{
static_assert
(
TensorAdaptor0
::
get_num_of_top_dimension
()
==
TensorAdaptor1
::
get_num_of_bottom_dimension
(),
"wrong!"
);
// all_transforms = transform0 + transform1
const
auto
all_transforms
=
container_concat
(
adaptor0
.
get_transforms
(),
adaptor1
.
get_transforms
());
// shift
constexpr
index_t
adaptor0_max_hidden_id
=
[
&
]()
{
index_t
adaptor0_max_hidden_id_
=
numeric
<
index_t
>::
min
();
static_for
<
0
,
TensorAdaptor0
::
get_num_of_transform
(),
1
>
{}([
&
](
auto
itran
)
{
constexpr
index_t
ndim_low
=
TensorAdaptor0
{}.
get_transforms
()[
itran
].
get_num_of_lower_dimension
();
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
idim_low
)
{
adaptor0_max_hidden_id_
=
max
(
adaptor0_max_hidden_id_
,
TensorAdaptor0
::
get_lower_dimension_hidden_idss
()[
itran
][
idim_low
].
value
);
});
constexpr
index_t
ndim_up
=
TensorAdaptor0
{}.
get_transforms
()[
itran
].
get_num_of_upper_dimension
();
static_for
<
0
,
ndim_up
,
1
>
{}([
&
](
auto
idim_up
)
{
adaptor0_max_hidden_id_
=
max
(
adaptor0_max_hidden_id_
,
TensorAdaptor0
::
get_upper_dimension_hidden_idss
()[
itran
][
idim_up
].
value
);
});
});
return
adaptor0_max_hidden_id_
;
}();
constexpr
index_t
adaptor1_min_hidden_id
=
[
&
]()
{
index_t
adaptor1_min_hidden_id_
=
numeric
<
index_t
>::
max
();
static_for
<
0
,
TensorAdaptor1
::
get_num_of_transform
(),
1
>
{}([
&
](
auto
itran
)
{
constexpr
index_t
ndim_low
=
TensorAdaptor1
{}.
get_transforms
()[
itran
].
get_num_of_lower_dimension
();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
idim_low
)
{
constexpr
index_t
low_dim_hidden_id
=
TensorAdaptor1
::
get_lower_dimension_hidden_idss
()[
itran
][
idim_low
].
value
;
bool
is_bottom_dim
=
false
;
static_for
<
0
,
TensorAdaptor1
::
get_num_of_bottom_dimension
(),
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
low_dim_hidden_id
==
TensorAdaptor1
::
get_bottom_dimension_hidden_ids
()[
i
])
{
is_bottom_dim
=
true
;
}
});
if
(
!
is_bottom_dim
)
{
adaptor1_min_hidden_id_
=
min
(
adaptor1_min_hidden_id_
,
low_dim_hidden_id
);
}
});
constexpr
index_t
ndim_up
=
TensorAdaptor1
{}.
get_transforms
()[
itran
].
get_num_of_upper_dimension
();
// get the min of all upper dimensions
static_for
<
0
,
ndim_up
,
1
>
{}([
&
](
auto
idim_up
)
{
adaptor1_min_hidden_id_
=
min
(
adaptor1_min_hidden_id_
,
TensorAdaptor1
::
get_upper_dimension_hidden_idss
()[
itran
][
idim_up
].
value
);
});
});
return
adaptor1_min_hidden_id_
;
}();
constexpr
index_t
adaptor1_hidden_id_shift
=
adaptor0_max_hidden_id
+
1
-
adaptor1_min_hidden_id
;
constexpr
index_t
ndim_bottom_1
=
TensorAdaptor1
::
get_num_of_bottom_dimension
();
// all_low_dim_hidden_idss =
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
constexpr
auto
low_dim_hidden_idss_1
=
generate_tuple
(
// generate sequence of ids for a transform
[
&
](
auto
itran
)
{
constexpr
auto
ndim_low_1
=
TensorAdaptor1
::
get_lower_dimension_hidden_idss
()[
itran
].
size
();
constexpr
auto
low_dim_hidden_ids_1
=
TensorAdaptor1
::
get_lower_dimension_hidden_idss
()[
itran
];
// sequence in, sequence out
constexpr
auto
low_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
auto
low_dim_hidden_ids_1_mod_
=
to_multi_index
(
low_dim_hidden_ids_1
);
// shift hidden id so every dim id is unique
static_for
<
0
,
ndim_low_1
,
1
>
{}([
&
](
auto
idim_low_1
)
{
low_dim_hidden_ids_1_mod_
(
idim_low_1
)
+=
adaptor1_hidden_id_shift
;
});
// match hidden id
static_for
<
0
,
ndim_low_1
,
1
>
{}([
&
](
auto
idim_low_1
)
{
static_for
<
0
,
ndim_bottom_1
,
1
>
{}([
&
](
auto
idim_bottom_1
)
{
// if this low dim is bottom dim, then do id matching
if
constexpr
(
low_dim_hidden_ids_1
[
idim_low_1
]
==
TensorAdaptor1
::
get_bottom_dimension_hidden_ids
()
[
idim_bottom_1
])
{
low_dim_hidden_ids_1_mod_
(
idim_low_1
)
=
TensorAdaptor0
::
get_top_dimension_hidden_ids
()[
idim_bottom_1
];
}
});
});
return
low_dim_hidden_ids_1_mod_
;
}
();
return
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
return
number
<
low_dim_hidden_ids_1_mod
[
i
]
>
{};
},
number
<
ndim_low_1
>
{});
},
number
<
TensorAdaptor1
::
get_num_of_transform
()
>
{});
constexpr
auto
all_low_dim_hidden_idss
=
container_concat
(
TensorAdaptor0
::
get_lower_dimension_hidden_idss
(),
low_dim_hidden_idss_1
);
// all_up_dim_hidden_idss =
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
constexpr
auto
up_dim_hidden_idss_1
=
generate_tuple
(
// generate sequence of ids for a transform
[
&
](
auto
itran
)
{
constexpr
auto
ndim_up_1
=
TensorAdaptor1
::
get_upper_dimension_hidden_idss
()[
itran
].
size
();
constexpr
auto
up_dim_hidden_ids_1
=
TensorAdaptor1
::
get_upper_dimension_hidden_idss
()[
itran
];
// sequence in, constexpr tuple out
constexpr
auto
up_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
auto
up_dim_hidden_ids_1_mod_
=
to_multi_index
(
up_dim_hidden_ids_1
);
// shift hidden id
static_for
<
0
,
ndim_up_1
,
1
>
{}([
&
](
auto
idim_up_1
)
{
up_dim_hidden_ids_1_mod_
(
idim_up_1
)
+=
adaptor1_hidden_id_shift
;
});
return
up_dim_hidden_ids_1_mod_
;
}
();
// constexpr tuple to sequence
return
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
return
number
<
up_dim_hidden_ids_1_mod
[
i
]
>
{};
},
number
<
ndim_up_1
>
{});
},
number
<
TensorAdaptor1
::
get_num_of_transform
()
>
{});
constexpr
auto
all_up_dim_hidden_idss
=
container_concat
(
TensorAdaptor0
::
get_upper_dimension_hidden_idss
(),
up_dim_hidden_idss_1
);
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
constexpr
auto
bottom_dim_hidden_ids
=
TensorAdaptor0
::
get_bottom_dimension_hidden_ids
();
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr
auto
top_dim_hidden_ids
=
TensorAdaptor1
::
get_top_dimension_hidden_ids
()
+
number
<
adaptor1_hidden_id_shift
>
{};
// put everything together
return
tensor_adaptor
<
remove_cvref_t
<
decltype
(
all_transforms
)
>
,
remove_cvref_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cvref_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cvref_t
<
decltype
(
bottom_dim_hidden_ids
)
>
,
remove_cvref_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
all_transforms
};
}
template
<
typename
X
,
typename
...
Xs
,
typename
std
::
enable_if
<
sizeof
...(
Xs
)
>
=
2
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
chain_tensor_adaptors
(
const
X
&
x
,
const
Xs
&
...
xs
)
{
return
chain_tensor_adaptors
(
x
,
chain_tensor_adaptors
(
xs
...));
}
}
// namespace ck_tile
// Macro function
// construct constexpr tensor_adaptor from constexpr encoding
// encoded_tensor_adaptor are Tuple of following objects:
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
// 1.1 name (coord_transform_enum)
// 1.2 meta data for constructor of the transform
// 1.3 num of lower dimension (index_t)
// 1.4 lower dimension Ids (array of fixed size)
// 1.5 num of up dimension (index_t)
// 1.6 upper dimension Ids (array of fixed size)
// 2. num of transforms (index_t)
// 3. encoded bottom dimension Ids (array of fixed size)
// 4. num of bottom dimension (index_t)
// 5. encoded top dimension Ids (array of fixed size)
// 6. num of top dimension (index_t)
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
[encoded_tensor_adaptor]() { \
using namespace ck_tile; \
\
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
\
static_assert(name == coord_transform_enum::pass_through || \
name == coord_transform_enum::pad || \
name == coord_transform_enum::embed || \
name == coord_transform_enum::merge || \
name == coord_transform_enum::unmerge || \
name == coord_transform_enum::replicate, \
""); \
\
if constexpr(name == coord_transform_enum::pass_through) \
{ \
index_t pos = 0; \
auto low_len = meta_data.template pop<index_t>(pos); \
\
return make_pass_through_transform(low_len); \
} \
else if constexpr(name == coord_transform_enum::pad) \
{ \
index_t pos = 0; \
auto low_len = meta_data.template pop<index_t>(pos); \
auto left_pad = meta_data.template pop<index_t>(pos); \
auto right_pad = meta_data.template pop<index_t>(pos); \
\
return make_pad_transform(low_len, left_pad, right_pad); \
} \
else if constexpr(name == coord_transform_enum::embed) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
auto coefficients = \
meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_embed_transform(up_lens, coefficients); \
} \
else if constexpr(name == coord_transform_enum::merge) \
{ \
index_t pos = 0; \
auto low_lens = meta_data.template pop<array<index_t, num_low_dim>>(pos); \
\
return make_merge_transform(low_lens); \
} \
else if constexpr(name == coord_transform_enum::unmerge) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_unmerge_transform(up_lens); \
} \
else if constexpr(name == coord_transform_enum::replicate) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_replicate_transform(up_lens); \
} \
}, \
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
\
return TO_SEQUENCE(low_dims, num_low_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
\
return TO_SEQUENCE(up_dims, num_up_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
\
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
remove_cvref_t<decltype(low_dim_idss)>, \
remove_cvref_t<decltype(up_dim_idss)>, \
remove_cvref_t<decltype(bottom_dim_ids)>, \
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
}()
// Macro function
// construct static tensor_adaptor from constexpr encoding
// encoded_tensor_adaptor are Tuple of following objects:
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
// 1.1 name (coord_transform_enum)
// 1.2 meta data for constructor of the transform
// 1.3 num of lower dimension (index_t)
// 1.4 lower dimension Ids (array of fixed size)
// 1.5 num of up dimension (index_t)
// 1.6 upper dimension Ids (array of fixed size)
// 2. num of transforms (index_t)
// 3. encoded bottom dimension Ids (array of fixed size)
// 4. num of bottom dimension (index_t)
// 5. encoded top dimension Ids (array of fixed size)
// 6. num of top dimension (index_t)
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
[encoded_tensor_adaptor]() { \
using namespace ck_tile; \
\
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
\
static_assert(name == coord_transform_enum::pass_through || \
name == coord_transform_enum::pad || \
name == coord_transform_enum::embed || \
name == coord_transform_enum::merge || \
name == coord_transform_enum::unmerge || \
name == coord_transform_enum::replicate, \
""); \
\
if constexpr(name == coord_transform_enum::pass_through) \
{ \
constexpr index_t low_len = meta_data.template get<index_t>(0); \
\
return make_pass_through_transform(number<low_len>{}); \
} \
else if constexpr(name == coord_transform_enum::pad) \
{ \
constexpr index_t low_len = meta_data.template get<index_t>(0); \
\
constexpr index_t left_pad = \
meta_data.template get<index_t>(sizeof(low_len)); \
\
constexpr index_t right_pad = \
meta_data.template pop<index_t>(sizeof(low_len) + sizeof(left_pad)); \
\
return make_pad_transform( \
number<low_len>{}, number<left_pad>{}, number<right_pad>{}); \
} \
else if constexpr(name == coord_transform_enum::embed) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
constexpr auto coefficients = \
meta_data.template get<array<index_t, num_up_dim>>(sizeof(up_lens)); \
\
return make_embed_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim), \
TO_TUPLE_OF_NUMBER(coefficients, num_up_dim)); \
} \
else if constexpr(name == coord_transform_enum::merge) \
{ \
constexpr auto low_lens = \
meta_data.template get<array<index_t, num_low_dim>>(0); \
\
return make_merge_transform(TO_TUPLE_OF_NUMBER(low_lens, num_low_dim)); \
} \
else if constexpr(name == coord_transform_enum::unmerge) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
return make_unmerge_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
} \
else if constexpr(name == coord_transform_enum::replicate) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
return make_replicate_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
} \
}, \
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
\
return TO_SEQUENCE(low_dims, num_low_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
\
return TO_SEQUENCE(up_dims, num_up_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
\
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
remove_cvref_t<decltype(low_dim_idss)>, \
remove_cvref_t<decltype(up_dim_idss)>, \
remove_cvref_t<decltype(bottom_dim_ids)>, \
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
}()
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
index_t
NDimHidden
,
typename
BottomDimensionHiddenIds
,
typename
TopDimensionHiddenIds
>
struct
tensor_adaptor_coordinate
{
static
constexpr
index_t
ndim_bottom_
=
BottomDimensionHiddenIds
::
size
();
static
constexpr
index_t
ndim_top_
=
TopDimensionHiddenIds
::
size
();
using
HiddenIndex
=
multi_index
<
NDimHidden
>
;
using
BottomIndex
=
multi_index
<
ndim_bottom_
>
;
using
TopIndex
=
multi_index
<
ndim_top_
>
;
public:
CK_TILE_HOST_DEVICE
constexpr
tensor_adaptor_coordinate
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
tensor_adaptor_coordinate
(
const
HiddenIndex
&
idx_hidden
)
:
idx_hidden_
{
idx_hidden
}
{
}
CK_TILE_HOST_DEVICE
constexpr
auto
get_top_index
()
const
{
return
get_container_subset
(
idx_hidden_
,
TopDimensionHiddenIds
{});
}
CK_TILE_HOST_DEVICE
constexpr
auto
get_bottom_index
()
const
{
return
get_container_subset
(
idx_hidden_
,
BottomDimensionHiddenIds
{});
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_hidden_index
()
const
{
return
idx_hidden_
;
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_hidden_index
()
{
return
idx_hidden_
;
}
//
HiddenIndex
idx_hidden_
;
};
template
<
typename
Adaptor
,
typename
TopIndex
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tensor_adaptor_coordinate
(
const
Adaptor
&
adaptor
,
const
TopIndex
&
idx_top
)
{
static_assert
(
Adaptor
::
get_num_of_top_dimension
()
==
TopIndex
::
size
(),
"wrong! # of dimension inconsistent"
);
constexpr
index_t
ntransform
=
Adaptor
::
get_num_of_transform
();
constexpr
index_t
ndim_hidden
=
Adaptor
::
get_num_of_hidden_dimension
();
constexpr
auto
bottom_dim_ids
=
Adaptor
::
get_bottom_dimension_hidden_ids
();
constexpr
auto
top_dim_ids
=
Adaptor
::
get_top_dimension_hidden_ids
();
multi_index
<
ndim_hidden
>
idx_hidden
;
// initialize visible index
set_container_subset
(
idx_hidden
,
top_dim_ids
,
idx_top
);
// calculate hidden index
static_for
<
ntransform
,
0
,
-
1
>
{}([
&
adaptor
,
&
idx_hidden
](
auto
itran_p1
)
{
auto
itran
=
itran_p1
-
number
<
1
>
{};
const
auto
&
tran
=
adaptor
.
get_transforms
().
at
(
itran
);
constexpr
auto
dims_low
=
Adaptor
::
get_lower_dimension_hidden_idss
().
at
(
itran
);
constexpr
auto
dims_up
=
Adaptor
::
get_upper_dimension_hidden_idss
().
at
(
itran
);
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
dims_up
);
multi_index
<
dims_low
.
size
()
>
idx_low
;
tran
.
calculate_lower_index
(
idx_low
,
idx_up
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
});
return
tensor_adaptor_coordinate
<
ndim_hidden
,
remove_cvref_t
<
decltype
(
bottom_dim_ids
)
>
,
remove_cvref_t
<
decltype
(
top_dim_ids
)
>>
{
idx_hidden
};
}
template
<
bool
JudgeDoTransforms
=
true
,
typename
Adaptor
,
typename
AdaptorCoord
,
typename
TopIndex
,
typename
BottomIndex
>
CK_TILE_HOST_DEVICE
constexpr
void
move_tensor_adaptor_coordinate
(
const
Adaptor
&
adaptor
,
AdaptorCoord
&
coord
,
const
TopIndex
&
idx_diff_top
,
BottomIndex
&
idx_diff_bottom
)
{
constexpr
index_t
ndim_hidden
=
Adaptor
::
get_num_of_hidden_dimension
();
constexpr
index_t
ndim_top
=
Adaptor
::
get_num_of_top_dimension
();
// constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
constexpr
index_t
ntransform
=
Adaptor
::
get_num_of_transform
();
// static_assert(TopIndex::size() == ndim_top && BottomIndex::size() == ndim_bottom, "");
// judge whether calculation of lower diff is needed for each transform
// use index_t for boolean type
auto
do_transforms
=
make_zero_multi_index
<
ntransform
>
();
if
constexpr
(
JudgeDoTransforms
)
{
auto
is_non_zero_diff
=
make_zero_multi_index
<
ndim_hidden
>
();
// decide do_transform by checkout non-zero index diff components
multi_index
<
ndim_top
>
non_zero_diff_pick_top
;
static_for
<
0
,
ndim_top
,
1
>
{}(
[
&
](
auto
i
)
{
non_zero_diff_pick_top
(
i
)
=
(
idx_diff_top
[
i
]
!=
0
);
});
set_container_subset
(
is_non_zero_diff
,
Adaptor
::
get_top_dimension_hidden_ids
(),
non_zero_diff_pick_top
);
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
dims_low
=
Adaptor
::
get_lower_dimension_hidden_idss
().
at
(
itran
);
constexpr
auto
dims_up
=
Adaptor
::
get_upper_dimension_hidden_idss
().
at
(
itran
);
const
auto
non_zero_diff_pick_up
=
get_container_subset
(
is_non_zero_diff
,
dims_up
);
multi_index
<
dims_low
.
size
()
>
non_zero_diff_pick_low
;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const
bool
idx_diff_up_has_non_zero
=
container_reduce
(
non_zero_diff_pick_up
,
[](
auto
a
,
auto
b
)
constexpr
{
return
a
or
b
;
},
false
);
do_transforms
(
itran
)
=
idx_diff_up_has_non_zero
;
static_for
<
0
,
dims_low
.
size
(),
1
>
{}(
[
&
](
auto
i
)
{
non_zero_diff_pick_low
(
i
)
=
idx_diff_up_has_non_zero
;
});
set_container_subset
(
is_non_zero_diff
,
dims_low
,
non_zero_diff_pick_low
);
});
}
else
{
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
do_transforms
(
itran
)
=
1
;
});
}
// this is what needs to be calculated
auto
idx_diff_hidden
=
make_zero_multi_index
<
ndim_hidden
>
();
// initialize top index diff
set_container_subset
(
idx_diff_hidden
,
Adaptor
::
get_top_dimension_hidden_ids
(),
idx_diff_top
);
// this is what needs to be updated
auto
&
idx_hidden
=
coord
.
get_hidden_index
();
// update top index
auto
idx_hidden_pick_top
=
get_container_subset
(
idx_hidden
,
Adaptor
::
get_top_dimension_hidden_ids
());
idx_hidden_pick_top
+=
idx_diff_top
;
set_container_subset
(
idx_hidden
,
Adaptor
::
get_top_dimension_hidden_ids
(),
idx_hidden_pick_top
);
// update rest of hidden index
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
if
(
do_transforms
[
itran
])
{
const
auto
&
tran
=
adaptor
.
get_transforms
().
at
(
itran
);
constexpr
auto
dims_low
=
Adaptor
::
get_lower_dimension_hidden_idss
().
at
(
itran
);
constexpr
auto
dims_up
=
Adaptor
::
get_upper_dimension_hidden_idss
().
at
(
itran
);
const
auto
idx_up_new
=
get_container_subset
(
idx_hidden
,
dims_up
);
auto
idx_low
=
get_container_subset
(
idx_hidden
,
dims_low
);
const
auto
idx_diff_up
=
get_container_subset
(
idx_diff_hidden
,
dims_up
);
multi_index
<
dims_low
.
size
()
>
idx_diff_low
;
tran
.
update_lower_index
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
);
set_container_subset
(
idx_diff_hidden
,
dims_low
,
idx_diff_low
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
}
});
// set bottom index diff
idx_diff_bottom
=
get_container_subset
(
idx_diff_hidden
,
Adaptor
::
get_bottom_dimension_hidden_ids
());
}
template
<
bool
JudgeDoTransforms
=
true
,
typename
Adaptor
,
typename
AdaptorCoord
,
typename
TopIndex
>
CK_TILE_HOST_DEVICE
constexpr
void
move_tensor_adaptor_coordinate
(
const
Adaptor
&
adaptor
,
AdaptorCoord
&
coord
,
const
TopIndex
&
idx_diff_top
)
{
constexpr
index_t
ndim_bottom
=
Adaptor
::
get_num_of_bottom_dimension
();
multi_index
<
ndim_bottom
>
tmp
;
move_tensor_adaptor_coordinate
<
JudgeDoTransforms
>
(
adaptor
,
coord
,
idx_diff_top
,
tmp
);
}
template
<
typename
Adaptor
,
typename
AdaptorCoord
>
CK_TILE_HOST_DEVICE
constexpr
bool
adaptor_coordinate_is_valid_assuming_top_index_is_valid
(
const
Adaptor
&
adaptor
,
const
AdaptorCoord
&
coord
)
{
bool
valid
=
true
;
constexpr
index_t
ntransform
=
Adaptor
::
get_num_of_transform
();
const
auto
&
idx_hidden
=
coord
.
get_hidden_index
();
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
adaptor
,
&
idx_hidden
,
&
valid
](
auto
itran
)
{
const
auto
tran
=
adaptor
.
get_transforms
().
at
(
itran
);
// check validity, only if current transformation does not always has a valid mapping
if
constexpr
(
!
decltype
(
tran
)
::
is_valid_upper_index_always_mapped_to_valid_lower_index
())
{
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
Adaptor
::
get_upper_dimension_hidden_idss
().
at
(
itran
));
// Comment: using valid = valid && .. will result in weird control flow in ISA
valid
&=
tran
.
is_valid_upper_index_mapped_to_valid_lower_index
(
idx_up
);
}
});
return
valid
;
}
template
<
typename
Adaptor
,
typename
AdpatorCoord
>
CK_TILE_HOST_DEVICE
constexpr
bool
adaptor_coordinate_is_valid
(
const
Adaptor
&
adaptor
,
const
AdpatorCoord
&
coord
)
{
// check top index
const
auto
&
idx_top
=
coord
.
get_top_index
();
bool
is_top_index_valid
=
true
;
static_for
<
0
,
Adaptor
::
get_num_of_dimension
(),
1
>
{}(
[
&
is_top_index_valid
,
&
idx_top
,
&
adaptor
](
auto
i
)
{
is_top_index_valid
=
is_top_index_valid
&&
(
idx_top
[
i
]
>=
0
&&
idx_top
[
i
]
<
adaptor
.
get_length
(
i
));
});
// check other hidden index
return
is_top_index_valid
&&
adaptor_coordinate_is_valid_assuming_top_index_is_valid
(
adaptor
,
coord
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tensor_coordinate.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
index_t
NDimHidden
,
typename
TopDimensionHiddenIds
>
struct
tensor_coordinate
:
public
tensor_adaptor_coordinate
<
NDimHidden
,
sequence
<
0
>
,
TopDimensionHiddenIds
>
{
using
Base
=
tensor_adaptor_coordinate
<
NDimHidden
,
sequence
<
0
>
,
TopDimensionHiddenIds
>
;
// TODO make these private
static
constexpr
index_t
ndim_top_
=
TopDimensionHiddenIds
::
size
();
using
HiddenIndex
=
multi_index
<
NDimHidden
>
;
using
TopIndex
=
multi_index
<
ndim_top_
>
;
public:
CK_TILE_HOST_DEVICE
constexpr
tensor_coordinate
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
tensor_coordinate
(
const
HiddenIndex
&
idx_hidden
)
:
Base
{
idx_hidden
}
{
}
// construct from TensorAdaptorCoordinte base class
CK_TILE_HOST_DEVICE
constexpr
tensor_coordinate
(
const
Base
&
adaptor_coord
)
:
Base
{
adaptor_coord
}
{
}
CK_TILE_HOST_DEVICE
constexpr
auto
get_index
()
const
{
return
Base
::
get_top_index
();
}
CK_TILE_HOST_DEVICE
constexpr
index_t
get_offset
()
const
{
return
Base
::
get_bottom_index
()[
number
<
0
>
{}];
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_hidden_index
()
const
{
return
Base
::
get_hidden_index
();
}
CK_TILE_HOST_DEVICE
auto
&
get_hidden_index
()
{
return
Base
::
get_hidden_index
();
}
};
template
<
typename
TensorDesc
,
typename
TopIndex
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
const
TopIndex
&
idx_top
)
{
const
auto
adaptor_coord
=
make_tensor_adaptor_coordinate
(
tensor_desc
,
idx_top
);
return
tensor_coordinate
<
TensorDesc
::
get_num_of_hidden_dimension
(),
remove_cvref_t
<
decltype
(
TensorDesc
::
get_top_dimension_hidden_ids
())
>>
{
adaptor_coord
};
}
template
<
bool
JudgeDoTransforms
=
true
,
typename
TensorDesc
,
typename
TensorCoord
,
typename
Index
>
CK_TILE_HOST_DEVICE
constexpr
void
move_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
TensorCoord
&
coord
,
const
Index
&
coord_step
)
{
move_tensor_adaptor_coordinate
(
tensor_desc
,
coord
,
coord_step
);
}
template
<
typename
TensorDesc
,
typename
TensorCoord
>
CK_TILE_HOST_DEVICE
constexpr
bool
coordinate_has_valid_offset_assuming_top_index_is_valid
(
const
TensorDesc
&
tensor_desc
,
const
TensorCoord
&
coord
)
{
return
adaptor_coordinate_is_valid_assuming_top_index_is_valid
(
tensor_desc
,
coord
);
}
template
<
typename
TensorDesc
,
typename
TensorCoord
>
CK_TILE_HOST_DEVICE
constexpr
bool
coordinate_has_valid_offset
(
const
TensorDesc
&
tensor_desc
,
const
TensorCoord
&
coord
)
{
return
adaptor_coordinate_is_valid
(
tensor_desc
,
coord
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tensor_descriptor.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<sequence<...>, ...>
// TopDimensionHiddenIds> : sequence<...>
template
<
typename
Transforms
,
typename
LowerDimensionHiddenIdss
,
typename
UpperDimensionHiddenIdss
,
typename
TopDimensionHiddenIds
,
typename
ElementSpaceSize
,
typename
GuaranteedVectorLengths_
,
typename
GuaranteedVectorSrides_
>
struct
tensor_descriptor
:
public
tensor_adaptor
<
Transforms
,
LowerDimensionHiddenIdss
,
UpperDimensionHiddenIdss
,
sequence
<
0
>
,
TopDimensionHiddenIds
>
{
using
Base
=
tensor_adaptor
<
Transforms
,
LowerDimensionHiddenIdss
,
UpperDimensionHiddenIdss
,
sequence
<
0
>
,
TopDimensionHiddenIds
>
;
using
ElementSpaceSizeType
=
ElementSpaceSize
;
constexpr
static
index_t
ntransform_
=
Base
::
get_num_of_transform
();
constexpr
static
index_t
ndim_hidden_
=
Base
::
get_num_of_hidden_dimension
();
constexpr
static
index_t
ndim_top_
=
Base
::
get_num_of_top_dimension
();
using
GuaranteedVectorLengths
=
GuaranteedVectorLengths_
;
using
GuaranteedVectorStrides
=
GuaranteedVectorSrides_
;
static_assert
(
GuaranteedVectorLengths
::
size
()
==
ndim_hidden_
&&
GuaranteedVectorStrides
::
size
()
==
ndim_hidden_
,
"wrong! inconsistent # of hidden dimensions"
);
using
TopIndex
=
multi_index
<
ndim_top_
>
;
using
HiddenIndex
=
multi_index
<
ndim_hidden_
>
;
public:
CK_TILE_HOST_DEVICE
constexpr
tensor_descriptor
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
tensor_descriptor
(
const
Transforms
&
transforms
,
ElementSpaceSize
element_space_size
)
:
Base
{
transforms
},
element_space_size_
{
element_space_size
}
{
static_assert
(
Transforms
::
size
()
==
ntransform_
&&
LowerDimensionHiddenIdss
::
size
()
==
ntransform_
&&
UpperDimensionHiddenIdss
::
size
()
==
ntransform_
,
"wrong! inconsistent # of transformations"
);
// TODO check dependency of dimensions is valid
}
// construct from tensor_adaptor base class
CK_TILE_HOST_DEVICE
constexpr
tensor_descriptor
(
const
Base
&
adaptor
,
ElementSpaceSize
element_space_size
)
:
Base
{
adaptor
},
element_space_size_
{
element_space_size
}
{
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
{
return
Base
::
get_num_of_top_dimension
();
}
template
<
index_t
IDim
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_length
(
number
<
IDim
>
idim
)
const
{
return
Base
::
get_top_dimension_length
(
idim
);
}
CK_TILE_HOST_DEVICE
constexpr
auto
get_lengths
()
const
{
return
Base
::
get_top_dimension_lengths
();
}
CK_TILE_HOST_DEVICE
constexpr
auto
get_element_space_size
()
const
{
return
element_space_size_
;
}
template
<
typename
Idx
>
CK_TILE_HOST_DEVICE
constexpr
index_t
calculate_offset
(
const
Idx
&
idx
)
const
{
return
Base
::
calculate_bottom_index
(
idx
)[
number
<
0
>
{}];
}
// TODO make these private
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_transforms
()
const
{
return
Base
::
get_transforms
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lower_dimension_hidden_idss
()
{
return
Base
::
get_lower_dimension_hidden_idss
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_upper_dimension_hidden_idss
()
{
return
Base
::
get_upper_dimension_hidden_idss
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_top_dimension_hidden_ids
()
{
return
Base
::
get_top_dimension_hidden_ids
();
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
return
Base
::
is_known_at_compile_time
()
&&
ck_tile
::
is_known_at_compile_time
<
ElementSpaceSize
>::
value
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
is_static
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_top_dimension_safe_vector_length_strides
()
{
return
Base
::
get_top_dimension_safe_vector_length_strides
(
to_array
<
index_t
,
ndim_hidden_
>
(
GuaranteedVectorLengths
{}),
to_array
<
index_t
,
ndim_hidden_
>
(
GuaranteedVectorStrides
{}));
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tensor_descriptor{"
);
// tensor_adaptor
Base
::
print
();
printf
(
", "
);
// element_space_size_
printf
(
"element_space_size_: "
);
print
(
element_space_size_
);
printf
(
"}"
);
}
// TODO make these private
ElementSpaceSize
element_space_size_
;
};
template
<
typename
Adaptor
,
typename
ElementSpaceSize
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tensor_descriptor_from_adaptor
(
const
Adaptor
&
adaptor
,
const
ElementSpaceSize
&
element_space_size
)
{
constexpr
index_t
NDimHidden
=
Adaptor
::
get_num_of_hidden_dimension
();
return
tensor_descriptor
<
remove_cvref_t
<
decltype
(
adaptor
.
get_transforms
())
>
,
remove_cvref_t
<
decltype
(
adaptor
.
get_lower_dimension_hidden_idss
())
>
,
remove_cvref_t
<
decltype
(
adaptor
.
get_upper_dimension_hidden_idss
())
>
,
remove_cvref_t
<
decltype
(
adaptor
.
get_top_dimension_hidden_ids
())
>
,
remove_cvref_t
<
decltype
(
element_space_size
)
>
,
typename
uniform_sequence_gen
<
NDimHidden
,
-
1
>::
type
,
typename
uniform_sequence_gen
<
NDimHidden
,
-
1
>::
type
>
{
adaptor
,
element_space_size
};
}
template
<
typename
OldTensorDescriptor
,
typename
NewTransforms
,
typename
NewLowerDimensionOldTopIdss
,
typename
NewUpperDimensionNewTopIdss
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tensor_descriptor
(
const
OldTensorDescriptor
&
old_tensor_desc
,
const
NewTransforms
&
new_transforms
,
NewLowerDimensionOldTopIdss
,
NewUpperDimensionNewTopIdss
)
{
const
auto
element_space_size
=
old_tensor_desc
.
get_element_space_size
();
const
auto
new_tensor_adaptor
=
transform_tensor_adaptor
(
old_tensor_desc
,
new_transforms
,
NewLowerDimensionOldTopIdss
{},
NewUpperDimensionNewTopIdss
{});
constexpr
index_t
NDimHiddenOld
=
OldTensorDescriptor
::
get_num_of_hidden_dimension
();
constexpr
index_t
NDimHiddenNew
=
decltype
(
new_tensor_adaptor
)
::
get_num_of_hidden_dimension
();
using
NewGuaranteedVectorLengths
=
typename
sequence_merge
<
typename
OldTensorDescriptor
::
GuaranteedVectorLengths
,
typename
uniform_sequence_gen
<
NDimHiddenNew
-
NDimHiddenOld
,
-
1
>::
type
>::
type
;
using
NewGuaranteedVectorStrides
=
typename
sequence_merge
<
typename
OldTensorDescriptor
::
GuaranteedVectorStrides
,
typename
uniform_sequence_gen
<
NDimHiddenNew
-
NDimHiddenOld
,
-
1
>::
type
>::
type
;
return
tensor_descriptor
<
remove_cvref_t
<
decltype
(
new_tensor_adaptor
.
get_transforms
())
>
,
remove_cvref_t
<
decltype
(
new_tensor_adaptor
.
get_lower_dimension_hidden_idss
())
>
,
remove_cvref_t
<
decltype
(
new_tensor_adaptor
.
get_upper_dimension_hidden_idss
())
>
,
remove_cvref_t
<
decltype
(
new_tensor_adaptor
.
get_top_dimension_hidden_ids
())
>
,
remove_cvref_t
<
decltype
(
element_space_size
)
>
,
NewGuaranteedVectorLengths
,
NewGuaranteedVectorStrides
>
{
new_tensor_adaptor
,
element_space_size
};
}
namespace
detail
{
template
<
typename
Lengths
,
typename
Strides
,
index_t
I
,
typename
AccOld
>
CK_TILE_HOST_DEVICE
constexpr
auto
calculate_element_space_size_impl
(
const
Lengths
&
lengths
,
const
Strides
&
strides
,
number
<
I
>
i
,
AccOld
acc_old
)
{
auto
acc_new
=
acc_old
+
(
lengths
[
i
]
-
number
<
1
>
{})
*
strides
[
i
];
if
constexpr
(
i
.
value
<
Lengths
::
size
()
-
1
)
{
return
calculate_element_space_size_impl
(
lengths
,
strides
,
i
+
number
<
1
>
{},
acc_new
);
}
else
{
return
acc_new
;
}
}
}
// namespace detail
/*
* These functions create naive tensor descriptor
*/
// Lengths..., Strides... could be:
// 1) index_t, which is known at run-time, or
// 2) number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) long_number<>
template
<
typename
...
Lengths
,
typename
...
Strides
,
index_t
GuaranteedLastDimensionVectorLength
=
-
1
,
index_t
GuaranteedLastDimensionVectorStride
=
-
1
,
typename
std
::
enable_if
<
sizeof
...(
Lengths
)
==
sizeof
...(
Strides
),
bool
>
::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_naive_tensor_descriptor
(
const
tuple
<
Lengths
...
>&
lengths
,
const
tuple
<
Strides
...
>&
strides
,
number
<
GuaranteedLastDimensionVectorLength
>
=
number
<-
1
>
{},
number
<
GuaranteedLastDimensionVectorStride
>
=
number
<-
1
>
{})
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
const
auto
transforms
=
make_tuple
(
make_embed_transform
(
lengths
,
strides
));
constexpr
auto
low_dim_hidden_idss
=
make_tuple
(
sequence
<
0
>
{});
constexpr
auto
up_dim_hidden_idss
=
make_tuple
(
typename
arithmetic_sequence_gen
<
1
,
N
+
1
,
1
>::
type
{});
constexpr
auto
visible_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
1
,
N
+
1
,
1
>::
type
{};
const
auto
element_space_size
=
detail
::
calculate_element_space_size_impl
(
lengths
,
strides
,
number
<
0
>
{},
long_number
<
1
>
{});
using
GuaranteedVectorLengths
=
typename
sequence_merge
<
typename
uniform_sequence_gen
<
N
,
-
1
>::
type
,
sequence
<
GuaranteedLastDimensionVectorLength
>>::
type
;
using
GuaranteedVectorStrides
=
typename
sequence_merge
<
typename
uniform_sequence_gen
<
N
,
-
1
>::
type
,
sequence
<
GuaranteedLastDimensionVectorStride
>>::
type
;
return
tensor_descriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>
,
GuaranteedVectorLengths
,
GuaranteedVectorStrides
>
{
transforms
,
element_space_size
};
}
// tensor descriptor with offset, the offset will not be added into element space size
// only have an information of the starting offset, and will impact on offset calculation
template
<
typename
...
Lengths
,
typename
...
Strides
,
typename
offset
,
index_t
GuaranteedLastDimensionVectorLength
=
-
1
,
index_t
GuaranteedLastDimensionVectorStride
=
-
1
,
typename
std
::
enable_if
<
sizeof
...(
Lengths
)
==
sizeof
...(
Strides
),
bool
>
::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_naive_tensor_descriptor_with_offset
(
const
tuple
<
Lengths
...
>&
lengths
,
const
tuple
<
Strides
...
>&
strides
,
const
offset
&
os
,
number
<
GuaranteedLastDimensionVectorLength
>
=
number
<-
1
>
{},
number
<
GuaranteedLastDimensionVectorStride
>
=
number
<-
1
>
{})
{
const
auto
desc_0
=
[
&
]()
{
const
auto
element_space_size
=
detail
::
calculate_element_space_size_impl
(
lengths
,
strides
,
number
<
0
>
{},
long_number
<
1
>
{});
const
auto
transforms
=
make_tuple
(
make_offset_transform
(
element_space_size
,
os
));
constexpr
auto
low_dim_hidden_idss
=
make_tuple
(
sequence
<
0
>
{});
constexpr
auto
up_dim_hidden_idss
=
make_tuple
(
sequence
<
1
>
{});
constexpr
auto
visible_dim_hidden_ids
=
sequence
<
1
>
{};
using
GuaranteedVectorLengths
=
typename
sequence_merge
<
typename
uniform_sequence_gen
<
1
,
-
1
>::
type
,
sequence
<
GuaranteedLastDimensionVectorLength
>>::
type
;
using
GuaranteedVectorStrides
=
typename
sequence_merge
<
typename
uniform_sequence_gen
<
1
,
-
1
>::
type
,
sequence
<
GuaranteedLastDimensionVectorStride
>>::
type
;
return
tensor_descriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>
,
GuaranteedVectorLengths
,
GuaranteedVectorStrides
>
{
transforms
,
element_space_size
};
}();
constexpr
index_t
N
=
sizeof
...(
Lengths
);
return
transform_tensor_descriptor
(
desc_0
,
make_tuple
(
make_embed_transform
(
lengths
,
strides
)),
make_tuple
(
sequence
<
0
>
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{}));
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) number<>, which is known at compile-time
// element_space_size could be:
// 1) long_index_t, or
// 2) long_number<>
template
<
typename
...
Lengths
,
index_t
GuaranteedLastDimensionVectorLength
=
-
1
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_naive_tensor_descriptor_packed
(
const
tuple
<
Lengths
...
>&
lengths
,
number
<
GuaranteedLastDimensionVectorLength
>
=
number
<-
1
>
{})
{
constexpr
index_t
N
=
sizeof
...(
Lengths
);
const
auto
transforms
=
make_tuple
(
make_unmerge_transform
(
lengths
));
constexpr
auto
low_dim_hidden_idss
=
make_tuple
(
sequence
<
0
>
{});
constexpr
auto
up_dim_hidden_idss
=
make_tuple
(
typename
arithmetic_sequence_gen
<
1
,
N
+
1
,
1
>::
type
{});
constexpr
auto
visible_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
1
,
N
+
1
,
1
>::
type
{};
const
auto
element_space_size
=
container_reduce
(
lengths
,
multiplies
{},
long_number
<
1
>
{});
using
GuaranteedVectorLengths
=
typename
sequence_merge
<
typename
uniform_sequence_gen
<
N
,
-
1
>::
type
,
sequence
<
GuaranteedLastDimensionVectorLength
>>::
type
;
using
GuaranteedVectorStrides
=
typename
sequence_merge
<
typename
uniform_sequence_gen
<
N
,
-
1
>::
type
,
sequence
<
1
>>::
type
;
return
tensor_descriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>
,
GuaranteedVectorLengths
,
GuaranteedVectorStrides
>
{
transforms
,
element_space_size
};
}
template
<
typename
...
Lengths
,
typename
...
Strides
,
typename
Offset
,
index_t
GuaranteedLastDimensionVectorLength
=
-
1
,
typename
std
::
enable_if
<
sizeof
...(
Lengths
)
==
sizeof
...(
Strides
),
bool
>
::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_naive_tensor_descriptor_packed_with_offset
(
const
tuple
<
Lengths
...
>&
lengths
,
const
Offset
&
offset
,
number
<
GuaranteedLastDimensionVectorLength
>
=
number
<-
1
>
{})
{
const
auto
desc_0
=
[
&
]()
{
const
auto
element_space_size
=
container_reduce
(
lengths
,
multiplies
{},
long_number
<
1
>
{});
const
auto
transforms
=
make_tuple
(
make_offset_transform
(
element_space_size
,
offset
));
constexpr
auto
low_dim_hidden_idss
=
make_tuple
(
sequence
<
0
>
{});
constexpr
auto
up_dim_hidden_idss
=
make_tuple
(
sequence
<
1
>
{});
constexpr
auto
visible_dim_hidden_ids
=
sequence
<
1
>
{};
using
GuaranteedVectorLengths
=
typename
sequence_merge
<
typename
uniform_sequence_gen
<
1
,
-
1
>::
type
,
sequence
<
GuaranteedLastDimensionVectorLength
>>::
type
;
using
GuaranteedVectorStrides
=
typename
sequence_merge
<
typename
uniform_sequence_gen
<
1
,
-
1
>::
type
,
sequence
<
1
>>::
type
;
return
tensor_descriptor
<
remove_cv_t
<
decltype
(
transforms
)
>
,
remove_cv_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cv_t
<
decltype
(
visible_dim_hidden_ids
)
>
,
remove_cv_t
<
decltype
(
element_space_size
)
>
,
GuaranteedVectorLengths
,
GuaranteedVectorStrides
>
{
transforms
,
element_space_size
};
}();
constexpr
index_t
N
=
sizeof
...(
Lengths
);
return
transform_tensor_descriptor
(
desc_0
,
make_tuple
(
make_unmerge_transform
(
lengths
)),
make_tuple
(
sequence
<
0
>
{}),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{}));
}
// Lengths... could be:
// 1) index_t, which is known at run-time, or
// 2) number<>, which is known at compile-time
// align could be:
// 1) index_t, or
// 2) number<>
template
<
typename
...
Lengths
,
typename
Align
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_naive_tensor_descriptor_aligned
(
const
tuple
<
Lengths
...
>&
lengths
,
Align
align
)
{
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
index_t
N
=
sizeof
...(
Lengths
);
const
auto
stride_n_minus_2
=
integer_least_multiple
(
lengths
[
number
<
N
-
1
>
{}],
align
);
auto
strides
=
generate_tuple
(
[
&
](
auto
i
)
{
if
constexpr
(
i
.
value
==
N
-
1
)
{
return
I1
;
}
else
if
constexpr
(
i
.
value
==
N
-
2
)
{
return
number
<
stride_n_minus_2
>
{};
}
else
{
return
container_reduce
(
lengths
,
multiplies
{},
number
<
stride_n_minus_2
>
{},
i
+
I1
,
number
<
N
-
1
>
{},
I1
);
}
},
number
<
N
>
{});
return
make_naive_tensor_descriptor
(
lengths
,
strides
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tensor_view.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
BufferView_
,
typename
TensorDesc_
>
struct
tensor_view
{
using
buffer_view
=
remove_reference_t
<
BufferView_
>
;
using
DataType
=
typename
buffer_view
::
type
;
using
TensorDesc
=
remove_cvref_t
<
TensorDesc_
>
;
using
TensorIndex
=
array
<
index_t
,
TensorDesc
::
get_num_of_top_dimension
()
>
;
using
TensorCoord
=
decltype
(
make_tensor_coordinate
(
TensorDesc
{},
TensorIndex
{}));
CK_TILE_HOST_DEVICE
constexpr
tensor_view
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
tensor_view
(
const
buffer_view
&
buffer_view
,
const
TensorDesc
&
desc
)
:
buf_
{
buffer_view
},
desc_
{
desc
}
{
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_tensor_descriptor
()
const
{
return
desc_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
{
return
TensorDesc
::
get_num_of_top_dimension
();
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_buffer_view
()
const
{
return
buf_
;
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_buffer_view
()
{
return
buf_
;
}
#if 0
CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const
{
return buf_.template get<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x)
{
buf_.template set<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
#endif
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
X
>
get_vectorized_elements
(
const
TensorCoord
&
coord
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
return
buf_
.
template
get
<
X
>(
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
const
TensorCoord
&
coord
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
>(
dst
,
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
));
}
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements
(
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tensor_view{"
);
// buf_
printf
(
"buf_: "
);
print
(
buf_
);
printf
(
", "
);
// desc_
printf
(
"desc_: "
);
print
(
desc_
);
printf
(
"}"
);
}
// member
buffer_view
buf_
;
TensorDesc
desc_
;
};
// placeholder type if we want to opt-out a tile view parameter
struct
null_tensor_view
{
};
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
typename
DataType
,
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tensor_view
(
DataType
*
p
,
const
tensor_descriptor
<
Ts
...
>&
desc
)
{
auto
buffer_view
=
make_buffer_view
<
BufferAddressSpace
>
(
p
,
desc
.
get_element_space_size
());
return
tensor_view
<
decltype
(
buffer_view
),
decltype
(
desc
)
>
{
buffer_view
,
desc
};
}
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
typename
DataType
,
typename
...
Lengths
,
typename
...
Strides
,
index_t
GuaranteedLastDimensionVectorLength
=
-
1
,
index_t
GuaranteedLastDimensionVectorStride
=
-
1
,
typename
std
::
enable_if
<
sizeof
...(
Lengths
)
==
sizeof
...(
Strides
),
bool
>
::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_naive_tensor_view
(
DataType
*
p
,
const
tuple
<
Lengths
...
>&
lengths
,
const
tuple
<
Strides
...
>&
strides
,
number
<
GuaranteedLastDimensionVectorLength
>
=
number
<-
1
>
{},
number
<
GuaranteedLastDimensionVectorStride
>
=
number
<-
1
>
{})
{
auto
desc
=
make_naive_tensor_descriptor
(
lengths
,
strides
,
number
<
GuaranteedLastDimensionVectorLength
>
{},
number
<
GuaranteedLastDimensionVectorStride
>
{});
auto
buffer_view
=
make_buffer_view
<
BufferAddressSpace
>
(
p
,
desc
.
get_element_space_size
());
return
tensor_view
<
decltype
(
buffer_view
),
decltype
(
desc
)
>
{
buffer_view
,
desc
};
}
template
<
address_space_enum
BufferAddressSpace
=
address_space_enum
::
generic
,
typename
DataType
,
typename
...
Lengths
,
index_t
GuaranteedLastDimensionVectorLength
=
-
1
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_naive_tensor_view_packed
(
DataType
*
p
,
const
tuple
<
Lengths
...
>&
lengths
,
number
<
GuaranteedLastDimensionVectorLength
>
=
number
<-
1
>
{})
{
auto
desc
=
make_naive_tensor_descriptor_packed
(
lengths
,
number
<
GuaranteedLastDimensionVectorLength
>
{});
auto
buffer_view
=
make_buffer_view
<
BufferAddressSpace
>
(
p
,
desc
.
get_element_space_size
());
return
tensor_view
<
decltype
(
buffer_view
),
decltype
(
desc
)
>
{
buffer_view
,
desc
};
}
template
<
typename
OldTensorView
,
typename
NewTransforms
,
typename
NewLowerDimensionOldVisibleIdss
,
typename
NewUpperDimensionNewVisibleIdss
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tensor_view
(
const
OldTensorView
&
old_tensor_view
,
const
NewTransforms
&
new_transforms
,
NewLowerDimensionOldVisibleIdss
,
NewUpperDimensionNewVisibleIdss
)
{
auto
new_desc
=
transform_tensor_descriptor
(
old_tensor_view
.
desc_
,
new_transforms
,
NewLowerDimensionOldVisibleIdss
{},
NewUpperDimensionNewVisibleIdss
{});
return
tensor_view
<
typename
OldTensorView
::
buffer_view
,
remove_cvref_t
<
decltype
(
new_desc
)
>>
{
old_tensor_view
.
buf_
,
new_desc
};
}
template
<
typename
TensorView
,
typename
TileLengths
,
// tuple<...>
typename
DoPads
>
// sequence<bool, bool, ...>
CK_TILE_HOST_DEVICE
constexpr
auto
pad_tensor_view
(
const
TensorView
&
tensor_view
,
const
TileLengths
&
tile_lengths
,
DoPads
)
{
constexpr
index_t
num_dim
=
DoPads
::
size
();
static_assert
(
num_dim
==
TileLengths
::
size
()
&&
num_dim
==
TensorView
::
get_num_of_dimension
(),
"wrong! inconsistent # of dimensions"
);
// transforms
const
auto
transforms
=
generate_tuple
(
[
&
](
auto
idim
)
{
const
auto
old_length
=
tensor_view
.
get_tensor_descriptor
().
get_length
(
idim
);
const
auto
tile_length
=
tile_lengths
[
idim
];
const
auto
new_length
=
integer_divide_ceil
(
old_length
,
tile_length
)
*
tile_length
;
const
auto
pad_length
=
new_length
-
old_length
;
constexpr
bool
DoPad
=
DoPads
::
at
(
idim
);
const
auto
transform
=
conditional_expr
<
DoPad
>
(
make_right_pad_transform
(
old_length
,
pad_length
),
make_pass_through_transform
(
old_length
));
return
transform
;
},
number
<
num_dim
>
{});
// lower dimension Id
const
auto
lower_dimss
=
generate_tuple
([
&
](
auto
idim
)
{
return
sequence
<
idim
.
value
>
{};
},
number
<
num_dim
>
{});
// upper dimension Id
const
auto
upper_dimss
=
lower_dimss
;
return
transform_tensor_view
(
tensor_view
,
transforms
,
lower_dimss
,
upper_dimss
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_distribution.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// distributed span
template
<
index_t
...
PartialHsLengths
>
struct
tile_distributed_span
{
using
Impl
=
sequence
<
PartialHsLengths
...
>
;
static
constexpr
auto
impl_
=
Impl
{};
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
return
true
;
}
};
// distributed index
template
<
index_t
...
PartialHsIndices
>
struct
tile_distributed_index
{
using
Impl
=
sequence
<
PartialHsIndices
...
>
;
static
constexpr
auto
impl_
=
Impl
{};
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
return
true
;
}
};
namespace
detail
{
template
<
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tile_distributed_span
(
sequence
<
Is
...
>
)
{
return
tile_distributed_span
<
Is
...
>
{};
}
template
<
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tile_distributed_index
(
sequence
<
Is
...
>
)
{
return
tile_distributed_index
<
Is
...
>
{};
}
}
// namespace detail
template
<
typename
PsYs2XsAdaptor_
,
typename
Ys2DDescriptor_
,
typename
StaticTileDistributionEncoding_
,
typename
TileDistributionDetail_
>
// FIXME: this is for hold ad-hoc but useful info,
// should be more elegnat
struct
tile_distribution
{
using
PsYs2XsAdaptor
=
remove_cvref_t
<
PsYs2XsAdaptor_
>
;
using
Ys2DDescriptor
=
remove_cvref_t
<
Ys2DDescriptor_
>
;
using
DstrEncode
=
remove_cvref_t
<
StaticTileDistributionEncoding_
>
;
using
DstrDetail
=
remove_cvref_t
<
TileDistributionDetail_
>
;
static_assert
(
PsYs2XsAdaptor
::
is_static
()
&&
Ys2DDescriptor
::
is_static
(),
"wrong! should be static"
);
static
constexpr
index_t
NDimX
=
PsYs2XsAdaptor
::
get_num_of_bottom_dimension
();
static
constexpr
index_t
NDimY
=
Ys2DDescriptor
::
get_num_of_top_dimension
();
static
constexpr
index_t
NDimP
=
PsYs2XsAdaptor
::
get_num_of_top_dimension
()
-
NDimY
;
static
constexpr
index_t
NDimR
=
StaticTileDistributionEncoding_
::
NDimR
;
PsYs2XsAdaptor
ps_ys_to_xs_
;
Ys2DDescriptor
ys_to_d_
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_x
()
{
return
NDimX
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_y
()
{
return
NDimY
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_p
()
{
return
NDimP
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_r
()
{
return
NDimR
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lengths
()
{
#if 0
// FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
ps_ys_to_xs_.GetBottomDimensionLengths();
#else
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
index_t
x_length
=
container_reduce
(
typename
DstrEncode
::
HsLengthss
{}[
i
],
multiplies
{},
1
);
return
number
<
x_length
>
{};
},
number
<
NDimX
>
{});
#endif
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_ps_ys_to_xs_adaptor
()
const
{
return
ps_ys_to_xs_
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_ys_to_d_descriptor
()
const
{
return
ys_to_d_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_static_tile_distribution_encoding
()
{
return
DstrEncode
{};
}
#if 1
// Calculate Replication index [R0, R1, ...] based on Partion index
// FIXME: very nasty implementation
template
<
typename
PartitionIndex
>
CK_TILE_HOST_DEVICE
auto
calculate_rs_index_from_ps_index
(
const
PartitionIndex
&
ps_idx
)
const
{
static_assert
(
PartitionIndex
::
size
()
==
NDimP
,
"wrong!"
);
const
auto
ps_ys_idx
=
container_concat
(
ps_idx
,
array
<
index_t
,
NDimY
>
{
0
});
const
auto
dummy_adaptor_coord
=
make_tensor_adaptor_coordinate
(
ps_ys_to_xs_
,
ps_ys_idx
);
array
<
index_t
,
NDimR
>
rs_idx
;
static_for
<
0
,
NDimP
,
1
>
{}([
&
](
auto
idim_p
)
{
constexpr
index_t
ndim_low
=
DstrEncode
::
ps_to_rhss_major_
[
idim_p
].
size
();
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
rh_major
=
DstrEncode
::
ps_to_rhss_major_
[
idim_p
][
i
];
constexpr
index_t
rh_minor
=
DstrEncode
::
ps_to_rhss_minor_
[
idim_p
][
i
];
// 0-th rh_major is the replicate dimension
if
constexpr
(
rh_major
==
0
)
{
constexpr
index_t
adaptor_hidden_id
=
DstrDetail
::
rh_major_minor_to_adaptor_hidden_idss_
[
rh_major
][
rh_minor
];
// fill in
rs_idx
(
rh_minor
)
=
dummy_adaptor_coord
.
get_hidden_index
()[
adaptor_hidden_id
];
}
});
});
return
rs_idx
;
}
#endif
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_distributed_spans
()
{
constexpr
auto
distributed_spans_impl
=
DstrEncode
::
detail
::
distributed_spans_lengthss_
;
constexpr
auto
ndims_spans_minor
=
DstrEncode
::
detail
::
ndims_distributed_spans_minor_
;
return
generate_tuple
(
[
&
](
auto
i
)
{
constexpr
auto
span_impl
=
distributed_spans_impl
[
i
];
constexpr
index_t
ndim_span_minor
=
ndims_spans_minor
[
i
];
constexpr
auto
span
=
TO_SEQUENCE
(
span_impl
,
ndim_span_minor
);
return
detail
::
make_tile_distributed_span
(
span
);
},
number
<
NDimX
>
{});
}
// FIXME: it's hacky to get Y index from Distributed-Index
template
<
typename
DistributedIndices
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_y_indices_from_distributed_indices
(
DistributedIndices
)
{
constexpr
auto
ys_idx_arr
=
[]
{
array
<
index_t
,
NDimY
>
ys_idx
;
static_for
<
0
,
NDimY
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
span_major
=
DstrEncode
::
detail
::
ys_to_span_major_
[
i
];
constexpr
index_t
span_minor
=
DstrEncode
::
detail
::
ys_to_span_minor_
[
i
];
constexpr
auto
dstr_index
=
DistributedIndices
{}[
number
<
span_major
>
{}];
ys_idx
(
i
)
=
dstr_index
.
impl_
[
span_minor
];
});
return
ys_idx
;
}();
constexpr
index_t
ndim_y
=
NDimY
;
return
TO_SEQUENCE
(
ys_idx_arr
,
ndim_y
);
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
return
PsYs2XsAdaptor
::
is_static
()
&&
Ys2DDescriptor
::
is_static
();
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tile_distribution{"
);
//
printf
(
"tile_distribution_encoding: "
);
print
(
DstrEncode
{});
printf
(
", "
);
//
printf
(
"ps_ys_to_xs_: "
);
print
(
ps_ys_to_xs_
);
printf
(
", "
);
//
printf
(
"ys_to_d_: "
);
print
(
ys_to_d_
);
//
printf
(
"}"
);
}
};
namespace
detail
{
template
<
index_t
NDimMax
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_sequential_index
(
index_t
ibegin
,
index_t
iend
)
{
array
<
index_t
,
NDimMax
>
arr
{
0
};
for
(
index_t
i
=
0
;
i
<
iend
-
ibegin
;
++
i
)
{
arr
(
i
)
=
ibegin
+
i
;
}
return
arr
;
}
// this returns a constexpr encoding of tile_distribution
template
<
typename
StaticTileDistributionEncoding_
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_adaptor_encoding_for_tile_distribution
(
StaticTileDistributionEncoding_
)
{
using
RsLengths
=
typename
StaticTileDistributionEncoding_
::
RsLengths
;
using
HsLengthss
=
typename
StaticTileDistributionEncoding_
::
HsLengthss
;
using
Ps2RHssMajor
=
typename
StaticTileDistributionEncoding_
::
Ps2RHssMajor
;
using
Ps2RHssMinor
=
typename
StaticTileDistributionEncoding_
::
Ps2RHssMinor
;
using
Ys2RHsMajor
=
typename
StaticTileDistributionEncoding_
::
Ys2RHsMajor
;
using
Ys2RHsMinor
=
typename
StaticTileDistributionEncoding_
::
Ys2RHsMinor
;
// FIXME: increase max value if fail
constexpr
index_t
kMaxNumTransforms
=
20
;
constexpr
index_t
kMaxMetaDataSize
=
128
;
constexpr
index_t
kMaxNumDim
=
10
;
using
Name
=
coord_transform_enum
;
using
MetaData
=
meta_data_buffer
<
kMaxMetaDataSize
>
;
using
NumDim
=
index_t
;
using
Dims
=
array
<
index_t
,
kMaxNumDim
>
;
using
Lengths
=
array
<
index_t
,
kMaxNumDim
>
;
// Tile Adaptor
// bottom dims [x0, x1, x2, ...]
// top dims [p0, p1, ..., y0, y1, ...]
constexpr
index_t
ndim_x
=
HsLengthss
::
size
();
// Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
array
<
array
<
index_t
,
kMaxNumDim
>
,
ndim_x
+
1
>
rh_major_minor_to_hidden_ids
;
array
<
array
<
index_t
,
kMaxNumDim
>
,
ndim_x
+
1
>
rh_major_minor_to_hidden_lengths
;
auto
trans
=
array
<
tuple
<
Name
,
MetaData
,
NumDim
,
Dims
,
NumDim
,
Dims
>
,
kMaxNumTransforms
>
{};
index_t
num_tran
=
0
;
index_t
hidden_dim_cnt
=
ndim_x
;
// this is replicate transform
{
constexpr
index_t
ndim_r_minor
=
RsLengths
::
size
();
constexpr
auto
r_minor_lengths
=
RsLengths
{};
trans
(
num_tran
++
)
=
{
coord_transform_enum
::
replicate
,
MetaData
{
to_array
<
index_t
,
ndim_r_minor
>
(
r_minor_lengths
)},
NumDim
{
0
},
Dims
{},
NumDim
{
ndim_r_minor
},
make_sequential_index
<
kMaxNumDim
>
(
hidden_dim_cnt
,
hidden_dim_cnt
+
ndim_r_minor
)};
for
(
index_t
i
=
0
;
i
<
ndim_r_minor
;
++
i
)
{
rh_major_minor_to_hidden_ids
(
0
)(
i
)
=
hidden_dim_cnt
;
rh_major_minor_to_hidden_lengths
(
0
)(
i
)
=
r_minor_lengths
[
i
];
hidden_dim_cnt
++
;
}
};
// these are Unmerge transforms for X dimesions
static_for
<
0
,
ndim_x
,
1
>
{}([
&
trans
,
&
num_tran
,
&
hidden_dim_cnt
,
&
rh_major_minor_to_hidden_ids
,
&
rh_major_minor_to_hidden_lengths
](
auto
idim_x
)
{
// typename HsLengthss::base{}.foo();
constexpr
auto
h_minor_lengths
=
HsLengthss
{}.
get
(
idim_x
);
// std::tuple_element_t<idim_x, HsLengthss>{};
// constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
constexpr
index_t
ndim_h_minor
=
h_minor_lengths
.
size
();
trans
(
num_tran
++
)
=
{
coord_transform_enum
::
unmerge
,
MetaData
{
to_array
<
index_t
,
ndim_h_minor
>
(
h_minor_lengths
)},
NumDim
{
1
},
Dims
{
idim_x
},
NumDim
{
ndim_h_minor
},
make_sequential_index
<
kMaxNumDim
>
(
hidden_dim_cnt
,
hidden_dim_cnt
+
ndim_h_minor
)};
for
(
index_t
i
=
0
;
i
<
ndim_h_minor
;
++
i
)
{
rh_major_minor_to_hidden_ids
(
idim_x
+
1
)(
i
)
=
hidden_dim_cnt
;
rh_major_minor_to_hidden_lengths
(
idim_x
+
1
)(
i
)
=
h_minor_lengths
[
i
];
hidden_dim_cnt
++
;
}
});
// transform: P dimensions
constexpr
index_t
ndim_p
=
Ps2RHssMajor
::
size
();
Dims
hidden_dim_id_ps
;
static_for
<
0
,
ndim_p
,
1
>
{}([
&
](
auto
iDimP
)
{
//
index_t
hidden_dim_id_p
=
hidden_dim_cnt
++
;
hidden_dim_id_ps
(
iDimP
)
=
hidden_dim_id_p
;
constexpr
auto
p2RHsMajor
=
Ps2RHssMajor
{}[
iDimP
];
constexpr
auto
p2RHsMinor
=
Ps2RHssMinor
{}[
iDimP
];
static_assert
(
p2RHsMajor
.
size
()
==
p2RHsMinor
.
size
(),
"wrong!"
);
constexpr
index_t
ndim_low
=
p2RHsMajor
.
size
();
Dims
low_dims
;
Lengths
low_lengths
;
for
(
index_t
i
=
0
;
i
<
ndim_low
;
++
i
)
{
index_t
rh_major
=
p2RHsMajor
[
i
];
index_t
rh_minor
=
p2RHsMinor
[
i
];
low_dims
(
i
)
=
rh_major_minor_to_hidden_ids
[
rh_major
][
rh_minor
];
low_lengths
(
i
)
=
rh_major_minor_to_hidden_lengths
[
rh_major
][
rh_minor
];
}
trans
(
num_tran
++
)
=
{
coord_transform_enum
::
merge
,
MetaData
{
to_array
<
index_t
,
ndim_low
>
(
low_lengths
)},
NumDim
{
ndim_low
},
low_dims
,
NumDim
{
1
},
Dims
{
hidden_dim_id_p
}};
});
constexpr
index_t
ndim_bottom
=
ndim_x
;
constexpr
auto
bottom_dim_ids
=
make_sequential_index
<
kMaxNumDim
>
(
0
,
ndim_bottom
);
constexpr
auto
ys_to_rhs_major
=
Ys2RHsMajor
{};
constexpr
auto
ys_to_rhs_minor
=
Ys2RHsMinor
{};
constexpr
index_t
ndim_y
=
Ys2RHsMajor
::
size
();
constexpr
index_t
ndim_top
=
ndim_p
+
ndim_y
;
auto
top_dim_ids
=
hidden_dim_id_ps
;
{
for
(
index_t
i
=
0
;
i
<
ndim_y
;
++
i
)
{
index_t
rh_major
=
ys_to_rhs_major
[
i
];
index_t
rh_minor
=
ys_to_rhs_minor
[
i
];
top_dim_ids
(
ndim_p
+
i
)
=
rh_major_minor_to_hidden_ids
[
rh_major
][
rh_minor
];
}
}
//
const
auto
ps_ys_to_xs_adaptor_encoding
=
make_tuple
(
trans
,
num_tran
,
bottom_dim_ids
,
ndim_bottom
,
top_dim_ids
,
ndim_top
);
// descriptor: [y0, y1, ...] to [d]
Lengths
y_lengths
;
index_t
d_length
=
1
;
for
(
index_t
i
=
0
;
i
<
ndim_y
;
++
i
)
{
index_t
rh_major
=
ys_to_rhs_major
[
i
];
index_t
rh_minor
=
ys_to_rhs_minor
[
i
];
index_t
y_length
=
rh_major_minor_to_hidden_lengths
[
rh_major
][
rh_minor
];
y_lengths
(
i
)
=
y_length
;
d_length
*=
y_length
;
}
auto
tran
=
make_tuple
(
coord_transform_enum
::
unmerge
,
MetaData
{
to_array
<
index_t
,
ndim_y
>
(
y_lengths
)},
NumDim
{
1
},
Dims
{
0
},
NumDim
{
ndim_y
},
make_sequential_index
<
kMaxNumDim
>
(
1
,
ndim_y
+
1
));
const
auto
ys_to_d_adaptor_encoding
=
make_tuple
(
make_tuple
(
tran
),
1
,
Dims
{
0
},
1
,
make_sequential_index
<
kMaxNumDim
>
(
1
,
ndim_y
+
1
),
ndim_y
);
return
make_tuple
(
ps_ys_to_xs_adaptor_encoding
,
ys_to_d_adaptor_encoding
,
d_length
,
rh_major_minor_to_hidden_ids
);
}
// FIXME: this is nasty. Move it inside TileDistributionEncoding::detail
template
<
typename
RhMajorMinor2AdaptorHiddenIdss
>
// tuple<sequence<...>, ...>
struct
tile_distribution_detail
{
static
constexpr
auto
rh_major_minor_to_adaptor_hidden_idss_
=
to_array_of_array
(
RhMajorMinor2AdaptorHiddenIdss
{});
};
}
// namespace detail
// this returns a constexpr tile_distribution
template
<
typename
StaticTileDistributionEncoding_
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tile_distribution
(
StaticTileDistributionEncoding_
)
{
using
DstrEncode
=
remove_cvref_t
<
StaticTileDistributionEncoding_
>
;
constexpr
auto
adaptor_impl
=
detail
::
make_adaptor_encoding_for_tile_distribution
(
StaticTileDistributionEncoding_
{});
constexpr
auto
ps_ys_to_xs_adaptor_impl
=
adaptor_impl
.
template
at
<
0
>();
constexpr
auto
ys_to_d_adaptor_impl
=
adaptor_impl
.
template
at
<
1
>();
constexpr
index_t
d_length
=
adaptor_impl
.
template
at
<
2
>();
constexpr
auto
rh_major_minor_to_hidden_ids_impl
=
adaptor_impl
.
template
at
<
3
>();
constexpr
auto
ps_ys_to_xs_adaptor
=
CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING
(
ps_ys_to_xs_adaptor_impl
);
constexpr
auto
ys_to_d_adaptor
=
CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING
(
ys_to_d_adaptor_impl
);
constexpr
auto
ys_to_d_descriptor
=
make_tensor_descriptor_from_adaptor
(
ys_to_d_adaptor
,
d_length
);
//
constexpr
index_t
ndim_rh_major
=
DstrEncode
::
detail
::
ndim_rh_major_
;
constexpr
auto
ndims_rhs_minor
=
DstrEncode
::
detail
::
ndims_rhs_minor_
;
constexpr
auto
rh_major_minor_to_hidden_ids
=
TO_TUPLE_OF_SEQUENCE
(
rh_major_minor_to_hidden_ids_impl
,
ndim_rh_major
,
ndims_rhs_minor
);
return
tile_distribution
<
remove_cvref_t
<
decltype
(
ps_ys_to_xs_adaptor
)
>
,
remove_cvref_t
<
decltype
(
ys_to_d_descriptor
)
>
,
remove_cvref_t
<
DstrEncode
>
,
detail
::
tile_distribution_detail
<
remove_cvref_t
<
decltype
(
rh_major_minor_to_hidden_ids
)
>>>
{
ps_ys_to_xs_adaptor
,
ys_to_d_descriptor
};
}
// this returns a static tile_distribution
template
<
typename
StaticTileDistributionEncoding_
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_static_tile_distribution
(
StaticTileDistributionEncoding_
)
{
using
DstrEncode
=
remove_cvref_t
<
StaticTileDistributionEncoding_
>
;
constexpr
auto
adaptor_impl
=
detail
::
make_adaptor_encoding_for_tile_distribution
(
StaticTileDistributionEncoding_
{});
constexpr
auto
ps_ys_to_xs_adaptor_impl
=
adaptor_impl
.
template
at
<
0
>();
constexpr
auto
ys_to_d_adaptor_impl
=
adaptor_impl
.
template
at
<
1
>();
constexpr
index_t
d_length
=
adaptor_impl
.
template
at
<
2
>();
constexpr
auto
rh_major_minor_to_hidden_ids_impl
=
adaptor_impl
.
template
at
<
3
>();
constexpr
auto
ps_ys_to_xs_adaptor
=
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING
(
ps_ys_to_xs_adaptor_impl
);
constexpr
auto
ys_to_d_adaptor
=
CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING
(
ys_to_d_adaptor_impl
);
constexpr
auto
ys_to_d_descriptor
=
make_tensor_descriptor_from_adaptor
(
ys_to_d_adaptor
,
number
<
d_length
>
{});
//
constexpr
index_t
ndim_rh_major
=
DstrEncode
::
detail
::
ndim_rh_major_
;
constexpr
auto
ndims_rhs_minor
=
DstrEncode
::
detail
::
ndims_rhs_minor_
;
constexpr
auto
rh_major_minor_to_hidden_ids
=
TO_TUPLE_OF_SEQUENCE
(
rh_major_minor_to_hidden_ids_impl
,
ndim_rh_major
,
ndims_rhs_minor
);
return
tile_distribution
<
remove_cvref_t
<
decltype
(
ps_ys_to_xs_adaptor
)
>
,
remove_cvref_t
<
decltype
(
ys_to_d_descriptor
)
>
,
remove_cvref_t
<
DstrEncode
>
,
detail
::
tile_distribution_detail
<
remove_cvref_t
<
decltype
(
rh_major_minor_to_hidden_ids
)
>>>
{
ps_ys_to_xs_adaptor
,
ys_to_d_descriptor
};
}
//***********************************************************************************
namespace
detail
{
template
<
typename
Distribution
>
CK_TILE_HOST_DEVICE
auto
get_partition_index
(
Distribution
)
{
// only support warp-tile and block-tile
static_assert
(
Distribution
::
NDimP
==
1
or
Distribution
::
NDimP
==
2
,
"wrong!"
);
if
constexpr
(
Distribution
::
NDimP
==
1
)
{
return
array
<
index_t
,
1
>
{
get_lane_id
()};
}
else
if
constexpr
(
Distribution
::
NDimP
==
2
)
{
return
array
<
index_t
,
2
>
{
get_warp_id
(),
get_lane_id
()};
}
}
template
<
typename
,
typename
,
typename
,
index_t
>
struct
reverse_slice_sequence_impl
;
template
<
index_t
x
,
index_t
...
xs
,
index_t
m
,
index_t
...
ms
,
index_t
id
,
index_t
...
ids
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
,
xs
...
>
,
sequence
<
m
,
ms
...
>
,
sequence
<
id
,
ids
...
>
,
SliceSize
>
{
using
old_scan
=
reverse_slice_sequence_impl
<
sequence
<
xs
...
>
,
sequence
<
ms
...
>
,
sequence
<
ids
...
>
,
SliceSize
>
;
static
constexpr
auto
slice_size
=
old_scan
::
remaining_slice_sizes
::
front
().
value
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
typename
sequence_merge
<
sequence
<
slice_length
>
,
typename
old_scan
::
dim_lengths
>::
type
;
using
dim_slices
=
typename
sequence_merge
<
sequence
<
x
/
slice_length
>
,
typename
old_scan
::
dim_slices
>::
type
;
using
remaining_slice_sizes
=
typename
sequence_merge
<
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
,
typename
old_scan
::
remaining_slice_sizes
>::
type
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
_split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
_split_idx
=
std
::
conditional_t
<
_split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_flag
=
_split_flag
||
old_scan
::
split_flag
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
old_scan
::
split_flag
,
number
<
old_scan
::
split_idx
>
,
number
<
_split_idx
>>::
value
;
};
template
<
index_t
x
,
index_t
m
,
index_t
id
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
>
,
sequence
<
m
>
,
sequence
<
id
>
,
SliceSize
>
{
static
constexpr
auto
slice_size
=
SliceSize
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
sequence
<
slice_length
>
;
using
dim_slices
=
sequence
<
x
/
slice_length
>
;
using
remaining_slice_sizes
=
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
};
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template
<
typename
Seq
,
index_t
SliceSize
,
typename
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>
::
type
>
constexpr
auto
reverse_slice_sequence
(
Seq
,
number
<
SliceSize
>
,
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>::
type
{})
{
static_assert
(
Seq
::
size
()
==
Mask
::
size
());
using
sliced_type
=
reverse_slice_sequence_impl
<
Seq
,
Mask
,
typename
arithmetic_sequence_gen
<
0
,
Seq
::
size
(),
1
>::
type
,
SliceSize
>
;
static_assert
(
sliced_type
::
remaining_slice_sizes
::
front
().
value
==
1
,
"can not evenly divide this sequence, please check"
);
return
make_tuple
(
typename
sliced_type
::
dim_lengths
{},
typename
sliced_type
::
dim_slices
{},
number
<
sliced_type
::
split_idx
>
{});
}
//
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// We don't support slice cross p_dim (aka, slice different threads)
// also, sliced along y_dim need be the first dim of current dim.
// Multiply Y dim before sliced dim does not make sense
//
// e.g
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
// |--> slice along this Y dim, is the first dim of X1, totally 4 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
// |--> slice along this Y dim, the P dim is 1 in the left, so is OK
// totally 16 slices
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
// |--> slice along this P dim, will split threads, not supported
//
// X0 X1
// <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
// Y P P Y P Y P Y
// => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
// |--> slice along this Y dim, but this Y sim need to split into 2
// subdime
// the P dim in the left is 1, means actually not crossing P
//
template
<
typename
Distribution
,
index_t
...
XSliceBegins
,
index_t
...
XSliceEnds
>
CK_TILE_HOST_DEVICE
constexpr
auto
slice_distribution_from_x
(
Distribution
,
sequence
<
XSliceBegins
...
>
x_slice_begins
,
sequence
<
XSliceEnds
...
>
x_slice_ends
)
{
// NOTE: this function need to be called under constexpr context,
// due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
using
Encoding
=
decltype
(
Distribution
::
get_static_tile_distribution_encoding
());
static_assert
(
sizeof
...(
XSliceBegins
)
==
sizeof
...(
XSliceEnds
));
constexpr
auto
x_slice_lengths
=
x_slice_ends
-
x_slice_begins
;
constexpr
auto
src_h_prefix_sum
=
Encoding
::
detail
::
get_h_dim_lengths_prefix_sum
();
constexpr
auto
src_y_info
=
Encoding
::
detail
::
get_sorted_y_info
();
constexpr
auto
src_y_dims
=
src_y_info
[
number
<
0
>
{}];
constexpr
auto
src_y_maps
=
src_y_info
[
number
<
1
>
{}];
constexpr
auto
src_y_prefix_sum
=
src_y_info
[
number
<
2
>
{}];
constexpr
auto
sliced_hlen_yidx_ylen
=
[
&
]()
constexpr
{
auto
y_slice_sorted_origins
=
make_zero_multi_index
<
Encoding
::
NDimY
>
();
auto
y_slice_lengths
=
Encoding
::
detail
::
ys_lengths_
;
// This lambda will modify some value outside, so c++ will not treat return value as
// constexpr
// TODO: ugly
auto
new_h_lengths
=
transform_tuples
(
[
&
](
auto
h_len
,
auto
id
)
{
constexpr
auto
sliced_h
=
reverse_slice_sequence
(
h_len
,
number
<
x_slice_lengths
[
id
]
>
{});
constexpr
auto
sliced_h_lens
=
sliced_h
[
number
<
0
>
{}];
constexpr
auto
sliced_h_index
=
sliced_h
[
number
<
2
>
{}];
// update y_slice_lengths
constexpr
auto
uniformed_h_index
=
sliced_h_index
+
number
<
src_h_prefix_sum
[
id
]
>
{};
constexpr
auto
found_y_index
=
container_find
(
src_y_dims
,
uniformed_h_index
);
static_assert
(
found_y_index
>=
0
&&
found_y_index
<
src_y_dims
.
size
(),
"not sliced at y dim, please check"
);
static_for
<
0
,
sliced_h_index
+
1
,
1
>
{}([
&
](
auto
i
)
{
y_slice_lengths
(
src_y_maps
[
found_y_index
-
i
])
=
sliced_h_lens
[
sliced_h_index
-
i
];
});
// TODO: add validations not across p dim
// NOTE: this y_origin is for all dims, not only current dim
// will later use pick to select target dim
constexpr
auto
y_origin
=
[
&
]()
{
constexpr
auto
h_trans
=
make_merge_transform_v3_division_mod
(
h_len
);
auto
h_origin_
=
make_zero_multi_index
<
h_trans
.
NDimLow
>
();
h_trans
.
calculate_lower_index
(
h_origin_
,
sequence
<
x_slice_begins
[
id
].
value
>
{});
auto
y_origin_
=
make_zero_multi_index
<
Encoding
::
NDimY
>
();
static_for
<
0
,
sliced_h_index
+
1
,
1
>
{}([
&
](
auto
i
)
{
y_origin_
(
found_y_index
-
i
)
=
h_origin_
[
sliced_h_index
-
i
];
});
return
y_origin_
;
}();
constexpr
auto
y_picks
=
typename
arithmetic_sequence_gen
<
src_y_prefix_sum
[
id
],
src_y_prefix_sum
[
id
+
1
],
1
>::
type
{};
set_container_subset
(
y_slice_sorted_origins
,
y_picks
,
get_container_subset
(
y_origin
,
y_picks
));
return
sliced_h_lens
;
},
typename
Encoding
::
HsLengthss
{},
typename
arithmetic_sequence_gen
<
0
,
Encoding
::
HsLengthss
::
size
(),
1
>::
type
{});
auto
y_slice_origins
=
container_reorder_given_old2new
(
y_slice_sorted_origins
,
src_y_maps
);
return
make_tuple
(
new_h_lengths
,
y_slice_origins
,
y_slice_lengths
);
}
();
constexpr
auto
sliced_h_lengths
=
sliced_hlen_yidx_ylen
[
number
<
0
>
{}];
constexpr
auto
sliced_y_origins_array
=
sliced_hlen_yidx_ylen
[
number
<
1
>
{}];
constexpr
auto
sliced_y_origins_size
=
sliced_y_origins_array
.
size
();
constexpr
auto
sliced_y_lengths_array
=
sliced_hlen_yidx_ylen
[
number
<
2
>
{}];
constexpr
auto
sliced_y_lengths_size
=
sliced_y_lengths_array
.
size
();
constexpr
auto
sliced_y_origins
=
TO_SEQUENCE
(
sliced_y_origins_array
,
sliced_y_origins_size
);
constexpr
auto
sliced_y_lengths
=
TO_SEQUENCE
(
sliced_y_lengths_array
,
sliced_y_lengths_size
);
return
make_tuple
(
make_static_tile_distribution
(
tile_distribution_encoding
<
typename
Encoding
::
RsLengths
,
decltype
(
sliced_h_lengths
),
// only need to change the
// h_lengths type
typename
Encoding
::
Ps2RHssMajor
,
typename
Encoding
::
Ps2RHssMinor
,
typename
Encoding
::
Ys2RHsMajor
,
typename
Encoding
::
Ys2RHsMinor
>
{}),
sliced_y_origins
,
sliced_y_lengths
);
}
}
// namespace detail
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
RsLengths_
,
// sequence<...>
typename
HsLengthss_
,
// tuple<sequence<...>, ...>
typename
Ps2RHssMajor_
,
// tuple<sequence<...>, ...>
typename
Ps2RHssMinor_
,
// tuple<sequence<...>, ...>
typename
Ys2RHsMajor_
,
// sequence<...>
typename
Ys2RHsMinor_
>
// sequence<...>
struct
tile_distribution_encoding
{
using
RsLengths
=
remove_cvref_t
<
RsLengths_
>
;
using
HsLengthss
=
remove_cvref_t
<
HsLengthss_
>
;
using
Ps2RHssMajor
=
remove_cvref_t
<
Ps2RHssMajor_
>
;
using
Ps2RHssMinor
=
remove_cvref_t
<
Ps2RHssMinor_
>
;
using
Ys2RHsMajor
=
remove_cvref_t
<
Ys2RHsMajor_
>
;
using
Ys2RHsMinor
=
remove_cvref_t
<
Ys2RHsMinor_
>
;
static_assert
(
Ps2RHssMajor
::
size
()
==
Ps2RHssMinor
::
size
(),
"wrong!"
);
static_assert
(
Ys2RHsMajor
::
size
()
==
Ys2RHsMinor
::
size
(),
"wrong!"
);
static
constexpr
index_t
NDimX
=
HsLengthss
::
size
();
static
constexpr
index_t
NDimP
=
Ps2RHssMajor
::
size
();
static
constexpr
index_t
NDimY
=
Ys2RHsMajor
::
size
();
static
constexpr
index_t
NDimR
=
RsLengths
::
size
();
// FIXME: move into detail
static
constexpr
auto
rs_lengths_
=
RsLengths
{};
static
constexpr
auto
hs_lengthss_
=
HsLengthss
{};
static
constexpr
auto
ps_to_rhss_major_
=
Ps2RHssMajor
{};
static
constexpr
auto
ps_to_rhss_minor_
=
Ps2RHssMinor
{};
static
constexpr
auto
ys_to_rhs_major_
=
Ys2RHsMajor
{};
static
constexpr
auto
ys_to_rhs_minor_
=
Ys2RHsMinor
{};
// redundant but useful info
// TODO: really bad code, should be over-hauled
struct
detail
{
// ndim_rh_major_, ndim_span_mainor_
static
constexpr
index_t
ndim_rh_major_
=
NDimX
+
1
;
static
constexpr
index_t
ndim_span_major_
=
NDimX
;
// ndims_rhs_minor_[ndim_rh_major_]
static
constexpr
auto
ndims_rhs_minor_
=
generate_array
(
[](
auto
i
)
{
if
constexpr
(
i
.
value
==
0
)
{
return
rs_lengths_
.
size
();
}
else
{
return
hs_lengthss_
[
i
-
number
<
1
>
{}].
size
();
}
},
number
<
ndim_rh_major_
>
{});
// max_ndim_rh_minor_
static
constexpr
index_t
max_ndim_rh_minor_
=
container_reduce
(
ndims_rhs_minor_
,
maximize
<
index_t
>
{},
0
);
// rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
static
constexpr
auto
rhs_lengthss_
=
to_array_of_array
(
container_concat
(
make_tuple
(
rs_lengths_
),
hs_lengthss_
));
// ys_lengths_
static
constexpr
auto
ys_lengths_
=
[]
{
array
<
index_t
,
NDimY
>
ys_lengths_tmp
{
-
1
};
for
(
index_t
i
=
0
;
i
<
NDimY
;
i
++
)
{
index_t
rh_major
=
ys_to_rhs_major_
[
i
];
index_t
rh_minor
=
ys_to_rhs_minor_
[
i
];
ys_lengths_tmp
(
i
)
=
rhs_lengthss_
[
rh_major
][
rh_minor
];
}
return
ys_lengths_tmp
;
}();
// rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
static
constexpr
auto
rhs_major_minor_to_ys_
=
[]
{
array
<
array
<
index_t
,
max_ndim_rh_minor_
>
,
NDimX
+
1
>
rhs_major_minor_to_ys_tmp
{{
-
1
}};
static_for
<
0
,
NDimY
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
rh_major
=
ys_to_rhs_major_
[
i
];
constexpr
index_t
rh_minor
=
ys_to_rhs_minor_
[
i
];
rhs_major_minor_to_ys_tmp
(
rh_major
)(
rh_minor
)
=
i
;
});
return
rhs_major_minor_to_ys_tmp
;
}();
// ndims_span_minor_[NDimY]
static
constexpr
auto
ndims_span_minor_
=
[]
{
array
<
index_t
,
NDimX
>
ndims_span_minor
{
0
};
for
(
index_t
i
=
0
;
i
<
NDimY
;
i
++
)
{
const
index_t
span_major
=
ys_to_rhs_major_
[
i
]
-
1
;
ndims_span_minor
(
span_major
)
++
;
}
return
ndims_span_minor
;
}();
// max_ndim_span_minor_
static
constexpr
index_t
max_ndim_span_minor_
=
container_reduce
(
ndims_span_minor_
,
maximize
<
index_t
>
{},
0
);
// rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
static
constexpr
auto
rhs_major_minor_to_span_minor_
=
[]
{
array
<
array
<
index_t
,
max_ndim_rh_minor_
>
,
ndim_rh_major_
>
rhs_major_minor_to_span_minor
{
{
-
1
}};
static_for
<
0
,
ndim_rh_major_
,
1
>
{}([
&
](
auto
rh_major
)
{
constexpr
index_t
ndim_rh_minor
=
ndims_rhs_minor_
[
rh_major
];
index_t
cnt_ndim_span_minor
=
0
;
static_for
<
0
,
ndim_rh_minor
,
1
>
{}([
&
](
auto
rh_minor
)
{
constexpr
index_t
idim_y
=
rhs_major_minor_to_ys_
[
rh_major
][
rh_minor
];
if
(
idim_y
>=
0
)
{
rhs_major_minor_to_span_minor
(
rh_major
)(
rh_minor
)
=
cnt_ndim_span_minor
;
cnt_ndim_span_minor
++
;
}
});
});
return
rhs_major_minor_to_span_minor
;
}();
// ys_to_span_major_[NDimY]
static
constexpr
auto
ys_to_span_major_
=
generate_array
([](
auto
i
)
{
return
ys_to_rhs_major_
[
i
]
-
1
;
},
number
<
NDimY
>
{});
// ys_to_span_minor_[NDimY]
static
constexpr
auto
ys_to_span_minor_
=
generate_array
(
[](
auto
i
)
{
return
rhs_major_minor_to_span_minor_
[
ys_to_rhs_major_
[
i
]][
ys_to_rhs_minor_
[
i
]];
},
number
<
NDimY
>
{});
// distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
static
constexpr
auto
distributed_spans_lengthss_
=
[]
{
array
<
array
<
index_t
,
max_ndim_span_minor_
>
,
ndim_span_major_
>
distributed_spans_lengthss
{{
-
1
}};
static_for
<
0
,
NDimY
,
1
>
{}([
&
](
auto
i
)
{
const
index_t
rh_major
=
ys_to_rhs_major_
[
i
];
const
index_t
rh_minor
=
ys_to_rhs_minor_
[
i
];
const
index_t
h_length
=
hs_lengthss_
[
number
<
rh_major
-
1
>
{}][
rh_minor
];
const
index_t
span_major
=
rh_major
-
1
;
const
index_t
span_minor
=
rhs_major_minor_to_span_minor_
[
rh_major
][
rh_minor
];
distributed_spans_lengthss
(
span_major
)(
span_minor
)
=
h_length
;
});
return
distributed_spans_lengthss
;
}();
// ndims_distributed_spans_minor_[ndim_span_major_]
static
constexpr
auto
ndims_distributed_spans_minor_
=
[]
{
array
<
index_t
,
ndim_span_major_
>
ndims_distributed_spans_minor
{
0
};
static_for
<
0
,
NDimY
,
1
>
{}([
&
](
auto
i
)
{
const
index_t
span_major
=
ys_to_rhs_major_
[
i
]
-
1
;
ndims_distributed_spans_minor
(
span_major
)
++
;
});
return
ndims_distributed_spans_minor
;
}();
// does_p_own_r_[NDimP][NDimR]
static
constexpr
auto
does_p_own_r_
=
[]
{
if
constexpr
(
NDimR
>
0
)
{
array
<
array
<
bool
,
NDimR
>
,
NDimP
>
does_p_own_r
{{
false
}};
static_for
<
0
,
NDimP
,
1
>
{}([
&
](
auto
idim_p
)
{
constexpr
index_t
ndim_low
=
ps_to_rhss_major_
[
idim_p
].
size
();
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
idim_low
)
{
constexpr
index_t
rh_major
=
ps_to_rhss_major_
[
idim_p
][
idim_low
];
constexpr
index_t
rh_minor
=
ps_to_rhss_minor_
[
idim_p
][
idim_low
];
if
constexpr
(
rh_major
==
0
)
{
does_p_own_r
(
idim_p
)(
rh_minor
)
=
true
;
}
});
});
return
does_p_own_r
;
}
else
{
return
array
<
array
<
bool
,
NDimR
>
,
NDimP
>
{};
}
}();
// ps_over_rs_derivative_[NDimP][NDimR]
static
constexpr
auto
ps_over_rs_derivative_
=
[]
{
if
constexpr
(
NDimR
>
0
)
{
array
<
array
<
index_t
,
NDimR
>
,
NDimP
>
ps_over_rs_derivative
{{
0
}};
static_for
<
0
,
NDimP
,
1
>
{}([
&
](
auto
idim_p
)
{
constexpr
index_t
ndim_low
=
ps_to_rhss_major_
[
idim_p
].
size
();
index_t
p_over_rh_derivative
=
1
;
static_for
<
ndim_low
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
idim_low
)
{
constexpr
index_t
rh_major
=
ps_to_rhss_major_
[
idim_p
][
idim_low
];
constexpr
index_t
rh_minor
=
ps_to_rhss_minor_
[
idim_p
][
idim_low
];
constexpr
index_t
rh_length
=
rhs_lengthss_
[
rh_major
][
rh_minor
];
if
constexpr
(
rh_major
==
0
)
{
ps_over_rs_derivative
(
idim_p
)(
rh_minor
)
=
p_over_rh_derivative
;
}
p_over_rh_derivative
*=
rh_length
;
});
});
return
ps_over_rs_derivative
;
}
else
{
return
array
<
array
<
index_t
,
NDimR
>
,
NDimP
>
{};
}
}();
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_h_dim_lengths_prefix_sum
()
{
// <len_d0, len_d1, ...>
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
constexpr
auto
uniformed_h_dim_lengths
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
constexpr
index_t
size
=
HsLengthss
{}[
i
].
size
();
return
number
<
size
>
{};
},
number
<
NDimX
>
{});
// <0, len_d0, len_d0+len_d1, ...>
// e.g. seq<3, 5> --> seq<0, 3, 8>
constexpr
auto
h_dim_prefix_sum
=
prefix_sum_sequence
(
uniformed_h_dim_lengths
);
return
h_dim_prefix_sum
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_uniformed_idx_y_to_h
()
{
constexpr
auto
all_ys_2_rhss
=
transform_sequences
(
[](
auto
major
,
auto
minor
)
constexpr
{
// <0, 0, len_d0, len_d0+len_d1, ...>
constexpr
auto
x_dim_prefix_sum
=
merge_sequences
(
sequence
<
0
>
{}
/*for R dims*/
,
get_h_dim_lengths_prefix_sum
());
return
x_dim_prefix_sum
.
at
(
major
)
+
minor
;
},
Ys2RHsMajor
{},
Ys2RHsMinor
{});
return
all_ys_2_rhss
;
}
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
template
<
typename
IdxSeq
,
typename
PrefixSumSeq
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_sorted_info
(
IdxSeq
,
PrefixSumSeq
)
{
using
sorted_idx
=
sequence_unique_sort
<
IdxSeq
,
less
<
index_t
>
,
equal
<
index_t
>>
;
constexpr
auto
sorted_dims
=
typename
sorted_idx
::
type
{};
constexpr
auto
sorted_maps
=
typename
sorted_idx
::
sorted2unsorted_map
{};
constexpr
auto
sorted_histogram
=
histogram_sorted_sequence
(
sorted_dims
,
PrefixSumSeq
{});
constexpr
auto
sorted_prefix_sum
=
prefix_sum_sequence
(
sorted_histogram
);
return
make_tuple
(
sorted_dims
,
sorted_maps
,
sorted_prefix_sum
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_sorted_y_info
()
{
return
get_sorted_info
(
get_uniformed_idx_y_to_h
(),
get_h_dim_lengths_prefix_sum
());
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tile_distribution_encoding::detail{"
);
//
printf
(
"ndim_rh_major_: "
);
print
(
ndim_rh_major_
);
printf
(
", "
);
//
printf
(
"ndim_span_major_: "
);
print
(
ndim_span_major_
);
printf
(
", "
);
//
printf
(
"ndims_rhs_minor_: "
);
print
(
ndims_rhs_minor_
);
printf
(
", "
);
//
printf
(
"ndim_rh_major_: "
);
print
(
ndim_rh_major_
);
printf
(
", "
);
//
printf
(
"max_ndim_rh_minor_: "
);
print
(
max_ndim_rh_minor_
);
printf
(
", "
);
//
printf
(
"rhs_lengthss_: "
);
print
(
rhs_lengthss_
);
printf
(
", "
);
//
printf
(
"ys_lengths_: "
);
print
(
ys_lengths_
);
printf
(
", "
);
//
printf
(
"rhs_major_minor_to_ys_: "
);
print
(
rhs_major_minor_to_ys_
);
printf
(
", "
);
//
printf
(
"ndims_span_minor_: "
);
print
(
ndims_span_minor_
);
printf
(
", "
);
//
printf
(
"max_ndim_span_minor_: "
);
print
(
max_ndim_span_minor_
);
printf
(
", "
);
//
printf
(
"ys_to_span_major_: "
);
print
(
ys_to_span_major_
);
printf
(
", "
);
//
printf
(
"ys_to_span_minor_: "
);
print
(
ys_to_span_minor_
);
printf
(
", "
);
//
printf
(
"distributed_spans_lengthss_: "
);
print
(
distributed_spans_lengthss_
);
printf
(
", "
);
//
printf
(
"ndims_distributed_spans_minor_: "
);
print
(
ndims_distributed_spans_minor_
);
printf
(
", "
);
//
printf
(
"ps_over_rs_derivative_: "
);
print
(
ps_over_rs_derivative_
);
//
printf
(
"}"
);
}
};
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tile_distribution_encoding{"
);
//
printf
(
"NDimX: %d, NDimP: %d, NDimY: %d, "
,
NDimX
,
NDimP
,
NDimY
);
//
printf
(
"rs_lengths_: "
);
print
(
rs_lengths_
);
printf
(
", "
);
//
printf
(
"hs_lengthss_: "
);
print
(
hs_lengthss_
);
printf
(
", "
);
//
printf
(
"ps_to_rhss_major_: "
);
print
(
ps_to_rhss_major_
);
printf
(
", "
);
//
printf
(
"ps_to_rhss_minor_: "
);
print
(
ps_to_rhss_minor_
);
printf
(
", "
);
//
printf
(
"ys_to_rhs_major_: "
);
print
(
ys_to_rhs_major_
);
printf
(
", "
);
//
printf
(
"ys_to_rhs_minor_: "
);
print
(
ys_to_rhs_minor_
);
printf
(
", "
);
//
printf
(
"detail: "
);
print
(
detail
{});
//
printf
(
"}"
);
}
};
namespace
detail
{
template
<
typename
OuterDstr
,
typename
InnerDstr
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_embed_tile_distribution_encoding
(
OuterDstr
,
InnerDstr
)
{
static_assert
(
OuterDstr
::
NDimX
==
InnerDstr
::
NDimX
,
"wrong!"
);
constexpr
index_t
NDimHMajor
=
OuterDstr
::
NDimX
;
using
RsLengths
=
sequence_merge_t
<
typename
OuterDstr
::
RsLengths
,
typename
InnerDstr
::
RsLengths
>
;
constexpr
auto
hs_lengthss
=
generate_tuple
(
[
&
](
auto
i
)
{
return
merge_sequences
(
typename
OuterDstr
::
HsLengthss
{}[
i
],
typename
InnerDstr
::
HsLengthss
{}[
i
]);
},
number
<
NDimHMajor
>
{});
//
constexpr
auto
rhs_major_2_ndim_outer_rhs_minor
=
[
&
]()
{
array
<
index_t
,
NDimHMajor
+
1
>
rhs_major_2_ndim_outer_rhs_minor_
;
// R dimension
rhs_major_2_ndim_outer_rhs_minor_
(
0
)
=
OuterDstr
::
RsLengths
::
size
();
// Hs dimensions
static_for
<
0
,
NDimHMajor
,
1
>
{}([
&
](
auto
i
)
{
rhs_major_2_ndim_outer_rhs_minor_
(
i
+
1
)
=
typename
OuterDstr
::
HsLengthss
{}[
i
].
size
();
});
return
rhs_major_2_ndim_outer_rhs_minor_
;
}();
// Ps2RHssMinor
constexpr
auto
updated_inner_ps_2_rhss_minor
=
generate_tuple
(
[
&
](
auto
p
)
{
constexpr
auto
inner_p_2_rhss_major
=
typename
InnerDstr
::
Ps2RHssMajor
{}[
p
];
constexpr
auto
inner_p_2_rhss_minor
=
typename
InnerDstr
::
Ps2RHssMinor
{}[
p
];
constexpr
index_t
ndim_tmp
=
inner_p_2_rhss_minor
.
size
();
constexpr
auto
updated_inner_p_2_rhss_minor
=
[
&
]()
{
array
<
index_t
,
ndim_tmp
>
updated_inner_p_2_rhss_minor_
;
for
(
index_t
i
=
0
;
i
<
ndim_tmp
;
i
++
)
{
index_t
rh_major
=
inner_p_2_rhss_major
[
i
];
index_t
ndim_outer_h_minor
=
rhs_major_2_ndim_outer_rhs_minor
[
rh_major
];
updated_inner_p_2_rhss_minor_
(
i
)
=
inner_p_2_rhss_minor
[
i
]
+
ndim_outer_h_minor
;
}
return
updated_inner_p_2_rhss_minor_
;
}();
return
TO_SEQUENCE
(
updated_inner_p_2_rhss_minor
,
ndim_tmp
);
},
number
<
InnerDstr
::
NDimP
>
{});
// Ys2RHsMinor
constexpr
auto
updated_inner_ys_2_rhs_minor
=
[
&
]()
{
constexpr
auto
inner_ys_2_rhs_major
=
typename
InnerDstr
::
Ys2RHsMajor
{};
constexpr
auto
inner_ys_2_rhs_minor
=
typename
InnerDstr
::
Ys2RHsMinor
{};
constexpr
index_t
ndim_tmp
=
inner_ys_2_rhs_minor
.
size
();
constexpr
auto
updated_inner_ys_2_rhs_minor_
=
[
&
]()
{
array
<
index_t
,
ndim_tmp
>
updated_inner_ys_2_rhs_minor__
;
for
(
index_t
i
=
0
;
i
<
ndim_tmp
;
i
++
)
{
index_t
rh_major
=
inner_ys_2_rhs_major
[
i
];
index_t
ndim_outer_h_minor
=
rhs_major_2_ndim_outer_rhs_minor
[
rh_major
];
updated_inner_ys_2_rhs_minor__
(
i
)
=
inner_ys_2_rhs_minor
[
i
]
+
ndim_outer_h_minor
;
}
return
updated_inner_ys_2_rhs_minor__
;
}();
return
TO_SEQUENCE
(
updated_inner_ys_2_rhs_minor_
,
ndim_tmp
);
}();
//
constexpr
auto
ps_2_rhss_major
=
container_concat
(
typename
OuterDstr
::
Ps2RHssMajor
{},
typename
InnerDstr
::
Ps2RHssMajor
{});
constexpr
auto
ps_2_rhss_minor
=
container_concat
(
typename
OuterDstr
::
Ps2RHssMinor
{},
updated_inner_ps_2_rhss_minor
);
//
constexpr
auto
ys_2_rhs_major
=
merge_sequences
(
typename
OuterDstr
::
Ys2RHsMajor
{},
typename
InnerDstr
::
Ys2RHsMajor
{});
constexpr
auto
ys_2_rhs_minor
=
merge_sequences
(
typename
OuterDstr
::
Ys2RHsMinor
{},
updated_inner_ys_2_rhs_minor
);
return
tile_distribution_encoding
<
RsLengths
,
remove_cvref_t
<
decltype
(
hs_lengthss
)
>
,
remove_cvref_t
<
decltype
(
ps_2_rhss_major
)
>
,
remove_cvref_t
<
decltype
(
ps_2_rhss_minor
)
>
,
remove_cvref_t
<
decltype
(
ys_2_rhs_major
)
>
,
remove_cvref_t
<
decltype
(
ys_2_rhs_minor
)
>>
{};
}
template
<
typename
InDstr
,
index_t
...
InReduceDimXs
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_reduce_tile_distribution_encoding_impl
(
InDstr
,
sequence
<
InReduceDimXs
...
>
reduce_dim_xs_in
)
{
constexpr
auto
I1
=
number
<
1
>
{};
// FIXME: increase if fail
constexpr
index_t
max_ndim_r_out
=
20
;
constexpr
index_t
max_ndim_y_out
=
20
;
//
constexpr
index_t
ndim_p
=
InDstr
::
NDimP
;
constexpr
index_t
ndim_x_in
=
InDstr
::
NDimX
;
constexpr
index_t
ndim_y_in
=
InDstr
::
NDimY
;
constexpr
index_t
ndim_rh_major_in
=
InDstr
::
NDimX
+
1
;
constexpr
index_t
ndim_x_out
=
ndim_x_in
-
sizeof
...(
InReduceDimXs
);
constexpr
index_t
max_ndim_rh_minor_in
=
InDstr
::
detail
::
max_ndim_rh_minor_
;
// ndims_ps_low
constexpr
auto
ndims_ps_low
=
generate_array
(
[
&
](
auto
i
)
{
return
InDstr
::
ps_to_rhss_major_
[
i
].
size
();
},
number
<
ndim_p
>
{});
// is_rh_major_in_for_reduce
array
<
bool
,
ndim_rh_major_in
>
is_rh_major_in_for_reduce
{
false
};
for
(
index_t
i
=
0
;
i
<
reduce_dim_xs_in
.
size
();
i
++
)
{
index_t
rh_major
=
reduce_dim_xs_in
[
i
]
+
1
;
is_rh_major_in_for_reduce
(
rh_major
)
=
true
;
}
// is_y_in_for_reduce
array
<
bool
,
ndim_y_in
>
is_y_in_for_reduce
{
false
};
for
(
index_t
i
=
0
;
i
<
ndim_y_in
;
i
++
)
{
index_t
rh_major
=
InDstr
::
ys_to_rhs_major_
[
i
];
if
(
is_rh_major_in_for_reduce
[
rh_major
])
{
is_y_in_for_reduce
(
i
)
=
true
;
}
}
// is_rh_minor_in_for_y_reduce
array
<
array
<
bool
,
max_ndim_rh_minor_in
>
,
ndim_rh_major_in
>
is_rh_minor_in_for_y_reduce
{{
false
}};
static_for
<
0
,
ndim_y_in
,
1
>
{}([
&
](
auto
i
)
{
index_t
rh_major
=
InDstr
::
ys_to_rhs_major_
[
i
];
index_t
rh_minor
=
InDstr
::
ys_to_rhs_minor_
[
i
];
if
(
is_y_in_for_reduce
[
i
])
{
is_rh_minor_in_for_y_reduce
(
rh_major
)(
rh_minor
)
=
true
;
}
});
// in2out_rh_major
array
<
index_t
,
ndim_rh_major_in
>
in2out_rh_major
{
-
1
};
index_t
cnt_ndim_rh_major_out
=
0
;
for
(
index_t
i
=
0
;
i
<
ndim_rh_major_in
;
i
++
)
{
if
(
is_rh_major_in_for_reduce
[
i
])
{
in2out_rh_major
(
i
)
=
0
;
}
else
{
in2out_rh_major
(
i
)
=
cnt_ndim_rh_major_out
;
cnt_ndim_rh_major_out
++
;
}
}
// rs_lengths_out, in2out_rh_minor
array
<
index_t
,
max_ndim_r_out
>
rs_lengths_out
{
-
1
};
array
<
array
<
index_t
,
max_ndim_rh_minor_in
>
,
ndim_rh_major_in
>
in2out_rh_minor
{{
-
1
}};
// loop over input R dim
for
(
index_t
i
=
0
;
i
<
InDstr
::
rs_lengths_
.
size
();
i
++
)
{
// rs_lengths_out
rs_lengths_out
(
i
)
=
InDstr
::
rs_lengths_
[
i
];
// in2out_rh_minor
in2out_rh_minor
(
0
)(
i
)
=
i
;
}
// loop over input H Dim
index_t
cnt_ndim_r_out
=
InDstr
::
rs_lengths_
.
size
();
static_for
<
1
,
ndim_rh_major_in
,
1
>
{}([
&
](
auto
rh_major_in
)
{
constexpr
auto
h_major_in
=
rh_major_in
-
I1
;
constexpr
index_t
ndim_rh_minor_in
=
InDstr
::
hs_lengthss_
[
h_major_in
].
size
();
if
(
is_rh_major_in_for_reduce
[
rh_major_in
])
{
for
(
index_t
rh_minor_in
=
0
;
rh_minor_in
<
ndim_rh_minor_in
;
rh_minor_in
++
)
{
if
(
not
is_rh_minor_in_for_y_reduce
[
rh_major_in
][
rh_minor_in
])
{
// rs_lengths_out
rs_lengths_out
(
cnt_ndim_r_out
)
=
InDstr
::
hs_lengthss_
[
h_major_in
][
rh_minor_in
];
// in2out_rh_minor
in2out_rh_minor
(
rh_major_in
)(
rh_minor_in
)
=
cnt_ndim_r_out
;
cnt_ndim_r_out
++
;
}
}
}
else
{
for
(
index_t
rh_minor_in
=
0
;
rh_minor_in
<
ndim_rh_minor_in
;
rh_minor_in
++
)
{
// in2out_rh_minor
in2out_rh_minor
(
rh_major_in
)(
rh_minor_in
)
=
rh_minor_in
;
}
}
});
// ndim_r_out
const
index_t
ndim_r_out
=
cnt_ndim_r_out
;
// ndims_hs_minor_out, hs_lengthss_out
array
<
index_t
,
ndim_x_out
>
ndims_hs_minor_out
{
-
1
};
array
<
array
<
index_t
,
max_ndim_rh_minor_in
>
,
ndim_x_out
>
hs_lengthss_out
{{
-
1
}};
index_t
cnt_ndim_x_out
=
0
;
static_for
<
0
,
ndim_x_in
,
1
>
{}([
&
](
auto
i
)
{
if
(
not
is_rh_major_in_for_reduce
[
i
+
I1
])
{
// ndims_hs_minor_out
ndims_hs_minor_out
(
cnt_ndim_x_out
)
=
InDstr
::
hs_lengthss_
[
i
].
size
();
// hs_lengthss_out
static_for
<
0
,
InDstr
::
hs_lengthss_
[
i
].
size
(),
1
>
{}(
[
&
](
auto
j
)
{
hs_lengthss_out
(
cnt_ndim_x_out
)(
j
)
=
InDstr
::
hs_lengthss_
[
i
][
j
];
});
cnt_ndim_x_out
++
;
}
});
// ps_to_rhss_major_out, ps_to_rhss_minor_out
array
<
array
<
index_t
,
max_ndim_rh_minor_in
>
,
ndim_p
>
ps_to_rhss_major_out
{{
-
1
}};
array
<
array
<
index_t
,
max_ndim_rh_minor_in
>
,
ndim_p
>
ps_to_rhss_minor_out
{{
-
1
}};
static_for
<
0
,
ndim_p
,
1
>
{}([
&
](
auto
idim_p
)
{
static_for
<
0
,
InDstr
::
ps_to_rhss_major_
[
idim_p
].
size
(),
1
>
{}([
&
](
auto
idim_low
)
{
index_t
rh_major_in
=
InDstr
::
ps_to_rhss_major_
[
idim_p
][
idim_low
];
index_t
rh_minor_in
=
InDstr
::
ps_to_rhss_minor_
[
idim_p
][
idim_low
];
ps_to_rhss_major_out
(
idim_p
)(
idim_low
)
=
in2out_rh_major
[
rh_major_in
];
ps_to_rhss_minor_out
(
idim_p
)(
idim_low
)
=
in2out_rh_minor
[
rh_major_in
][
rh_minor_in
];
});
});
// ys_to_rhs_major_out, ys_to_rhs_minor_out
array
<
index_t
,
max_ndim_y_out
>
ys_to_rhs_major_out
{
-
1
};
array
<
index_t
,
max_ndim_y_out
>
ys_to_rhs_minor_out
{
-
1
};
index_t
cnt_ndim_y_out
=
0
;
static_for
<
0
,
ndim_y_in
,
1
>
{}([
&
](
auto
i
)
{
if
(
not
is_y_in_for_reduce
[
i
])
{
index_t
rh_major_in
=
InDstr
::
ys_to_rhs_major_
[
i
];
index_t
rh_minor_in
=
InDstr
::
ys_to_rhs_minor_
[
i
];
ys_to_rhs_major_out
(
cnt_ndim_y_out
)
=
in2out_rh_major
[
rh_major_in
];
ys_to_rhs_minor_out
(
cnt_ndim_y_out
)
=
in2out_rh_minor
[
rh_major_in
][
rh_minor_in
];
cnt_ndim_y_out
++
;
}
});
// ndim_y_out
const
index_t
ndim_y_out
=
cnt_ndim_y_out
;
//
return
make_tuple
(
ndim_x_out
,
ndim_p
,
ndim_y_out
,
ndim_r_out
,
ndims_hs_minor_out
,
ndims_ps_low
,
rs_lengths_out
,
hs_lengthss_out
,
ps_to_rhss_major_out
,
ps_to_rhss_minor_out
,
ys_to_rhs_major_out
,
ys_to_rhs_minor_out
);
}
template
<
typename
InDstr
,
index_t
...
InReduceDimXs
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_reduce_tile_distribution_encoding
(
InDstr
,
sequence
<
InReduceDimXs
...
>
reduce_dim_xs_in
)
{
constexpr
auto
impl
=
make_reduce_tile_distribution_encoding_impl
(
InDstr
{},
reduce_dim_xs_in
);
constexpr
index_t
ndim_x
=
impl
.
template
at
<
0
>();
constexpr
index_t
ndim_p
=
impl
.
template
at
<
1
>();
constexpr
index_t
ndim_y
=
impl
.
template
at
<
2
>();
constexpr
index_t
ndim_r
=
impl
.
template
at
<
3
>();
constexpr
auto
ndims_hs_minor
=
impl
.
template
at
<
4
>();
constexpr
auto
ndims_ps_low
=
impl
.
template
at
<
5
>();
constexpr
auto
rs_lengths_impl
=
impl
.
template
at
<
6
>();
constexpr
auto
hs_lengthss_impl
=
impl
.
template
at
<
7
>();
constexpr
auto
ps_to_rhss_major_impl
=
impl
.
template
at
<
8
>();
constexpr
auto
ps_to_rhss_minor_impl
=
impl
.
template
at
<
9
>();
constexpr
auto
ys_to_rhs_major_impl
=
impl
.
template
at
<
10
>();
constexpr
auto
ys_to_rhs_minor_impl
=
impl
.
template
at
<
11
>();
constexpr
auto
rs_lengths
=
TO_SEQUENCE
(
rs_lengths_impl
,
ndim_r
);
constexpr
auto
hs_lengthss
=
TO_TUPLE_OF_SEQUENCE
(
hs_lengthss_impl
,
ndim_x
,
ndims_hs_minor
);
constexpr
auto
ps_to_rhss_major
=
TO_TUPLE_OF_SEQUENCE
(
ps_to_rhss_major_impl
,
ndim_p
,
ndims_ps_low
);
constexpr
auto
ps_to_rhss_minor
=
TO_TUPLE_OF_SEQUENCE
(
ps_to_rhss_minor_impl
,
ndim_p
,
ndims_ps_low
);
constexpr
auto
ys_to_rhs_major
=
TO_SEQUENCE
(
ys_to_rhs_major_impl
,
ndim_y
);
constexpr
auto
ys_to_rhs_minor
=
TO_SEQUENCE
(
ys_to_rhs_minor_impl
,
ndim_y
);
return
tile_distribution_encoding
<
remove_cvref_t
<
decltype
(
rs_lengths
)
>
,
remove_cvref_t
<
decltype
(
hs_lengthss
)
>
,
remove_cvref_t
<
decltype
(
ps_to_rhss_major
)
>
,
remove_cvref_t
<
decltype
(
ps_to_rhss_minor
)
>
,
remove_cvref_t
<
decltype
(
ys_to_rhs_major
)
>
,
remove_cvref_t
<
decltype
(
ys_to_rhs_minor
)
>>
{};
}
}
// namespace detail
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_elementwise.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// TODO: support tensors with different distribution
template
<
typename
InOutElementFunc
,
typename
...
InOutDstrTensors
,
typename
=
std
::
enable_if_t
<
std
::
conjunction_v
<
std
::
negation
<
std
::
is_same
<
std
::
remove_const_t
<
InOutDstrTensors
>,
null_tensor
>>
...
>>>
CK_TILE_DEVICE
void
tile_elementwise_inout
(
const
InOutElementFunc
&
inout_element_func
,
InOutDstrTensors
&
...
inout_dstr_tensors
)
{
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
constexpr
index_t
thread_buffer_size
=
__type_pack_element
<
0
,
InOutDstrTensors
...
>::
get_thread_buffer_size
();
static_for
<
0
,
thread_buffer_size
,
1
>
{}(
[
&
](
auto
i
)
{
inout_element_func
(
inout_dstr_tensors
.
get_thread_buffer
().
at
(
i
)...);
});
}
template
<
typename
InElementFunc
,
typename
...
InTensor
,
typename
=
std
::
enable_if_t
<
std
::
conjunction_v
<
std
::
negation
<
std
::
is_same
<
InTensor
,
null_tensor
>
>
...
>>>
CK_TILE_DEVICE
auto
tile_elementwise_in
(
const
InElementFunc
&
in_element_func
,
const
InTensor
&
...
in_dstr_tensors
)
{
using
OutDataType
=
decltype
(
in_element_func
(
typename
InTensor
::
DataType
{}...));
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
constexpr
auto
in_tile_dstr
=
__type_pack_element
<
0
,
InTensor
...
>::
get_tile_distribution
();
constexpr
index_t
thread_buffer_size
=
__type_pack_element
<
0
,
InTensor
...
>::
get_thread_buffer_size
();
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
i
)
{
out_dstr_tensor
.
get_thread_buffer
()(
i
)
=
in_element_func
(
in_dstr_tensors
.
get_thread_buffer
()[
i
]...);
});
return
out_dstr_tensor
;
}
template
<
typename
DstrTensors
,
typename
T
>
CK_TILE_DEVICE
void
set_tile
(
DstrTensors
&
dstr_tensor
,
const
T
&
value
)
{
tile_elementwise_inout
(
[
&
value
](
auto
&
x
)
{
x
=
type_convert
<
typename
DstrTensors
::
DataType
,
remove_cvref_t
<
T
>>
(
value
);
},
dstr_tensor
);
}
template
<
typename
T
>
CK_TILE_DEVICE
void
set_tile
(
null_tensor
&
,
const
T
&
)
{
}
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor...
template
<
typename
DstrTensors
,
index_t
v
>
CK_TILE_DEVICE
void
set_tile
(
DstrTensors
&
dstr_tensor
,
number
<
v
>
)
{
constexpr
index_t
tensor_bytes
=
DstrTensors
::
get_thread_buffer_size
()
*
sizeof
(
typename
DstrTensors
::
DataType
);
if
constexpr
(
v
==
0
&&
tensor_bytes
%
4
==
0
)
{
using
dvec_t
=
array
<
index_t
,
tensor_bytes
/
4
>
;
auto
&
tensor
=
reinterpret_cast
<
dvec_t
&>
(
dstr_tensor
.
get_thread_buffer
());
for
(
auto
i
=
0
;
i
<
tensor
.
size
();
i
++
)
tensor
.
get
(
i
)
=
v
;
}
else
{
tile_elementwise_inout
(
[](
auto
&
x
)
{
x
=
type_convert
<
typename
DstrTensors
::
DataType
,
index_t
>
(
v
);
},
dstr_tensor
);
}
}
template
<
index_t
v
>
CK_TILE_DEVICE
void
set_tile
(
null_tensor
&
,
number
<
v
>
)
{
}
template
<
typename
DstrTensors
>
CK_TILE_DEVICE
void
clear_tile
(
DstrTensors
&
dstr_tensor
)
{
set_tile
(
dstr_tensor
,
0
);
}
namespace
impl
{
// TODO: this is ugly
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp8x4
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// This API is designed to use the _pk_ serious of function
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
constexpr
index_t
thread_buffer_size
=
InTensor
::
get_thread_buffer_size
();
static_assert
(
thread_buffer_size
%
4
==
0
);
constexpr
index_t
thread_buffer_size_pk
=
thread_buffer_size
/
4
;
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wuninitialized"
// __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and
// will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA
// so we prepare an uninitialized variable purposely, and turn off the warning
int
dummy_old
;
static_for
<
0
,
thread_buffer_size_pk
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
x
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
in_dstr_tensors
.
get_thread_buffer
()[
number
<
4
*
i
+
0
>
{}],
in_dstr_tensors
.
get_thread_buffer
()[
number
<
4
*
i
+
1
>
{}],
dummy_old
,
false
);
// false -> WORD0
uint32_t
y
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
in_dstr_tensors
.
get_thread_buffer
()[
number
<
4
*
i
+
2
>
{}],
in_dstr_tensors
.
get_thread_buffer
()[
number
<
4
*
i
+
3
>
{}],
dummy_old
,
false
);
// false -> WORD0
constexpr
int32_t
m0
=
0x05040100
;
using
vec_t
=
array
<
OutDataType
,
4
>
;
vec_t
d
=
bit_cast
<
vec_t
>
(
__builtin_amdgcn_perm
(
y
,
x
,
m0
));
out_dstr_tensor
.
get_thread_buffer
().
template
set_as
<
vec_t
>(
number
<
i
>
{},
d
);
});
#pragma clang diagnostic pop
return
out_dstr_tensor
;
#else
// fallback
return
tile_elementwise_in
(
type_convert
<
OutDataType
,
typename
InTensor
::
DataType
>
,
in_dstr_tensors
);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_opt_subdword
(
const
InTensor
&
in_dstr_tensors
)
{
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
using
i_type
=
remove_cvref_t
<
typename
InTensor
::
DataType
>
;
using
o_type
=
remove_cvref_t
<
OutDataType
>
;
constexpr
index_t
i_elem_bytes
=
sizeof
(
i_type
);
constexpr
index_t
o_elem_bytes
=
sizeof
(
o_type
);
static_assert
(
i_elem_bytes
<
4
||
o_elem_bytes
<
4
);
constexpr
index_t
bulk_size
=
(
i_elem_bytes
>=
o_elem_bytes
)
?
(
4
/
o_elem_bytes
)
:
(
4
/
i_elem_bytes
);
static_assert
(
bulk_size
!=
0
);
using
o_bulk_type
=
std
::
conditional_t
<
i_elem_bytes
>=
o_elem_bytes
,
float
,
array
<
o_type
,
bulk_size
>>
;
constexpr
index_t
thread_buffer_size
=
InTensor
::
get_thread_buffer_size
();
constexpr
index_t
iters
=
thread_buffer_size
/
bulk_size
;
constexpr
index_t
rems
=
thread_buffer_size
%
bulk_size
;
// cast the sequence per-bulk
static_for
<
0
,
iters
,
1
>
{}([
&
](
auto
i
)
{
union
bulk_wrapper
{
o_bulk_type
bulk
{};
o_type
data
[
bulk_size
];
}
o_bulk
;
// TODO: should use below function, but somehow will result in spill (same as c-forloop)
static_for
<
0
,
bulk_size
,
1
>
{}([
&
o_bulk
,
&
in_dstr_tensors
,
&
i
](
auto
ib
)
{
o_bulk
.
data
[
ib
.
value
]
=
static_cast
<
o_type
>
(
in_dstr_tensors
.
get_thread_buffer
()
.
template
get_as
<
i_type
>()[
number
<
bulk_size
*
i
.
value
+
ib
.
value
>
{}]);
});
// TODO: fixme, should use above!
// static_assert(sizeof(i_type) / sizeof(o_type) == 2);
// o_bulk.data[0] = static_cast<o_type>(
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
// o_bulk.data[1] = static_cast<o_type>(
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
out_dstr_tensor
.
get_thread_buffer
().
template
set_as
<
o_bulk_type
>(
i
,
o_bulk
.
bulk
);
});
static_for
<
0
,
rems
,
1
>
{}([
&
](
auto
r
)
{
// TODO: introducing local scratch pad?
auto
idx
=
number
<
iters
*
bulk_size
+
r
>
{};
out_dstr_tensor
.
get_thread_buffer
().
at
(
idx
)
=
static_cast
<
o_type
>
(
in_dstr_tensors
.
get_thread_buffer
().
at
(
idx
));
});
return
out_dstr_tensor
;
}
#endif
}
// namespace impl
template
<
typename
DstType
,
typename
SrcTensor
>
CK_TILE_DEVICE
auto
cast_tile
(
const
SrcTensor
&
src_tensor
)
{
if
constexpr
((
std
::
is_same_v
<
DstType
,
fp8_t
>
||
std
::
is_same_v
<
DstType
,
bf8_t
>
)
&&
std
::
is_same_v
<
typename
SrcTensor
::
DataType
,
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
4
==
0
))
{
return
impl
::
cast_tile_pk_fp8x4
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
else
if
constexpr
(
sizeof
(
DstType
)
<
4
||
sizeof
(
typename
SrcTensor
::
DataType
)
<
4
)
{
return
impl
::
cast_tile_opt_subdword
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#endif
else
return
tile_elementwise_in
(
type_convert
<
DstType
,
typename
SrcTensor
::
DataType
>
,
src_tensor
);
}
// no-op function for null_tensor arguments
template
<
typename
InOutElementFunc
,
typename
...
MaybeNullTensor
,
typename
=
std
::
enable_if_t
<
std
::
disjunction_v
<
std
::
is_same
<
remove_cvref_t
<
MaybeNullTensor
>,
null_tensor
>
...
>>>
CK_TILE_DEVICE
void
tile_elementwise_inout
(
const
InOutElementFunc
&
,
MaybeNullTensor
&&
...)
{
}
// no-op function for null_tensor arguments
template
<
typename
InElementFunc
,
typename
...
MaybeNullTensor
,
typename
=
std
::
enable_if_t
<
std
::
disjunction_v
<
std
::
is_same
<
remove_cvref_t
<
MaybeNullTensor
>,
null_tensor
>
...
>>>
CK_TILE_DEVICE
auto
tile_elementwise_in
(
const
InElementFunc
&
,
MaybeNullTensor
&&
...)
{
return
null_tensor
{};
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_window.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
index_t
NumCoord
>
struct
tile_window_with_static_distribution
{
using
BottomTensorView
=
remove_reference_t
<
BottomTensorView_
>
;
using
WindowLengths
=
remove_cvref_t
<
WindowLengths_
>
;
using
TileDstr
=
remove_cvref_t
<
StaticTileDistribution_
>
;
using
WindowAdaptor
=
typename
TileDstr
::
PsYs2XsAdaptor
;
using
BottomTensorDesc
=
typename
BottomTensorView
::
TensorDesc
;
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView
::
DataType
>
;
static
constexpr
index_t
NDimWindowAdaptorTop
=
WindowAdaptor
::
get_num_of_top_dimension
();
static
constexpr
index_t
NDimBottomTensor
=
BottomTensorDesc
::
get_num_of_dimension
();
static
constexpr
index_t
NDimP
=
TileDstr
::
get_num_of_dimension_p
();
static
constexpr
index_t
NDimY
=
TileDstr
::
get_num_of_dimension_y
();
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths
>::
value
,
"wrong! lengths should be static"
);
static_assert
(
TileDstr
::
is_static
(),
"wrong!"
);
static_assert
(
NDimBottomTensor
==
WindowAdaptor
::
get_num_of_bottom_dimension
(),
"wrong! inconsistent # of diemsnions"
);
using
AdaptorTopIndex
=
array
<
index_t
,
NDimWindowAdaptorTop
>
;
using
BottomTensorIndex
=
array
<
index_t
,
NDimBottomTensor
>
;
using
WindowAdaptorCoord
=
decltype
(
make_tensor_adaptor_coordinate
(
WindowAdaptor
{},
AdaptorTopIndex
{}));
using
BottomTensorCoord
=
decltype
(
make_tensor_coordinate
(
BottomTensorDesc
{},
BottomTensorIndex
{}));
struct
load_store_traits
{
private:
static
constexpr
auto
get_vector_dim_y_scalar_per_vector
()
{
const
auto
[
ys_vector_lengths
,
ys_vector_strides
]
=
tile_window_with_static_distribution
::
get_window_adaptor_ys_safe_vector_length_strides
();
index_t
VectorDimY_
=
0
;
index_t
ScalarPerVector_
=
1
;
for
(
index_t
i
=
0
;
i
<
NDimY
;
++
i
)
{
if
(
ys_vector_strides
[
i
]
==
1
&&
ys_vector_lengths
[
i
]
>
ScalarPerVector_
)
{
ScalarPerVector_
=
ys_vector_lengths
[
i
];
VectorDimY_
=
i
;
}
}
return
make_tuple
(
VectorDimY_
,
ScalarPerVector_
);
}
public:
static
constexpr
index_t
VectorDimY
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
0
>();
static
constexpr
index_t
ScalarPerVector
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
1
>();
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type;
using
vector_t
=
thread_buffer
<
DataType
,
ScalarPerVector
>
;
private:
static
constexpr
auto
scalars_per_access_
=
[]
{
constexpr
auto
scalars_per_access_arr
=
generate_array
(
[
&
](
auto
i
)
{
return
(
i
==
VectorDimY
)
?
ScalarPerVector
:
1
;
},
number
<
NDimY
>
{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr
auto
NDimY_
=
NDimY
;
return
TO_SEQUENCE
(
scalars_per_access_arr
,
NDimY_
);
}();
static
constexpr
auto
get_space_filling_curve
()
{
constexpr
auto
tile_dstr
=
TileDstr
{};
constexpr
auto
thread_tensor_lengths_ys
=
to_sequence
(
tile_dstr
.
get_ys_to_d_descriptor
().
get_lengths
());
// FIXME: need logic to judge dim access order
using
DimAccessOrder
=
typename
arithmetic_sequence_gen
<
0
,
NDimY
,
1
>::
type
;
return
space_filling_curve
<
decltype
(
thread_tensor_lengths_ys
),
DimAccessOrder
,
decltype
(
scalars_per_access_
)
>
{};
}
public:
using
SFC_Ys
=
decltype
(
get_space_filling_curve
());
static
constexpr
index_t
NumAccess
=
SFC_Ys
::
get_num_of_access
();
static_assert
(
0
<
NumAccess
,
"Wrong! NumAccess should be larger than 0"
);
static_assert
(
NumAccess
%
NumCoord
==
0
,
"wrong! # of access is not divisible by NumCoord"
);
};
static
constexpr
index_t
NumAccessPerCoord
=
load_store_traits
::
NumAccess
/
NumCoord
;
CK_TILE_DEVICE
constexpr
tile_window_with_static_distribution
()
=
default
;
CK_TILE_DEVICE
constexpr
tile_window_with_static_distribution
(
const
BottomTensorView
&
bottom_tensor_view
,
const
WindowLengths
&
window_lengths
,
const
BottomTensorIndex
&
window_origin
,
const
TileDstr
&
tile_distribution
)
:
bottom_tensor_view_
{
bottom_tensor_view
},
window_lengths_
{
window_lengths
},
window_origin_
{
window_origin
},
tile_dstr_
{
tile_distribution
},
pre_computed_coords_
{}
{
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
tile_distribution
.
get_ps_ys_to_xs_adaptor
(),
container_concat
(
detail
::
get_partition_index
(
tile_distribution
),
array
<
index_t
,
NDimY
>
{
0
}));
#endif
BottomTensorIndex
bottom_tensor_thread_origin_idx_tmp
=
window_origin
+
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
const
auto
bottom_tensor_thread_coord_tmp
=
make_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_origin_idx_tmp
);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using
Traits
=
load_store_traits
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
window_adaptor_thread_coord
=
window_adaptor_thread_coord_tmp
;
auto
bottom_tensor_thread_coord
=
bottom_tensor_thread_coord_tmp
;
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_step_between
(
number
<
0
>
{},
number
<
iCoord
*
NumAccessPerCoord
>
{});
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
pre_computed_coords_
(
iCoord
)
=
make_tuple
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
);
});
}
CK_TILE_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
{
return
NDimBottomTensor
;
}
CK_TILE_DEVICE
static
constexpr
bool
has_static_tile_distribution
()
{
return
TileDstr
::
is_static
();
}
CK_TILE_DEVICE
constexpr
auto
get_window_lengths
()
const
{
return
window_lengths_
;
}
CK_TILE_DEVICE
constexpr
auto
get_tile_distribution
()
const
{
return
tile_dstr_
;
}
CK_TILE_DEVICE
constexpr
auto
get_bottom_tensor_view
()
const
{
return
bottom_tensor_view_
;
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
WindowAdaptorCoord
&
window_adaptor_thread_coord
,
BottomTensorCoord
&
bottom_tensor_thread_coord
,
const
AdaptorTopIndex
&
idx_diff_adaptor_top
)
const
{
array
<
index_t
,
NDimBottomTensor
>
idx_diff_adaptor_bottom
;
move_tensor_adaptor_coordinate
(
tile_dstr_
.
get_ps_ys_to_xs_adaptor
(),
window_adaptor_thread_coord
,
idx_diff_adaptor_top
,
idx_diff_adaptor_bottom
);
move_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_coord
,
idx_diff_adaptor_bottom
);
}
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE
static
constexpr
auto
get_window_adaptor_ys_safe_vector_length_strides
()
{
// bottom tensor top dimension vector lengths and strides
const
auto
[
bottom_tensor_top_dim_vector_lengths
,
bottom_tensor_top_dim_vector_strides
]
=
BottomTensorDesc
::
get_top_dimension_safe_vector_length_strides
();
// window vector lengths/strides
const
auto
window_adaptor_bottom_dim_vector_lengths
=
bottom_tensor_top_dim_vector_lengths
;
const
auto
window_adaptor_bottom_dim_vector_strides
=
bottom_tensor_top_dim_vector_strides
;
// window adaptor [p0, p1, ..., y0, y1, ...]
array
<
index_t
,
WindowAdaptor
::
get_num_of_hidden_dimension
()
>
window_adaptor_vector_lengths
{
-
1
};
array
<
index_t
,
WindowAdaptor
::
get_num_of_hidden_dimension
()
>
window_adaptor_vector_strides
{
-
1
};
constexpr
auto
window_adaptor_bottom_dims
=
WindowAdaptor
::
get_bottom_dimension_hidden_ids
();
set_container_subset
(
window_adaptor_vector_lengths
,
window_adaptor_bottom_dims
,
window_adaptor_bottom_dim_vector_lengths
);
set_container_subset
(
window_adaptor_vector_strides
,
window_adaptor_bottom_dims
,
window_adaptor_bottom_dim_vector_strides
);
const
auto
[
window_adaptor_ps_ys_vector_lengths
,
window_adaptor_ps_ys_vector_strides
]
=
WindowAdaptor
{}.
get_top_dimension_safe_vector_length_strides
(
window_adaptor_vector_lengths
,
window_adaptor_vector_strides
);
// [y0, y1, ...]
constexpr
auto
y_dims
=
typename
arithmetic_sequence_gen
<
TileDstr
::
get_num_of_dimension_p
(),
NDimWindowAdaptorTop
,
1
>::
type
{};
return
make_tuple
(
get_container_subset
(
window_adaptor_ps_ys_vector_lengths
,
y_dims
),
get_container_subset
(
window_adaptor_ps_ys_vector_strides
,
y_dims
));
}
CK_TILE_DEVICE
constexpr
auto
get_num_access
()
const
{
return
load_store_traits
::
NumAccess
;
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// write into distributed tensor
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
Traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
return
dst_tensor
;
}
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
// using vector_type_t = typename Traits::vector_type_t;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
static
constexpr
index_t
YElementSize
=
TileDstr
{}.
get_ys_to_d_descriptor
().
get_element_space_size
();
static_assert
(
YElementSize
%
Traits
::
ScalarPerVector
==
0
);
using
vectorized_tbuf
=
array
<
vector_t
,
YElementSize
/
Traits
::
ScalarPerVector
>
;
// StaticBuffer<address_space_enum::vgpr,
// vector_t,
// YElementSize / Traits::ScalarPerVector,
// true>;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
&
dst_vec_tbuf
=
reinterpret_cast
<
vectorized_tbuf
&>
(
dst_tensor
.
get_thread_buffer
());
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
using
LdsDataType
=
typename
LdsTileWindow
::
DataType
;
// using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
// issues * warps * lanes
static_assert
(
LdsTileWindow
::
get_num_of_dimension
()
==
3
);
// TODO: hard coded
const
index_t
size_per_buf
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
);
const
index_t
size_per_wave
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
1
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
size_per_issue
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
1
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
m0_init_value
=
size_per_buf
+
size_per_wave
*
get_warp_id
();
m0_set_with_memory
(
m0_init_value
);
// This should be wave independent
using
Traits
=
load_store_traits
;
// using vector_type_t = typename Traits::vector_type_t;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
m0_inc_with_memory
(
size_per_issue
);
}
});
});
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
// using vector_type_t = typename Traits::vector_type_t;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
// vector_type_t vec;
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
)
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
static
constexpr
bool
oob_conditional_check
=
true
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
()
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
bottom_tensor_thread_coord
,
vec_value
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE
void
move
(
const
BottomTensorIndex
&
step
)
{
window_origin_
+=
step
;
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
move_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
pre_computed_coords_
(
iCoord
)(
I1
),
step
);
});
}
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView
bottom_tensor_view_
;
//
WindowLengths
window_lengths_
;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex
window_origin_
;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr
tile_dstr_
;
// this contains:
// per-thread coordinate for window adaptor
// per-thread coordinate for bottom tensor
array
<
tuple
<
WindowAdaptorCoord
,
BottomTensorCoord
>
,
NumCoord
>
pre_computed_coords_
;
};
// TODO: use strategy
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
index_t
NumCoord
=
1
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
TensorView_
&
tensor_view
,
const
WindowLengths_
&
window_lengths
,
const
multi_index
<
TensorView_
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution_
&
tile_distribution
,
number
<
NumCoord
>
=
{})
{
return
tile_window_with_static_distribution
<
remove_cvref_t
<
TensorView_
>
,
remove_cvref_t
<
WindowLengths_
>
,
remove_cvref_t
<
StaticTileDistribution_
>
,
NumCoord
>
{
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
}
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
index_t
NumCoord
>
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_with_static_distribution
<
TensorView_
,
WindowLengths_
,
StaticTileDistribution_
,
NumCoord
>&
window
,
const
typename
tile_window_with_static_distribution
<
TensorView_
,
WindowLengths_
,
StaticTileDistribution_
,
NumCoord
>::
BottomTensorIndex
&
step
)
{
window
.
move
(
step
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
>
struct
tile_window_with_static_lengths
{
using
BottomTensorView
=
remove_reference_t
<
BottomTensorView_
>
;
using
WindowLengths
=
remove_cvref_t
<
WindowLengths_
>
;
using
BottomTensorDesc
=
typename
BottomTensorView
::
TensorDesc
;
using
DataType
=
typename
BottomTensorView
::
DataType
;
static
constexpr
index_t
NDimBottomTensor
=
BottomTensorDesc
::
get_num_of_dimension
();
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths
>::
value
,
"wrong! lengths should be static"
);
using
BottomTensorIndex
=
array
<
index_t
,
NDimBottomTensor
>
;
CK_TILE_DEVICE
constexpr
tile_window_with_static_lengths
()
=
default
;
CK_TILE_DEVICE
constexpr
tile_window_with_static_lengths
(
const
BottomTensorView
&
bottom_tensor_view
,
const
WindowLengths
&
window_lengths
,
const
BottomTensorIndex
&
window_origin
)
:
bottom_tensor_view_
{
bottom_tensor_view
},
window_lengths_
{
window_lengths
},
window_origin_
{
window_origin
}
{
}
CK_TILE_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
{
return
NDimBottomTensor
;
}
CK_TILE_DEVICE
constexpr
auto
get_window_lengths
()
const
{
return
window_lengths_
;
}
CK_TILE_DEVICE
constexpr
auto
get_bottom_tensor_view
()
const
{
return
bottom_tensor_view_
;
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
// move window-origin
CK_TILE_DEVICE
void
move
(
const
BottomTensorIndex
&
step
)
{
window_origin_
+=
step
;
}
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView
bottom_tensor_view_
;
//
WindowLengths
window_lengths_
;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex
window_origin_
;
};
template
<
typename
TensorView_
,
typename
WindowLengths_
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
TensorView_
&
tensor_view
,
const
WindowLengths_
&
window_lengths
,
const
multi_index
<
TensorView_
::
get_num_of_dimension
()
>&
origin
)
{
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths_
>::
value
,
"wrong! lengths should be static"
);
return
tile_window_with_static_lengths
<
remove_cvref_t
<
TensorView_
>
,
remove_cvref_t
<
WindowLengths_
>>
{
tensor_view
,
window_lengths
,
origin
};
}
template
<
typename
TensorView_
,
typename
WindowLengths_
>
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>&
window
,
const
typename
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>::
BottomTensorIndex
&
step
)
{
window
.
move
(
step
);
}
}
// namespace ck_tile
include/ck_tile/core/utility/bit_cast.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace
ck_tile
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
Y
bit_cast
(
const
X
&
x
)
{
static_assert
(
__has_builtin
(
__builtin_bit_cast
),
""
);
static_assert
(
sizeof
(
X
)
==
sizeof
(
Y
),
"Do not support cast between different size of type"
);
return
__builtin_bit_cast
(
Y
,
x
);
}
}
// namespace ck_tile
include/ck_tile/core/utility/functional.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include <stdint.h>
#include <utility>
namespace
ck_tile
{
namespace
detail
{
struct
swallow
{
template
<
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
swallow
(
Ts
&&
...)
{
}
};
template
<
class
>
struct
static_for_impl
;
template
<
index_t
...
Is
>
struct
static_for_impl
<
sequence
<
Is
...
>>
{
template
<
class
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
)
const
{
swallow
{(
f
(
number
<
Is
>
{}),
0
)...};
}
};
}
// namespace detail
// F signature: F(number<Iter>)
template
<
index_t
NBegin
,
index_t
NEnd
,
index_t
Increment
>
struct
static_for
{
CK_TILE_HOST_DEVICE
constexpr
static_for
()
{
static_assert
(
Increment
!=
0
&&
(
NEnd
-
NBegin
)
%
Increment
==
0
,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0"
);
static_assert
((
Increment
>
0
&&
NBegin
<=
NEnd
)
||
(
Increment
<
0
&&
NBegin
>=
NEnd
),
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
"NBegin >= NEnd)"
);
}
template
<
class
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
)
const
{
detail
::
static_for_impl
<
typename
arithmetic_sequence_gen
<
NBegin
,
NEnd
,
Increment
>::
type
>
{}(
f
);
}
};
struct
identity
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
&&
operator
()(
T
&&
arg
)
const
noexcept
{
return
std
::
forward
<
T
>
(
arg
);
}
};
namespace
detail
{
// RemainLengths: sequence<...>
// Orders: sequence<...>
template
<
class
RemainLengths
,
class
Orders
>
struct
static_ford_impl
{
CK_TILE_HOST_DEVICE
constexpr
static_ford_impl
()
{
static_assert
(
RemainLengths
::
size
()
>
0
,
"wrong! should not get here"
);
}
// F signature: F(sequence<...>)
// CurrentOrderedId: sequence<...>
template
<
class
F
,
class
CurrentOrderedId
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
CurrentOrderedId
)
const
{
static_for
<
0
,
RemainLengths
::
front
(),
1
>
{}([
=
](
auto
I
)
{
static_ford_impl
<
decltype
(
RemainLengths
::
pop_front
()),
Orders
>
{}(
f
,
CurrentOrderedId
::
push_back
(
I
));
});
}
};
template
<
class
Orders
>
struct
static_ford_impl
<
sequence
<>
,
Orders
>
{
// F signature: F(sequence<...>)
// OrderedId: sequence<...>
template
<
class
F
,
class
OrderedId
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
OrderedId
)
const
{
// retrive unordered Id
f
(
OrderedId
::
reorder_old_to_new
(
Orders
{}));
}
};
}
// namespace detail
// Lengths is sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is sequence<...>, it is the order of dimension in which static_ford
// will loop over each
// dimension
template
<
class
Lengths
,
class
Orders
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
size
(),
1
>
::
type
>
struct
static_ford
{
CK_TILE_HOST_DEVICE
constexpr
static_ford
()
{
static_assert
(
Lengths
::
size
()
>
0
,
"wrong! Lengths is empty"
);
static_assert
(
Lengths
::
size
()
==
Orders
::
size
(),
"wrong! inconsistent size"
);
}
// F signature: F(sequence<...> multi_id)
// multi_id is the unordered multi-index
template
<
class
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
)
const
{
constexpr
auto
ordered_lengths
=
Lengths
::
reorder_new_to_old
(
Orders
{});
detail
::
static_ford_impl
<
decltype
(
ordered_lengths
),
Orders
>
{}(
f
,
sequence
<>
{});
}
};
namespace
detail
{
template
<
typename
Indices
>
struct
unpack_impl
;
template
<
index_t
...
Is
>
struct
unpack_impl
<
sequence
<
Is
...
>>
{
template
<
typename
F
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
)
const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
#else
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
template
at
<
Is
>()...);
#endif
}
};
template
<
typename
Seq0
,
typename
Seq1
>
struct
unpack2_impl
;
// TODO: remove this, after properly implementing unpack that takes any number of containers
template
<
index_t
...
Is
,
index_t
...
Js
>
struct
unpack2_impl
<
sequence
<
Is
...
>
,
sequence
<
Js
...
>>
{
template
<
typename
F
,
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
,
Y
&&
y
)
const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
std::forward<Y>(y).at(number<Js>{})...);
#else
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
template
at
<
Is
>()...,
std
::
forward
<
Y
>
(
y
).
template
at
<
Js
>()...);
#endif
}
};
}
// namespace detail
template
<
typename
F
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
unpack
(
F
&&
f
,
X
&&
x
)
{
using
X_
=
remove_reference_t
<
X
>
;
return
detail
::
unpack_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
));
}
// TODO: properly implement unpack that takes any number of containers
template
<
typename
F
,
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
unpack2
(
F
&&
f
,
X
&&
x
,
Y
&&
y
)
{
using
X_
=
remove_reference_t
<
X
>
;
using
Y_
=
remove_reference_t
<
Y
>
;
return
detail
::
unpack2_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
size
(),
1
>::
type
,
typename
arithmetic_sequence_gen
<
0
,
Y_
::
size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Y
>
(
y
));
}
// z = predicate ? x : y
template
<
bool
predicate
,
typename
X
,
typename
Y
>
constexpr
auto
conditional_expr
(
X
&&
x
,
Y
&&
y
)
{
if
constexpr
(
predicate
)
{
return
std
::
forward
<
X
>
(
x
);
}
else
{
return
std
::
forward
<
Y
>
(
y
);
}
}
}
// namespace ck_tile
include/ck_tile/core/utility/ignore.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
namespace
ck_tile
{
namespace
detail
{
struct
ignore_t
{
template
<
typename
T
>
constexpr
void
operator
=
(
T
&&
)
const
noexcept
{
}
};
}
// namespace detail
inline
constexpr
detail
::
ignore_t
ignore
;
}
// namespace ck_tile
include/ck_tile/core/utility/magic_div.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
namespace
ck_tile
{
// magic number division
// Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
// division implementation for uint32_t is then used. Therefore, dividend value need to be
// non-negative.
// TODO:
// 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range
struct
magic_division32_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_magic_numbers
(
uint32_t
divisor
)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= INT32_MAX)
uint32_t
shift_u32
=
0
;
while
((
1U
<<
shift_u32
)
<
divisor
)
{
shift_u32
++
;
};
uint64_t
tmp_u64
=
((
1UL
<<
shift_u32
)
-
divisor
)
<<
32
;
uint32_t
multiplier_u32
=
tmp_u64
/
divisor
+
1
;
return
make_tuple
(
multiplier_u32
,
shift_u32
);
}
template
<
auto
Divisor
,
typename
=
std
::
enable_if_t
<
(
0
<
Divisor
)>
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_magic_numbers
(
constant
<
Divisor
>
)
{
constexpr
auto
tmp
=
calculate_magic_numbers
(
uint32_t
{
Divisor
});
constexpr
uint32_t
multiplier
=
tmp
[
number
<
0
>
{}];
constexpr
uint32_t
shift
=
tmp
[
number
<
1
>
{}];
return
make_tuple
(
constant
<
multiplier
>
{},
constant
<
shift
>
{});
}
// magic division for uint32_t
CK_TILE_DEVICE
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
tmp
=
__umulhi
(
dividend
,
multiplier
);
return
(
tmp
+
dividend
)
>>
shift
;
}
CK_TILE_HOST
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
tmp
=
(
static_cast
<
uint64_t
>
(
dividend
)
*
multiplier
)
>>
32
;
return
(
tmp
+
dividend
)
>>
shift
;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
__umulhi
(
dividend_u32
,
multiplier
);
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
CK_TILE_HOST
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
(
static_cast
<
uint64_t
>
(
dividend_u32
)
*
multiplier
)
>>
32
;
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
};
// magic number division
// This version on works for divisor and dividended between [0, 1 << 16]
struct
magic_division16_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_magic_numbers
(
uint32_t
divisor
)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= (1U << 16));
uint32_t
shift_u32
=
0
;
while
((
1U
<<
shift_u32
)
<
divisor
)
{
shift_u32
++
;
};
uint32_t
one
=
1
;
uint32_t
multiplier_u32
=
((
one
<<
16
)
*
((
one
<<
shift_u32
)
-
divisor
))
/
divisor
+
1
;
return
make_tuple
(
multiplier_u32
,
shift_u32
);
}
// integral_constant<uint32_t, .>
template
<
auto
Divisor
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_magic_numbers
(
constant
<
Divisor
>
)
{
constexpr
auto
tmp
=
calculate_magic_numbers
(
uint32_t
{
Divisor
});
constexpr
uint32_t
multiplier
=
tmp
[
number
<
0
>
{}];
constexpr
uint32_t
shift
=
tmp
[
number
<
1
>
{}];
return
make_tuple
(
constant
<
multiplier
>
{},
constant
<
shift
>
{});
}
// magic division for uint32_t
CK_TILE_DEVICE
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
tmp
=
(
dividend
*
multiplier
)
>>
16
;
return
(
tmp
+
dividend
)
>>
shift
;
}
CK_TILE_HOST
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
tmp
=
(
dividend
*
multiplier
)
>>
16
;
return
(
tmp
+
dividend
)
>>
shift
;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
(
dividend_u32
*
multiplier
)
>>
16
;
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
CK_TILE_HOST
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
(
dividend_u32
*
multiplier
)
>>
16
;
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
};
// use 32bit version
using
magic_division
=
magic_division32_bit_range
;
struct
mdiv
{
// 1 dword -> 3 dword storage
uint32_t
divisor
;
uint32_t
multiplier
;
uint32_t
shift
;
// TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE
mdiv
(
uint32_t
divisor_
)
:
divisor
(
divisor_
)
{
auto
tmp
=
magic_division
::
calculate_magic_numbers
(
divisor_
);
multiplier
=
tmp
[
number
<
0
>
{}];
shift
=
tmp
[
number
<
1
>
{}];
}
CK_TILE_HOST_DEVICE
mdiv
()
:
divisor
(
0
),
multiplier
(
0
),
shift
(
0
)
{}
CK_TILE_HOST_DEVICE
void
update
(
uint32_t
divisor_
)
{
divisor
=
divisor_
;
auto
tmp
=
magic_division
::
calculate_magic_numbers
(
divisor_
);
multiplier
=
tmp
[
number
<
0
>
{}];
shift
=
tmp
[
number
<
1
>
{}];
}
CK_TILE_HOST_DEVICE
uint32_t
div
(
uint32_t
dividend_
)
const
{
return
magic_division
::
do_magic_division
(
dividend_
,
multiplier
,
shift
);
}
CK_TILE_HOST_DEVICE
void
divmod
(
uint32_t
dividend_
,
uint32_t
&
quotient_
,
uint32_t
&
remainder_
)
const
{
quotient_
=
div
(
dividend_
);
remainder_
=
dividend_
-
(
quotient_
*
divisor
);
}
CK_TILE_HOST_DEVICE
uint32_t
get
()
const
{
return
divisor
;
}
};
struct
mdiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t
multiplier
;
uint32_t
shift
;
// TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE
mdiv2
(
uint32_t
divisor_
)
{
auto
tmp
=
magic_division
::
calculate_magic_numbers
(
divisor_
);
multiplier
=
tmp
[
number
<
0
>
{}];
shift
=
tmp
[
number
<
1
>
{}];
}
CK_TILE_HOST_DEVICE
mdiv2
()
:
multiplier
(
0
),
shift
(
0
)
{}
CK_TILE_HOST_DEVICE
uint32_t
div
(
uint32_t
dividend_
)
const
{
return
magic_division
::
do_magic_division
(
dividend_
,
multiplier
,
shift
);
}
CK_TILE_HOST_DEVICE
void
divmod
(
uint32_t
dividend_
,
uint32_t
divisor_
,
uint32_t
&
quotient_
,
uint32_t
&
remainder_
)
const
{
quotient_
=
div
(
dividend_
);
remainder_
=
dividend_
-
(
quotient_
*
divisor_
);
}
};
}
// namespace ck_tile
include/ck_tile/core/utility/random.hpp
0 → 100644
View file @
5a9c4962
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include <stdint.h>
#include <tuple>
#include <type_traits>
namespace
ck_tile
{
// return 0 if data is not fp16 or fp32
template
<
typename
T
,
uint32_t
seed_
>
struct
prand_generator_t
{
CK_TILE_HOST_DEVICE
uint32_t
operator
()(
int
,
T
,
uint32_t
=
seed_
)
{
return
0
;
}
};
// version for fp32
template
<
uint32_t
seed_
>
struct
prand_generator_t
<
float
,
seed_
>
{
CK_TILE_HOST_DEVICE
uint32_t
operator
()(
int
id
,
float
val
,
uint32_t
seed
=
seed_
)
{
uint32_t
x
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
val
));
uint32_t
drop_bits
=
uint32_t
(
x
)
&
0xFFFFu
;
drop_bits
^=
x
>>
16
;
drop_bits
=
((
drop_bits
&
31
)
<<
11
)
|
(
drop_bits
>>
5
);
drop_bits
*=
0x7000149
;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t
rng
=
(
drop_bits
^
0x13371337
^
(
id
*
229791
)
^
seed
);
return
rng
;
}
};
// version for fp16
template
<
uint32_t
seed_
>
struct
prand_generator_t
<
half_t
,
seed_
>
{
CK_TILE_HOST_DEVICE
uint32_t
operator
()(
int
id
,
half_t
val
,
uint32_t
seed
=
seed_
)
{
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
uint32_t
drop_bits
=
uint32_t
(
x
)
&
0xFFFFu
;
drop_bits
=
((
drop_bits
&
31
)
<<
11
)
|
(
drop_bits
>>
5
);
drop_bits
*=
0x7000149
;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t
rng
=
(
drop_bits
^
0x13371337
^
(
id
*
229791
)
^
seed
);
return
rng
;
}
};
}
// namespace ck_tile
Prev
1
…
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment