Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
20ddaeba
Commit
20ddaeba
authored
Apr 22, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
c5f1cdf7
43879b89
Changes
236
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4339 additions
and
0 deletions
+4339
-0
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
+760
-0
include/ck_tile/core/tensor/tile_elementwise.hpp
include/ck_tile/core/tensor/tile_elementwise.hpp
+263
-0
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+740
-0
include/ck_tile/core/utility/bit_cast.hpp
include/ck_tile/core/utility/bit_cast.hpp
+19
-0
include/ck_tile/core/utility/functional.hpp
include/ck_tile/core/utility/functional.hpp
+208
-0
include/ck_tile/core/utility/ignore.hpp
include/ck_tile/core/utility/ignore.hpp
+22
-0
include/ck_tile/core/utility/magic_div.hpp
include/ck_tile/core/utility/magic_div.hpp
+240
-0
include/ck_tile/core/utility/random.hpp
include/ck_tile/core/utility/random.hpp
+58
-0
include/ck_tile/core/utility/to_sequence.hpp
include/ck_tile/core/utility/to_sequence.hpp
+73
-0
include/ck_tile/core/utility/transpose_vectors.hpp
include/ck_tile/core/utility/transpose_vectors.hpp
+125
-0
include/ck_tile/core/utility/type_traits.hpp
include/ck_tile/core/utility/type_traits.hpp
+95
-0
include/ck_tile/core/utility/unary_element_function.hpp
include/ck_tile/core/utility/unary_element_function.hpp
+67
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+22
-0
include/ck_tile/host/arg_parser.hpp
include/ck_tile/host/arg_parser.hpp
+184
-0
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+394
-0
include/ck_tile/host/device_memory.hpp
include/ck_tile/host/device_memory.hpp
+112
-0
include/ck_tile/host/fill.hpp
include/ck_tile/host/fill.hpp
+232
-0
include/ck_tile/host/hip_check_error.hpp
include/ck_tile/host/hip_check_error.hpp
+36
-0
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+523
-0
include/ck_tile/host/kernel_launch.hpp
include/ck_tile/host/kernel_launch.hpp
+166
-0
No files found.
Too many changes to show.
To preserve performance only
236 of 236+
files are displayed.
Plain diff
Email patch
include/ck_tile/core/tensor/tile_distribution_encoding.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
RsLengths_
,
// sequence<...>
typename
HsLengthss_
,
// tuple<sequence<...>, ...>
typename
Ps2RHssMajor_
,
// tuple<sequence<...>, ...>
typename
Ps2RHssMinor_
,
// tuple<sequence<...>, ...>
typename
Ys2RHsMajor_
,
// sequence<...>
typename
Ys2RHsMinor_
>
// sequence<...>
struct
tile_distribution_encoding
{
using
RsLengths
=
remove_cvref_t
<
RsLengths_
>
;
using
HsLengthss
=
remove_cvref_t
<
HsLengthss_
>
;
using
Ps2RHssMajor
=
remove_cvref_t
<
Ps2RHssMajor_
>
;
using
Ps2RHssMinor
=
remove_cvref_t
<
Ps2RHssMinor_
>
;
using
Ys2RHsMajor
=
remove_cvref_t
<
Ys2RHsMajor_
>
;
using
Ys2RHsMinor
=
remove_cvref_t
<
Ys2RHsMinor_
>
;
static_assert
(
Ps2RHssMajor
::
size
()
==
Ps2RHssMinor
::
size
(),
"wrong!"
);
static_assert
(
Ys2RHsMajor
::
size
()
==
Ys2RHsMinor
::
size
(),
"wrong!"
);
static
constexpr
index_t
NDimX
=
HsLengthss
::
size
();
static
constexpr
index_t
NDimP
=
Ps2RHssMajor
::
size
();
static
constexpr
index_t
NDimY
=
Ys2RHsMajor
::
size
();
static
constexpr
index_t
NDimR
=
RsLengths
::
size
();
// FIXME: move into detail
static
constexpr
auto
rs_lengths_
=
RsLengths
{};
static
constexpr
auto
hs_lengthss_
=
HsLengthss
{};
static
constexpr
auto
ps_to_rhss_major_
=
Ps2RHssMajor
{};
static
constexpr
auto
ps_to_rhss_minor_
=
Ps2RHssMinor
{};
static
constexpr
auto
ys_to_rhs_major_
=
Ys2RHsMajor
{};
static
constexpr
auto
ys_to_rhs_minor_
=
Ys2RHsMinor
{};
// redundant but useful info
// TODO: really bad code, should be over-hauled
struct
detail
{
// ndim_rh_major_, ndim_span_mainor_
static
constexpr
index_t
ndim_rh_major_
=
NDimX
+
1
;
static
constexpr
index_t
ndim_span_major_
=
NDimX
;
// ndims_rhs_minor_[ndim_rh_major_]
static
constexpr
auto
ndims_rhs_minor_
=
generate_array
(
[](
auto
i
)
{
if
constexpr
(
i
.
value
==
0
)
{
return
rs_lengths_
.
size
();
}
else
{
return
hs_lengthss_
[
i
-
number
<
1
>
{}].
size
();
}
},
number
<
ndim_rh_major_
>
{});
// max_ndim_rh_minor_
static
constexpr
index_t
max_ndim_rh_minor_
=
container_reduce
(
ndims_rhs_minor_
,
maximize
<
index_t
>
{},
0
);
// rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
static
constexpr
auto
rhs_lengthss_
=
to_array_of_array
(
container_concat
(
make_tuple
(
rs_lengths_
),
hs_lengthss_
));
// ys_lengths_
static
constexpr
auto
ys_lengths_
=
[]
{
array
<
index_t
,
NDimY
>
ys_lengths_tmp
{
-
1
};
for
(
index_t
i
=
0
;
i
<
NDimY
;
i
++
)
{
index_t
rh_major
=
ys_to_rhs_major_
[
i
];
index_t
rh_minor
=
ys_to_rhs_minor_
[
i
];
ys_lengths_tmp
(
i
)
=
rhs_lengthss_
[
rh_major
][
rh_minor
];
}
return
ys_lengths_tmp
;
}();
// rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
static
constexpr
auto
rhs_major_minor_to_ys_
=
[]
{
array
<
array
<
index_t
,
max_ndim_rh_minor_
>
,
NDimX
+
1
>
rhs_major_minor_to_ys_tmp
{{
-
1
}};
static_for
<
0
,
NDimY
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
rh_major
=
ys_to_rhs_major_
[
i
];
constexpr
index_t
rh_minor
=
ys_to_rhs_minor_
[
i
];
rhs_major_minor_to_ys_tmp
(
rh_major
)(
rh_minor
)
=
i
;
});
return
rhs_major_minor_to_ys_tmp
;
}();
// ndims_span_minor_[NDimY]
static
constexpr
auto
ndims_span_minor_
=
[]
{
array
<
index_t
,
NDimX
>
ndims_span_minor
{
0
};
for
(
index_t
i
=
0
;
i
<
NDimY
;
i
++
)
{
const
index_t
span_major
=
ys_to_rhs_major_
[
i
]
-
1
;
ndims_span_minor
(
span_major
)
++
;
}
return
ndims_span_minor
;
}();
// max_ndim_span_minor_
static
constexpr
index_t
max_ndim_span_minor_
=
container_reduce
(
ndims_span_minor_
,
maximize
<
index_t
>
{},
0
);
// rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
static
constexpr
auto
rhs_major_minor_to_span_minor_
=
[]
{
array
<
array
<
index_t
,
max_ndim_rh_minor_
>
,
ndim_rh_major_
>
rhs_major_minor_to_span_minor
{
{
-
1
}};
static_for
<
0
,
ndim_rh_major_
,
1
>
{}([
&
](
auto
rh_major
)
{
constexpr
index_t
ndim_rh_minor
=
ndims_rhs_minor_
[
rh_major
];
index_t
cnt_ndim_span_minor
=
0
;
static_for
<
0
,
ndim_rh_minor
,
1
>
{}([
&
](
auto
rh_minor
)
{
constexpr
index_t
idim_y
=
rhs_major_minor_to_ys_
[
rh_major
][
rh_minor
];
if
(
idim_y
>=
0
)
{
rhs_major_minor_to_span_minor
(
rh_major
)(
rh_minor
)
=
cnt_ndim_span_minor
;
cnt_ndim_span_minor
++
;
}
});
});
return
rhs_major_minor_to_span_minor
;
}();
// ys_to_span_major_[NDimY]
static
constexpr
auto
ys_to_span_major_
=
generate_array
([](
auto
i
)
{
return
ys_to_rhs_major_
[
i
]
-
1
;
},
number
<
NDimY
>
{});
// ys_to_span_minor_[NDimY]
static
constexpr
auto
ys_to_span_minor_
=
generate_array
(
[](
auto
i
)
{
return
rhs_major_minor_to_span_minor_
[
ys_to_rhs_major_
[
i
]][
ys_to_rhs_minor_
[
i
]];
},
number
<
NDimY
>
{});
// distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
static
constexpr
auto
distributed_spans_lengthss_
=
[]
{
array
<
array
<
index_t
,
max_ndim_span_minor_
>
,
ndim_span_major_
>
distributed_spans_lengthss
{{
-
1
}};
static_for
<
0
,
NDimY
,
1
>
{}([
&
](
auto
i
)
{
const
index_t
rh_major
=
ys_to_rhs_major_
[
i
];
const
index_t
rh_minor
=
ys_to_rhs_minor_
[
i
];
const
index_t
h_length
=
hs_lengthss_
[
number
<
rh_major
-
1
>
{}][
rh_minor
];
const
index_t
span_major
=
rh_major
-
1
;
const
index_t
span_minor
=
rhs_major_minor_to_span_minor_
[
rh_major
][
rh_minor
];
distributed_spans_lengthss
(
span_major
)(
span_minor
)
=
h_length
;
});
return
distributed_spans_lengthss
;
}();
// ndims_distributed_spans_minor_[ndim_span_major_]
static
constexpr
auto
ndims_distributed_spans_minor_
=
[]
{
array
<
index_t
,
ndim_span_major_
>
ndims_distributed_spans_minor
{
0
};
static_for
<
0
,
NDimY
,
1
>
{}([
&
](
auto
i
)
{
const
index_t
span_major
=
ys_to_rhs_major_
[
i
]
-
1
;
ndims_distributed_spans_minor
(
span_major
)
++
;
});
return
ndims_distributed_spans_minor
;
}();
// does_p_own_r_[NDimP][NDimR]
static
constexpr
auto
does_p_own_r_
=
[]
{
if
constexpr
(
NDimR
>
0
)
{
array
<
array
<
bool
,
NDimR
>
,
NDimP
>
does_p_own_r
{{
false
}};
static_for
<
0
,
NDimP
,
1
>
{}([
&
](
auto
idim_p
)
{
constexpr
index_t
ndim_low
=
ps_to_rhss_major_
[
idim_p
].
size
();
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
idim_low
)
{
constexpr
index_t
rh_major
=
ps_to_rhss_major_
[
idim_p
][
idim_low
];
constexpr
index_t
rh_minor
=
ps_to_rhss_minor_
[
idim_p
][
idim_low
];
if
constexpr
(
rh_major
==
0
)
{
does_p_own_r
(
idim_p
)(
rh_minor
)
=
true
;
}
});
});
return
does_p_own_r
;
}
else
{
return
array
<
array
<
bool
,
NDimR
>
,
NDimP
>
{};
}
}();
// ps_over_rs_derivative_[NDimP][NDimR]
static
constexpr
auto
ps_over_rs_derivative_
=
[]
{
if
constexpr
(
NDimR
>
0
)
{
array
<
array
<
index_t
,
NDimR
>
,
NDimP
>
ps_over_rs_derivative
{{
0
}};
static_for
<
0
,
NDimP
,
1
>
{}([
&
](
auto
idim_p
)
{
constexpr
index_t
ndim_low
=
ps_to_rhss_major_
[
idim_p
].
size
();
index_t
p_over_rh_derivative
=
1
;
static_for
<
ndim_low
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
idim_low
)
{
constexpr
index_t
rh_major
=
ps_to_rhss_major_
[
idim_p
][
idim_low
];
constexpr
index_t
rh_minor
=
ps_to_rhss_minor_
[
idim_p
][
idim_low
];
constexpr
index_t
rh_length
=
rhs_lengthss_
[
rh_major
][
rh_minor
];
if
constexpr
(
rh_major
==
0
)
{
ps_over_rs_derivative
(
idim_p
)(
rh_minor
)
=
p_over_rh_derivative
;
}
p_over_rh_derivative
*=
rh_length
;
});
});
return
ps_over_rs_derivative
;
}
else
{
return
array
<
array
<
index_t
,
NDimR
>
,
NDimP
>
{};
}
}();
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_h_dim_lengths_prefix_sum
()
{
// <len_d0, len_d1, ...>
// e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
constexpr
auto
uniformed_h_dim_lengths
=
generate_sequence_v2
(
[
&
](
auto
i
)
{
constexpr
index_t
size
=
HsLengthss
{}[
i
].
size
();
return
number
<
size
>
{};
},
number
<
NDimX
>
{});
// <0, len_d0, len_d0+len_d1, ...>
// e.g. seq<3, 5> --> seq<0, 3, 8>
constexpr
auto
h_dim_prefix_sum
=
prefix_sum_sequence
(
uniformed_h_dim_lengths
);
return
h_dim_prefix_sum
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_uniformed_idx_y_to_h
()
{
constexpr
auto
all_ys_2_rhss
=
transform_sequences
(
[](
auto
major
,
auto
minor
)
constexpr
{
// <0, 0, len_d0, len_d0+len_d1, ...>
constexpr
auto
x_dim_prefix_sum
=
merge_sequences
(
sequence
<
0
>
{}
/*for R dims*/
,
get_h_dim_lengths_prefix_sum
());
return
x_dim_prefix_sum
.
at
(
major
)
+
minor
;
},
Ys2RHsMajor
{},
Ys2RHsMinor
{});
return
all_ys_2_rhss
;
}
// return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
template
<
typename
IdxSeq
,
typename
PrefixSumSeq
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_sorted_info
(
IdxSeq
,
PrefixSumSeq
)
{
using
sorted_idx
=
sequence_unique_sort
<
IdxSeq
,
less
<
index_t
>
,
equal
<
index_t
>>
;
constexpr
auto
sorted_dims
=
typename
sorted_idx
::
type
{};
constexpr
auto
sorted_maps
=
typename
sorted_idx
::
sorted2unsorted_map
{};
constexpr
auto
sorted_histogram
=
histogram_sorted_sequence
(
sorted_dims
,
PrefixSumSeq
{});
constexpr
auto
sorted_prefix_sum
=
prefix_sum_sequence
(
sorted_histogram
);
return
make_tuple
(
sorted_dims
,
sorted_maps
,
sorted_prefix_sum
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_sorted_y_info
()
{
return
get_sorted_info
(
get_uniformed_idx_y_to_h
(),
get_h_dim_lengths_prefix_sum
());
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tile_distribution_encoding::detail{"
);
//
printf
(
"ndim_rh_major_: "
);
print
(
ndim_rh_major_
);
printf
(
", "
);
//
printf
(
"ndim_span_major_: "
);
print
(
ndim_span_major_
);
printf
(
", "
);
//
printf
(
"ndims_rhs_minor_: "
);
print
(
ndims_rhs_minor_
);
printf
(
", "
);
//
printf
(
"ndim_rh_major_: "
);
print
(
ndim_rh_major_
);
printf
(
", "
);
//
printf
(
"max_ndim_rh_minor_: "
);
print
(
max_ndim_rh_minor_
);
printf
(
", "
);
//
printf
(
"rhs_lengthss_: "
);
print
(
rhs_lengthss_
);
printf
(
", "
);
//
printf
(
"ys_lengths_: "
);
print
(
ys_lengths_
);
printf
(
", "
);
//
printf
(
"rhs_major_minor_to_ys_: "
);
print
(
rhs_major_minor_to_ys_
);
printf
(
", "
);
//
printf
(
"ndims_span_minor_: "
);
print
(
ndims_span_minor_
);
printf
(
", "
);
//
printf
(
"max_ndim_span_minor_: "
);
print
(
max_ndim_span_minor_
);
printf
(
", "
);
//
printf
(
"ys_to_span_major_: "
);
print
(
ys_to_span_major_
);
printf
(
", "
);
//
printf
(
"ys_to_span_minor_: "
);
print
(
ys_to_span_minor_
);
printf
(
", "
);
//
printf
(
"distributed_spans_lengthss_: "
);
print
(
distributed_spans_lengthss_
);
printf
(
", "
);
//
printf
(
"ndims_distributed_spans_minor_: "
);
print
(
ndims_distributed_spans_minor_
);
printf
(
", "
);
//
printf
(
"ps_over_rs_derivative_: "
);
print
(
ps_over_rs_derivative_
);
//
printf
(
"}"
);
}
};
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tile_distribution_encoding{"
);
//
printf
(
"NDimX: %d, NDimP: %d, NDimY: %d, "
,
NDimX
,
NDimP
,
NDimY
);
//
printf
(
"rs_lengths_: "
);
print
(
rs_lengths_
);
printf
(
", "
);
//
printf
(
"hs_lengthss_: "
);
print
(
hs_lengthss_
);
printf
(
", "
);
//
printf
(
"ps_to_rhss_major_: "
);
print
(
ps_to_rhss_major_
);
printf
(
", "
);
//
printf
(
"ps_to_rhss_minor_: "
);
print
(
ps_to_rhss_minor_
);
printf
(
", "
);
//
printf
(
"ys_to_rhs_major_: "
);
print
(
ys_to_rhs_major_
);
printf
(
", "
);
//
printf
(
"ys_to_rhs_minor_: "
);
print
(
ys_to_rhs_minor_
);
printf
(
", "
);
//
printf
(
"detail: "
);
print
(
detail
{});
//
printf
(
"}"
);
}
};
namespace
detail
{
template
<
typename
OuterDstr
,
typename
InnerDstr
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_embed_tile_distribution_encoding
(
OuterDstr
,
InnerDstr
)
{
static_assert
(
OuterDstr
::
NDimX
==
InnerDstr
::
NDimX
,
"wrong!"
);
constexpr
index_t
NDimHMajor
=
OuterDstr
::
NDimX
;
using
RsLengths
=
sequence_merge_t
<
typename
OuterDstr
::
RsLengths
,
typename
InnerDstr
::
RsLengths
>
;
constexpr
auto
hs_lengthss
=
generate_tuple
(
[
&
](
auto
i
)
{
return
merge_sequences
(
typename
OuterDstr
::
HsLengthss
{}[
i
],
typename
InnerDstr
::
HsLengthss
{}[
i
]);
},
number
<
NDimHMajor
>
{});
//
constexpr
auto
rhs_major_2_ndim_outer_rhs_minor
=
[
&
]()
{
array
<
index_t
,
NDimHMajor
+
1
>
rhs_major_2_ndim_outer_rhs_minor_
;
// R dimension
rhs_major_2_ndim_outer_rhs_minor_
(
0
)
=
OuterDstr
::
RsLengths
::
size
();
// Hs dimensions
static_for
<
0
,
NDimHMajor
,
1
>
{}([
&
](
auto
i
)
{
rhs_major_2_ndim_outer_rhs_minor_
(
i
+
1
)
=
typename
OuterDstr
::
HsLengthss
{}[
i
].
size
();
});
return
rhs_major_2_ndim_outer_rhs_minor_
;
}();
// Ps2RHssMinor
constexpr
auto
updated_inner_ps_2_rhss_minor
=
generate_tuple
(
[
&
](
auto
p
)
{
constexpr
auto
inner_p_2_rhss_major
=
typename
InnerDstr
::
Ps2RHssMajor
{}[
p
];
constexpr
auto
inner_p_2_rhss_minor
=
typename
InnerDstr
::
Ps2RHssMinor
{}[
p
];
constexpr
index_t
ndim_tmp
=
inner_p_2_rhss_minor
.
size
();
constexpr
auto
updated_inner_p_2_rhss_minor
=
[
&
]()
{
array
<
index_t
,
ndim_tmp
>
updated_inner_p_2_rhss_minor_
;
for
(
index_t
i
=
0
;
i
<
ndim_tmp
;
i
++
)
{
index_t
rh_major
=
inner_p_2_rhss_major
[
i
];
index_t
ndim_outer_h_minor
=
rhs_major_2_ndim_outer_rhs_minor
[
rh_major
];
updated_inner_p_2_rhss_minor_
(
i
)
=
inner_p_2_rhss_minor
[
i
]
+
ndim_outer_h_minor
;
}
return
updated_inner_p_2_rhss_minor_
;
}();
return
TO_SEQUENCE
(
updated_inner_p_2_rhss_minor
,
ndim_tmp
);
},
number
<
InnerDstr
::
NDimP
>
{});
// Ys2RHsMinor
constexpr
auto
updated_inner_ys_2_rhs_minor
=
[
&
]()
{
constexpr
auto
inner_ys_2_rhs_major
=
typename
InnerDstr
::
Ys2RHsMajor
{};
constexpr
auto
inner_ys_2_rhs_minor
=
typename
InnerDstr
::
Ys2RHsMinor
{};
constexpr
index_t
ndim_tmp
=
inner_ys_2_rhs_minor
.
size
();
constexpr
auto
updated_inner_ys_2_rhs_minor_
=
[
&
]()
{
array
<
index_t
,
ndim_tmp
>
updated_inner_ys_2_rhs_minor__
;
for
(
index_t
i
=
0
;
i
<
ndim_tmp
;
i
++
)
{
index_t
rh_major
=
inner_ys_2_rhs_major
[
i
];
index_t
ndim_outer_h_minor
=
rhs_major_2_ndim_outer_rhs_minor
[
rh_major
];
updated_inner_ys_2_rhs_minor__
(
i
)
=
inner_ys_2_rhs_minor
[
i
]
+
ndim_outer_h_minor
;
}
return
updated_inner_ys_2_rhs_minor__
;
}();
return
TO_SEQUENCE
(
updated_inner_ys_2_rhs_minor_
,
ndim_tmp
);
}();
//
constexpr
auto
ps_2_rhss_major
=
container_concat
(
typename
OuterDstr
::
Ps2RHssMajor
{},
typename
InnerDstr
::
Ps2RHssMajor
{});
constexpr
auto
ps_2_rhss_minor
=
container_concat
(
typename
OuterDstr
::
Ps2RHssMinor
{},
updated_inner_ps_2_rhss_minor
);
//
constexpr
auto
ys_2_rhs_major
=
merge_sequences
(
typename
OuterDstr
::
Ys2RHsMajor
{},
typename
InnerDstr
::
Ys2RHsMajor
{});
constexpr
auto
ys_2_rhs_minor
=
merge_sequences
(
typename
OuterDstr
::
Ys2RHsMinor
{},
updated_inner_ys_2_rhs_minor
);
return
tile_distribution_encoding
<
RsLengths
,
remove_cvref_t
<
decltype
(
hs_lengthss
)
>
,
remove_cvref_t
<
decltype
(
ps_2_rhss_major
)
>
,
remove_cvref_t
<
decltype
(
ps_2_rhss_minor
)
>
,
remove_cvref_t
<
decltype
(
ys_2_rhs_major
)
>
,
remove_cvref_t
<
decltype
(
ys_2_rhs_minor
)
>>
{};
}
template
<
typename
InDstr
,
index_t
...
InReduceDimXs
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_reduce_tile_distribution_encoding_impl
(
InDstr
,
sequence
<
InReduceDimXs
...
>
reduce_dim_xs_in
)
{
constexpr
auto
I1
=
number
<
1
>
{};
// FIXME: increase if fail
constexpr
index_t
max_ndim_r_out
=
20
;
constexpr
index_t
max_ndim_y_out
=
20
;
//
constexpr
index_t
ndim_p
=
InDstr
::
NDimP
;
constexpr
index_t
ndim_x_in
=
InDstr
::
NDimX
;
constexpr
index_t
ndim_y_in
=
InDstr
::
NDimY
;
constexpr
index_t
ndim_rh_major_in
=
InDstr
::
NDimX
+
1
;
constexpr
index_t
ndim_x_out
=
ndim_x_in
-
sizeof
...(
InReduceDimXs
);
constexpr
index_t
max_ndim_rh_minor_in
=
InDstr
::
detail
::
max_ndim_rh_minor_
;
// ndims_ps_low
constexpr
auto
ndims_ps_low
=
generate_array
(
[
&
](
auto
i
)
{
return
InDstr
::
ps_to_rhss_major_
[
i
].
size
();
},
number
<
ndim_p
>
{});
// is_rh_major_in_for_reduce
array
<
bool
,
ndim_rh_major_in
>
is_rh_major_in_for_reduce
{
false
};
for
(
index_t
i
=
0
;
i
<
reduce_dim_xs_in
.
size
();
i
++
)
{
index_t
rh_major
=
reduce_dim_xs_in
[
i
]
+
1
;
is_rh_major_in_for_reduce
(
rh_major
)
=
true
;
}
// is_y_in_for_reduce
array
<
bool
,
ndim_y_in
>
is_y_in_for_reduce
{
false
};
for
(
index_t
i
=
0
;
i
<
ndim_y_in
;
i
++
)
{
index_t
rh_major
=
InDstr
::
ys_to_rhs_major_
[
i
];
if
(
is_rh_major_in_for_reduce
[
rh_major
])
{
is_y_in_for_reduce
(
i
)
=
true
;
}
}
// is_rh_minor_in_for_y_reduce
array
<
array
<
bool
,
max_ndim_rh_minor_in
>
,
ndim_rh_major_in
>
is_rh_minor_in_for_y_reduce
{{
false
}};
static_for
<
0
,
ndim_y_in
,
1
>
{}([
&
](
auto
i
)
{
index_t
rh_major
=
InDstr
::
ys_to_rhs_major_
[
i
];
index_t
rh_minor
=
InDstr
::
ys_to_rhs_minor_
[
i
];
if
(
is_y_in_for_reduce
[
i
])
{
is_rh_minor_in_for_y_reduce
(
rh_major
)(
rh_minor
)
=
true
;
}
});
// in2out_rh_major
array
<
index_t
,
ndim_rh_major_in
>
in2out_rh_major
{
-
1
};
index_t
cnt_ndim_rh_major_out
=
0
;
for
(
index_t
i
=
0
;
i
<
ndim_rh_major_in
;
i
++
)
{
if
(
is_rh_major_in_for_reduce
[
i
])
{
in2out_rh_major
(
i
)
=
0
;
}
else
{
in2out_rh_major
(
i
)
=
cnt_ndim_rh_major_out
;
cnt_ndim_rh_major_out
++
;
}
}
// rs_lengths_out, in2out_rh_minor
array
<
index_t
,
max_ndim_r_out
>
rs_lengths_out
{
-
1
};
array
<
array
<
index_t
,
max_ndim_rh_minor_in
>
,
ndim_rh_major_in
>
in2out_rh_minor
{{
-
1
}};
// loop over input R dim
for
(
index_t
i
=
0
;
i
<
InDstr
::
rs_lengths_
.
size
();
i
++
)
{
// rs_lengths_out
rs_lengths_out
(
i
)
=
InDstr
::
rs_lengths_
[
i
];
// in2out_rh_minor
in2out_rh_minor
(
0
)(
i
)
=
i
;
}
// loop over input H Dim
index_t
cnt_ndim_r_out
=
InDstr
::
rs_lengths_
.
size
();
static_for
<
1
,
ndim_rh_major_in
,
1
>
{}([
&
](
auto
rh_major_in
)
{
constexpr
auto
h_major_in
=
rh_major_in
-
I1
;
constexpr
index_t
ndim_rh_minor_in
=
InDstr
::
hs_lengthss_
[
h_major_in
].
size
();
if
(
is_rh_major_in_for_reduce
[
rh_major_in
])
{
for
(
index_t
rh_minor_in
=
0
;
rh_minor_in
<
ndim_rh_minor_in
;
rh_minor_in
++
)
{
if
(
not
is_rh_minor_in_for_y_reduce
[
rh_major_in
][
rh_minor_in
])
{
// rs_lengths_out
rs_lengths_out
(
cnt_ndim_r_out
)
=
InDstr
::
hs_lengthss_
[
h_major_in
][
rh_minor_in
];
// in2out_rh_minor
in2out_rh_minor
(
rh_major_in
)(
rh_minor_in
)
=
cnt_ndim_r_out
;
cnt_ndim_r_out
++
;
}
}
}
else
{
for
(
index_t
rh_minor_in
=
0
;
rh_minor_in
<
ndim_rh_minor_in
;
rh_minor_in
++
)
{
// in2out_rh_minor
in2out_rh_minor
(
rh_major_in
)(
rh_minor_in
)
=
rh_minor_in
;
}
}
});
// ndim_r_out
const
index_t
ndim_r_out
=
cnt_ndim_r_out
;
// ndims_hs_minor_out, hs_lengthss_out
array
<
index_t
,
ndim_x_out
>
ndims_hs_minor_out
{
-
1
};
array
<
array
<
index_t
,
max_ndim_rh_minor_in
>
,
ndim_x_out
>
hs_lengthss_out
{{
-
1
}};
index_t
cnt_ndim_x_out
=
0
;
static_for
<
0
,
ndim_x_in
,
1
>
{}([
&
](
auto
i
)
{
if
(
not
is_rh_major_in_for_reduce
[
i
+
I1
])
{
// ndims_hs_minor_out
ndims_hs_minor_out
(
cnt_ndim_x_out
)
=
InDstr
::
hs_lengthss_
[
i
].
size
();
// hs_lengthss_out
static_for
<
0
,
InDstr
::
hs_lengthss_
[
i
].
size
(),
1
>
{}(
[
&
](
auto
j
)
{
hs_lengthss_out
(
cnt_ndim_x_out
)(
j
)
=
InDstr
::
hs_lengthss_
[
i
][
j
];
});
cnt_ndim_x_out
++
;
}
});
// ps_to_rhss_major_out, ps_to_rhss_minor_out
array
<
array
<
index_t
,
max_ndim_rh_minor_in
>
,
ndim_p
>
ps_to_rhss_major_out
{{
-
1
}};
array
<
array
<
index_t
,
max_ndim_rh_minor_in
>
,
ndim_p
>
ps_to_rhss_minor_out
{{
-
1
}};
static_for
<
0
,
ndim_p
,
1
>
{}([
&
](
auto
idim_p
)
{
static_for
<
0
,
InDstr
::
ps_to_rhss_major_
[
idim_p
].
size
(),
1
>
{}([
&
](
auto
idim_low
)
{
index_t
rh_major_in
=
InDstr
::
ps_to_rhss_major_
[
idim_p
][
idim_low
];
index_t
rh_minor_in
=
InDstr
::
ps_to_rhss_minor_
[
idim_p
][
idim_low
];
ps_to_rhss_major_out
(
idim_p
)(
idim_low
)
=
in2out_rh_major
[
rh_major_in
];
ps_to_rhss_minor_out
(
idim_p
)(
idim_low
)
=
in2out_rh_minor
[
rh_major_in
][
rh_minor_in
];
});
});
// ys_to_rhs_major_out, ys_to_rhs_minor_out
array
<
index_t
,
max_ndim_y_out
>
ys_to_rhs_major_out
{
-
1
};
array
<
index_t
,
max_ndim_y_out
>
ys_to_rhs_minor_out
{
-
1
};
index_t
cnt_ndim_y_out
=
0
;
static_for
<
0
,
ndim_y_in
,
1
>
{}([
&
](
auto
i
)
{
if
(
not
is_y_in_for_reduce
[
i
])
{
index_t
rh_major_in
=
InDstr
::
ys_to_rhs_major_
[
i
];
index_t
rh_minor_in
=
InDstr
::
ys_to_rhs_minor_
[
i
];
ys_to_rhs_major_out
(
cnt_ndim_y_out
)
=
in2out_rh_major
[
rh_major_in
];
ys_to_rhs_minor_out
(
cnt_ndim_y_out
)
=
in2out_rh_minor
[
rh_major_in
][
rh_minor_in
];
cnt_ndim_y_out
++
;
}
});
// ndim_y_out
const
index_t
ndim_y_out
=
cnt_ndim_y_out
;
//
return
make_tuple
(
ndim_x_out
,
ndim_p
,
ndim_y_out
,
ndim_r_out
,
ndims_hs_minor_out
,
ndims_ps_low
,
rs_lengths_out
,
hs_lengthss_out
,
ps_to_rhss_major_out
,
ps_to_rhss_minor_out
,
ys_to_rhs_major_out
,
ys_to_rhs_minor_out
);
}
template
<
typename
InDstr
,
index_t
...
InReduceDimXs
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_reduce_tile_distribution_encoding
(
InDstr
,
sequence
<
InReduceDimXs
...
>
reduce_dim_xs_in
)
{
constexpr
auto
impl
=
make_reduce_tile_distribution_encoding_impl
(
InDstr
{},
reduce_dim_xs_in
);
constexpr
index_t
ndim_x
=
impl
.
template
at
<
0
>();
constexpr
index_t
ndim_p
=
impl
.
template
at
<
1
>();
constexpr
index_t
ndim_y
=
impl
.
template
at
<
2
>();
constexpr
index_t
ndim_r
=
impl
.
template
at
<
3
>();
constexpr
auto
ndims_hs_minor
=
impl
.
template
at
<
4
>();
constexpr
auto
ndims_ps_low
=
impl
.
template
at
<
5
>();
constexpr
auto
rs_lengths_impl
=
impl
.
template
at
<
6
>();
constexpr
auto
hs_lengthss_impl
=
impl
.
template
at
<
7
>();
constexpr
auto
ps_to_rhss_major_impl
=
impl
.
template
at
<
8
>();
constexpr
auto
ps_to_rhss_minor_impl
=
impl
.
template
at
<
9
>();
constexpr
auto
ys_to_rhs_major_impl
=
impl
.
template
at
<
10
>();
constexpr
auto
ys_to_rhs_minor_impl
=
impl
.
template
at
<
11
>();
constexpr
auto
rs_lengths
=
TO_SEQUENCE
(
rs_lengths_impl
,
ndim_r
);
constexpr
auto
hs_lengthss
=
TO_TUPLE_OF_SEQUENCE
(
hs_lengthss_impl
,
ndim_x
,
ndims_hs_minor
);
constexpr
auto
ps_to_rhss_major
=
TO_TUPLE_OF_SEQUENCE
(
ps_to_rhss_major_impl
,
ndim_p
,
ndims_ps_low
);
constexpr
auto
ps_to_rhss_minor
=
TO_TUPLE_OF_SEQUENCE
(
ps_to_rhss_minor_impl
,
ndim_p
,
ndims_ps_low
);
constexpr
auto
ys_to_rhs_major
=
TO_SEQUENCE
(
ys_to_rhs_major_impl
,
ndim_y
);
constexpr
auto
ys_to_rhs_minor
=
TO_SEQUENCE
(
ys_to_rhs_minor_impl
,
ndim_y
);
return
tile_distribution_encoding
<
remove_cvref_t
<
decltype
(
rs_lengths
)
>
,
remove_cvref_t
<
decltype
(
hs_lengthss
)
>
,
remove_cvref_t
<
decltype
(
ps_to_rhss_major
)
>
,
remove_cvref_t
<
decltype
(
ps_to_rhss_minor
)
>
,
remove_cvref_t
<
decltype
(
ys_to_rhs_major
)
>
,
remove_cvref_t
<
decltype
(
ys_to_rhs_minor
)
>>
{};
}
}
// namespace detail
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_elementwise.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// TODO: support tensors with different distribution
template
<
typename
InOutElementFunc
,
typename
...
InOutDstrTensors
,
typename
=
std
::
enable_if_t
<
std
::
conjunction_v
<
std
::
negation
<
std
::
is_same
<
std
::
remove_const_t
<
InOutDstrTensors
>,
null_tensor
>>
...
>>>
CK_TILE_DEVICE
void
tile_elementwise_inout
(
const
InOutElementFunc
&
inout_element_func
,
InOutDstrTensors
&
...
inout_dstr_tensors
)
{
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
constexpr
index_t
thread_buffer_size
=
__type_pack_element
<
0
,
InOutDstrTensors
...
>::
get_thread_buffer_size
();
static_for
<
0
,
thread_buffer_size
,
1
>
{}(
[
&
](
auto
i
)
{
inout_element_func
(
inout_dstr_tensors
.
get_thread_buffer
().
at
(
i
)...);
});
}
template
<
typename
InElementFunc
,
typename
...
InTensor
,
typename
=
std
::
enable_if_t
<
std
::
conjunction_v
<
std
::
negation
<
std
::
is_same
<
InTensor
,
null_tensor
>
>
...
>>>
CK_TILE_DEVICE
auto
tile_elementwise_in
(
const
InElementFunc
&
in_element_func
,
const
InTensor
&
...
in_dstr_tensors
)
{
using
OutDataType
=
decltype
(
in_element_func
(
typename
InTensor
::
DataType
{}...));
// TODO: make sure all distributed tensors have same lengths and distribution
// static_assert(xxx);
constexpr
auto
in_tile_dstr
=
__type_pack_element
<
0
,
InTensor
...
>::
get_tile_distribution
();
constexpr
index_t
thread_buffer_size
=
__type_pack_element
<
0
,
InTensor
...
>::
get_thread_buffer_size
();
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
i
)
{
out_dstr_tensor
.
get_thread_buffer
()(
i
)
=
in_element_func
(
in_dstr_tensors
.
get_thread_buffer
()[
i
]...);
});
return
out_dstr_tensor
;
}
template
<
typename
DstrTensors
,
typename
T
>
CK_TILE_DEVICE
void
set_tile
(
DstrTensors
&
dstr_tensor
,
const
T
&
value
)
{
tile_elementwise_inout
(
[
&
value
](
auto
&
x
)
{
x
=
type_convert
<
typename
DstrTensors
::
DataType
,
remove_cvref_t
<
T
>>
(
value
);
},
dstr_tensor
);
}
template
<
typename
T
>
CK_TILE_DEVICE
void
set_tile
(
null_tensor
&
,
const
T
&
)
{
}
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor...
template
<
typename
DstrTensors
,
index_t
v
>
CK_TILE_DEVICE
void
set_tile
(
DstrTensors
&
dstr_tensor
,
number
<
v
>
)
{
constexpr
index_t
tensor_bytes
=
DstrTensors
::
get_thread_buffer_size
()
*
sizeof
(
typename
DstrTensors
::
DataType
);
if
constexpr
(
v
==
0
&&
tensor_bytes
%
4
==
0
)
{
using
dvec_t
=
array
<
index_t
,
tensor_bytes
/
4
>
;
auto
&
tensor
=
reinterpret_cast
<
dvec_t
&>
(
dstr_tensor
.
get_thread_buffer
());
for
(
auto
i
=
0
;
i
<
tensor
.
size
();
i
++
)
tensor
.
get
(
i
)
=
v
;
}
else
{
tile_elementwise_inout
(
[](
auto
&
x
)
{
x
=
type_convert
<
typename
DstrTensors
::
DataType
,
index_t
>
(
v
);
},
dstr_tensor
);
}
}
template
<
index_t
v
>
CK_TILE_DEVICE
void
set_tile
(
null_tensor
&
,
number
<
v
>
)
{
}
template
<
typename
DstrTensors
>
CK_TILE_DEVICE
void
clear_tile
(
DstrTensors
&
dstr_tensor
)
{
set_tile
(
dstr_tensor
,
0
);
}
namespace
impl
{
// TODO: this is ugly
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp8x4
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// This API is designed to use the _pk_ serious of function
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
constexpr
index_t
thread_buffer_size
=
InTensor
::
get_thread_buffer_size
();
static_assert
(
thread_buffer_size
%
4
==
0
);
constexpr
index_t
thread_buffer_size_pk
=
thread_buffer_size
/
4
;
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wuninitialized"
// __builtin_amdgcn_cvt_pk_fp8_f32() this builtin require the old value, and
// will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA
// so we prepare an uninitialized variable purposely, and turn off the warning
int
dummy_old
;
static_for
<
0
,
thread_buffer_size_pk
,
1
>
{}([
&
](
auto
i
)
{
uint32_t
x
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
in_dstr_tensors
.
get_thread_buffer
()[
number
<
4
*
i
+
0
>
{}],
in_dstr_tensors
.
get_thread_buffer
()[
number
<
4
*
i
+
1
>
{}],
dummy_old
,
false
);
// false -> WORD0
uint32_t
y
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
in_dstr_tensors
.
get_thread_buffer
()[
number
<
4
*
i
+
2
>
{}],
in_dstr_tensors
.
get_thread_buffer
()[
number
<
4
*
i
+
3
>
{}],
dummy_old
,
false
);
// false -> WORD0
constexpr
int32_t
m0
=
0x05040100
;
using
vec_t
=
array
<
OutDataType
,
4
>
;
vec_t
d
=
bit_cast
<
vec_t
>
(
__builtin_amdgcn_perm
(
y
,
x
,
m0
));
out_dstr_tensor
.
get_thread_buffer
().
template
set_as
<
vec_t
>(
number
<
i
>
{},
d
);
});
#pragma clang diagnostic pop
return
out_dstr_tensor
;
#else
// fallback
return
tile_elementwise_in
(
type_convert
<
OutDataType
,
typename
InTensor
::
DataType
>
,
in_dstr_tensors
);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_opt_subdword
(
const
InTensor
&
in_dstr_tensors
)
{
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
using
i_type
=
remove_cvref_t
<
typename
InTensor
::
DataType
>
;
using
o_type
=
remove_cvref_t
<
OutDataType
>
;
constexpr
index_t
i_elem_bytes
=
sizeof
(
i_type
);
constexpr
index_t
o_elem_bytes
=
sizeof
(
o_type
);
static_assert
(
i_elem_bytes
<
4
||
o_elem_bytes
<
4
);
constexpr
index_t
bulk_size
=
(
i_elem_bytes
>=
o_elem_bytes
)
?
(
4
/
o_elem_bytes
)
:
(
4
/
i_elem_bytes
);
static_assert
(
bulk_size
!=
0
);
using
o_bulk_type
=
std
::
conditional_t
<
i_elem_bytes
>=
o_elem_bytes
,
float
,
array
<
o_type
,
bulk_size
>>
;
constexpr
index_t
thread_buffer_size
=
InTensor
::
get_thread_buffer_size
();
constexpr
index_t
iters
=
thread_buffer_size
/
bulk_size
;
constexpr
index_t
rems
=
thread_buffer_size
%
bulk_size
;
// cast the sequence per-bulk
static_for
<
0
,
iters
,
1
>
{}([
&
](
auto
i
)
{
union
bulk_wrapper
{
o_bulk_type
bulk
{};
o_type
data
[
bulk_size
];
}
o_bulk
;
// TODO: should use below function, but somehow will result in spill (same as c-forloop)
static_for
<
0
,
bulk_size
,
1
>
{}([
&
o_bulk
,
&
in_dstr_tensors
,
&
i
](
auto
ib
)
{
o_bulk
.
data
[
ib
.
value
]
=
static_cast
<
o_type
>
(
in_dstr_tensors
.
get_thread_buffer
()
.
template
get_as
<
i_type
>()[
number
<
bulk_size
*
i
.
value
+
ib
.
value
>
{}]);
});
// TODO: fixme, should use above!
// static_assert(sizeof(i_type) / sizeof(o_type) == 2);
// o_bulk.data[0] = static_cast<o_type>(
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
// o_bulk.data[1] = static_cast<o_type>(
// in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
out_dstr_tensor
.
get_thread_buffer
().
template
set_as
<
o_bulk_type
>(
i
,
o_bulk
.
bulk
);
});
static_for
<
0
,
rems
,
1
>
{}([
&
](
auto
r
)
{
// TODO: introducing local scratch pad?
auto
idx
=
number
<
iters
*
bulk_size
+
r
>
{};
out_dstr_tensor
.
get_thread_buffer
().
at
(
idx
)
=
static_cast
<
o_type
>
(
in_dstr_tensors
.
get_thread_buffer
().
at
(
idx
));
});
return
out_dstr_tensor
;
}
#endif
}
// namespace impl
template
<
typename
DstType
,
typename
SrcTensor
>
CK_TILE_DEVICE
auto
cast_tile
(
const
SrcTensor
&
src_tensor
)
{
if
constexpr
((
std
::
is_same_v
<
DstType
,
fp8_t
>
||
std
::
is_same_v
<
DstType
,
bf8_t
>
)
&&
std
::
is_same_v
<
typename
SrcTensor
::
DataType
,
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
4
==
0
))
{
return
impl
::
cast_tile_pk_fp8x4
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
else
if
constexpr
(
sizeof
(
DstType
)
<
4
||
sizeof
(
typename
SrcTensor
::
DataType
)
<
4
)
{
return
impl
::
cast_tile_opt_subdword
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#endif
else
return
tile_elementwise_in
(
type_convert
<
DstType
,
typename
SrcTensor
::
DataType
>
,
src_tensor
);
}
// no-op function for null_tensor arguments
template
<
typename
InOutElementFunc
,
typename
...
MaybeNullTensor
,
typename
=
std
::
enable_if_t
<
std
::
disjunction_v
<
std
::
is_same
<
remove_cvref_t
<
MaybeNullTensor
>,
null_tensor
>
...
>>>
CK_TILE_DEVICE
void
tile_elementwise_inout
(
const
InOutElementFunc
&
,
MaybeNullTensor
&&
...)
{
}
// no-op function for null_tensor arguments
template
<
typename
InElementFunc
,
typename
...
MaybeNullTensor
,
typename
=
std
::
enable_if_t
<
std
::
disjunction_v
<
std
::
is_same
<
remove_cvref_t
<
MaybeNullTensor
>,
null_tensor
>
...
>>>
CK_TILE_DEVICE
auto
tile_elementwise_in
(
const
InElementFunc
&
,
MaybeNullTensor
&&
...)
{
return
null_tensor
{};
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_window.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
index_t
NumCoord
>
struct
tile_window_with_static_distribution
{
using
BottomTensorView
=
remove_reference_t
<
BottomTensorView_
>
;
using
WindowLengths
=
remove_cvref_t
<
WindowLengths_
>
;
using
TileDstr
=
remove_cvref_t
<
StaticTileDistribution_
>
;
using
WindowAdaptor
=
typename
TileDstr
::
PsYs2XsAdaptor
;
using
BottomTensorDesc
=
typename
BottomTensorView
::
TensorDesc
;
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView
::
DataType
>
;
static
constexpr
index_t
NDimWindowAdaptorTop
=
WindowAdaptor
::
get_num_of_top_dimension
();
static
constexpr
index_t
NDimBottomTensor
=
BottomTensorDesc
::
get_num_of_dimension
();
static
constexpr
index_t
NDimP
=
TileDstr
::
get_num_of_dimension_p
();
static
constexpr
index_t
NDimY
=
TileDstr
::
get_num_of_dimension_y
();
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths
>::
value
,
"wrong! lengths should be static"
);
static_assert
(
TileDstr
::
is_static
(),
"wrong!"
);
static_assert
(
NDimBottomTensor
==
WindowAdaptor
::
get_num_of_bottom_dimension
(),
"wrong! inconsistent # of diemsnions"
);
using
AdaptorTopIndex
=
array
<
index_t
,
NDimWindowAdaptorTop
>
;
using
BottomTensorIndex
=
array
<
index_t
,
NDimBottomTensor
>
;
using
WindowAdaptorCoord
=
decltype
(
make_tensor_adaptor_coordinate
(
WindowAdaptor
{},
AdaptorTopIndex
{}));
using
BottomTensorCoord
=
decltype
(
make_tensor_coordinate
(
BottomTensorDesc
{},
BottomTensorIndex
{}));
struct
load_store_traits
{
private:
static
constexpr
auto
get_vector_dim_y_scalar_per_vector
()
{
const
auto
[
ys_vector_lengths
,
ys_vector_strides
]
=
tile_window_with_static_distribution
::
get_window_adaptor_ys_safe_vector_length_strides
();
index_t
VectorDimY_
=
0
;
index_t
ScalarPerVector_
=
1
;
for
(
index_t
i
=
0
;
i
<
NDimY
;
++
i
)
{
if
(
ys_vector_strides
[
i
]
==
1
&&
ys_vector_lengths
[
i
]
>
ScalarPerVector_
)
{
ScalarPerVector_
=
ys_vector_lengths
[
i
];
VectorDimY_
=
i
;
}
}
return
make_tuple
(
VectorDimY_
,
ScalarPerVector_
);
}
public:
static
constexpr
index_t
VectorDimY
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
0
>();
static
constexpr
index_t
ScalarPerVector
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
1
>();
// using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
// using vector_t = typename vector_type_t::type;
using
vector_t
=
thread_buffer
<
DataType
,
ScalarPerVector
>
;
private:
static
constexpr
auto
scalars_per_access_
=
[]
{
constexpr
auto
scalars_per_access_arr
=
generate_array
(
[
&
](
auto
i
)
{
return
(
i
==
VectorDimY
)
?
ScalarPerVector
:
1
;
},
number
<
NDimY
>
{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr
auto
NDimY_
=
NDimY
;
return
TO_SEQUENCE
(
scalars_per_access_arr
,
NDimY_
);
}();
static
constexpr
auto
get_space_filling_curve
()
{
constexpr
auto
tile_dstr
=
TileDstr
{};
constexpr
auto
thread_tensor_lengths_ys
=
to_sequence
(
tile_dstr
.
get_ys_to_d_descriptor
().
get_lengths
());
// FIXME: need logic to judge dim access order
using
DimAccessOrder
=
typename
arithmetic_sequence_gen
<
0
,
NDimY
,
1
>::
type
;
return
space_filling_curve
<
decltype
(
thread_tensor_lengths_ys
),
DimAccessOrder
,
decltype
(
scalars_per_access_
)
>
{};
}
public:
using
SFC_Ys
=
decltype
(
get_space_filling_curve
());
static
constexpr
index_t
NumAccess
=
SFC_Ys
::
get_num_of_access
();
static_assert
(
0
<
NumAccess
,
"Wrong! NumAccess should be larger than 0"
);
static_assert
(
NumAccess
%
NumCoord
==
0
,
"wrong! # of access is not divisible by NumCoord"
);
};
static
constexpr
index_t
NumAccessPerCoord
=
load_store_traits
::
NumAccess
/
NumCoord
;
CK_TILE_DEVICE
constexpr
tile_window_with_static_distribution
()
=
default
;
CK_TILE_DEVICE
constexpr
tile_window_with_static_distribution
(
const
BottomTensorView
&
bottom_tensor_view
,
const
WindowLengths
&
window_lengths
,
const
BottomTensorIndex
&
window_origin
,
const
TileDstr
&
tile_distribution
)
:
bottom_tensor_view_
{
bottom_tensor_view
},
window_lengths_
{
window_lengths
},
window_origin_
{
window_origin
},
tile_dstr_
{
tile_distribution
},
pre_computed_coords_
{}
{
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_distribution.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
tile_distribution
.
get_ps_ys_to_xs_adaptor
(),
container_concat
(
detail
::
get_partition_index
(
tile_distribution
),
array
<
index_t
,
NDimY
>
{
0
}));
#endif
BottomTensorIndex
bottom_tensor_thread_origin_idx_tmp
=
window_origin
+
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
const
auto
bottom_tensor_thread_coord_tmp
=
make_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_origin_idx_tmp
);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using
Traits
=
load_store_traits
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
window_adaptor_thread_coord
=
window_adaptor_thread_coord_tmp
;
auto
bottom_tensor_thread_coord
=
bottom_tensor_thread_coord_tmp
;
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_step_between
(
number
<
0
>
{},
number
<
iCoord
*
NumAccessPerCoord
>
{});
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
pre_computed_coords_
(
iCoord
)
=
make_tuple
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
);
});
}
CK_TILE_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
{
return
NDimBottomTensor
;
}
CK_TILE_DEVICE
static
constexpr
bool
has_static_tile_distribution
()
{
return
TileDstr
::
is_static
();
}
CK_TILE_DEVICE
constexpr
auto
get_window_lengths
()
const
{
return
window_lengths_
;
}
CK_TILE_DEVICE
constexpr
auto
get_tile_distribution
()
const
{
return
tile_dstr_
;
}
CK_TILE_DEVICE
constexpr
auto
get_bottom_tensor_view
()
const
{
return
bottom_tensor_view_
;
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
WindowAdaptorCoord
&
window_adaptor_thread_coord
,
BottomTensorCoord
&
bottom_tensor_thread_coord
,
const
AdaptorTopIndex
&
idx_diff_adaptor_top
)
const
{
array
<
index_t
,
NDimBottomTensor
>
idx_diff_adaptor_bottom
;
move_tensor_adaptor_coordinate
(
tile_dstr_
.
get_ps_ys_to_xs_adaptor
(),
window_adaptor_thread_coord
,
idx_diff_adaptor_top
,
idx_diff_adaptor_bottom
);
move_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_coord
,
idx_diff_adaptor_bottom
);
}
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE
static
constexpr
auto
get_window_adaptor_ys_safe_vector_length_strides
()
{
// bottom tensor top dimension vector lengths and strides
const
auto
[
bottom_tensor_top_dim_vector_lengths
,
bottom_tensor_top_dim_vector_strides
]
=
BottomTensorDesc
::
get_top_dimension_safe_vector_length_strides
();
// window vector lengths/strides
const
auto
window_adaptor_bottom_dim_vector_lengths
=
bottom_tensor_top_dim_vector_lengths
;
const
auto
window_adaptor_bottom_dim_vector_strides
=
bottom_tensor_top_dim_vector_strides
;
// window adaptor [p0, p1, ..., y0, y1, ...]
array
<
index_t
,
WindowAdaptor
::
get_num_of_hidden_dimension
()
>
window_adaptor_vector_lengths
{
-
1
};
array
<
index_t
,
WindowAdaptor
::
get_num_of_hidden_dimension
()
>
window_adaptor_vector_strides
{
-
1
};
constexpr
auto
window_adaptor_bottom_dims
=
WindowAdaptor
::
get_bottom_dimension_hidden_ids
();
set_container_subset
(
window_adaptor_vector_lengths
,
window_adaptor_bottom_dims
,
window_adaptor_bottom_dim_vector_lengths
);
set_container_subset
(
window_adaptor_vector_strides
,
window_adaptor_bottom_dims
,
window_adaptor_bottom_dim_vector_strides
);
const
auto
[
window_adaptor_ps_ys_vector_lengths
,
window_adaptor_ps_ys_vector_strides
]
=
WindowAdaptor
{}.
get_top_dimension_safe_vector_length_strides
(
window_adaptor_vector_lengths
,
window_adaptor_vector_strides
);
// [y0, y1, ...]
constexpr
auto
y_dims
=
typename
arithmetic_sequence_gen
<
TileDstr
::
get_num_of_dimension_p
(),
NDimWindowAdaptorTop
,
1
>::
type
{};
return
make_tuple
(
get_container_subset
(
window_adaptor_ps_ys_vector_lengths
,
y_dims
),
get_container_subset
(
window_adaptor_ps_ys_vector_strides
,
y_dims
));
}
CK_TILE_DEVICE
constexpr
auto
get_num_access
()
const
{
return
load_store_traits
::
NumAccess
;
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// write into distributed tensor
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
Traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
return
dst_tensor
;
}
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
// using vector_type_t = typename Traits::vector_type_t;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
static
constexpr
index_t
YElementSize
=
TileDstr
{}.
get_ys_to_d_descriptor
().
get_element_space_size
();
static_assert
(
YElementSize
%
Traits
::
ScalarPerVector
==
0
);
using
vectorized_tbuf
=
array
<
vector_t
,
YElementSize
/
Traits
::
ScalarPerVector
>
;
// StaticBuffer<address_space_enum::vgpr,
// vector_t,
// YElementSize / Traits::ScalarPerVector,
// true>;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
&
dst_vec_tbuf
=
reinterpret_cast
<
vectorized_tbuf
&>
(
dst_tensor
.
get_thread_buffer
());
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
using
LdsDataType
=
typename
LdsTileWindow
::
DataType
;
// using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
// issues * warps * lanes
static_assert
(
LdsTileWindow
::
get_num_of_dimension
()
==
3
);
// TODO: hard coded
const
index_t
size_per_buf
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
);
const
index_t
size_per_wave
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
1
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
size_per_issue
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
1
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
m0_init_value
=
size_per_buf
+
size_per_wave
*
get_warp_id
();
m0_set_with_memory
(
m0_init_value
);
// This should be wave independent
using
Traits
=
load_store_traits
;
// using vector_type_t = typename Traits::vector_type_t;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
m0_inc_with_memory
(
size_per_issue
);
}
});
});
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
// using vector_type_t = typename Traits::vector_type_t;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
// vector_type_t vec;
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
)
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
static
constexpr
bool
oob_conditional_check
=
true
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
()
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
bottom_tensor_thread_coord
,
vec_value
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE
void
move
(
const
BottomTensorIndex
&
step
)
{
window_origin_
+=
step
;
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
move_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
pre_computed_coords_
(
iCoord
)(
I1
),
step
);
});
}
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView
bottom_tensor_view_
;
//
WindowLengths
window_lengths_
;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex
window_origin_
;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr
tile_dstr_
;
// this contains:
// per-thread coordinate for window adaptor
// per-thread coordinate for bottom tensor
array
<
tuple
<
WindowAdaptorCoord
,
BottomTensorCoord
>
,
NumCoord
>
pre_computed_coords_
;
};
// TODO: use strategy
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
index_t
NumCoord
=
1
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
TensorView_
&
tensor_view
,
const
WindowLengths_
&
window_lengths
,
const
multi_index
<
TensorView_
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution_
&
tile_distribution
,
number
<
NumCoord
>
=
{})
{
return
tile_window_with_static_distribution
<
remove_cvref_t
<
TensorView_
>
,
remove_cvref_t
<
WindowLengths_
>
,
remove_cvref_t
<
StaticTileDistribution_
>
,
NumCoord
>
{
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
}
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
index_t
NumCoord
>
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_with_static_distribution
<
TensorView_
,
WindowLengths_
,
StaticTileDistribution_
,
NumCoord
>&
window
,
const
typename
tile_window_with_static_distribution
<
TensorView_
,
WindowLengths_
,
StaticTileDistribution_
,
NumCoord
>::
BottomTensorIndex
&
step
)
{
window
.
move
(
step
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
>
struct
tile_window_with_static_lengths
{
using
BottomTensorView
=
remove_reference_t
<
BottomTensorView_
>
;
using
WindowLengths
=
remove_cvref_t
<
WindowLengths_
>
;
using
BottomTensorDesc
=
typename
BottomTensorView
::
TensorDesc
;
using
DataType
=
typename
BottomTensorView
::
DataType
;
static
constexpr
index_t
NDimBottomTensor
=
BottomTensorDesc
::
get_num_of_dimension
();
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths
>::
value
,
"wrong! lengths should be static"
);
using
BottomTensorIndex
=
array
<
index_t
,
NDimBottomTensor
>
;
CK_TILE_DEVICE
constexpr
tile_window_with_static_lengths
()
=
default
;
CK_TILE_DEVICE
constexpr
tile_window_with_static_lengths
(
const
BottomTensorView
&
bottom_tensor_view
,
const
WindowLengths
&
window_lengths
,
const
BottomTensorIndex
&
window_origin
)
:
bottom_tensor_view_
{
bottom_tensor_view
},
window_lengths_
{
window_lengths
},
window_origin_
{
window_origin
}
{
}
CK_TILE_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
{
return
NDimBottomTensor
;
}
CK_TILE_DEVICE
constexpr
auto
get_window_lengths
()
const
{
return
window_lengths_
;
}
CK_TILE_DEVICE
constexpr
auto
get_bottom_tensor_view
()
const
{
return
bottom_tensor_view_
;
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
// move window-origin
CK_TILE_DEVICE
void
move
(
const
BottomTensorIndex
&
step
)
{
window_origin_
+=
step
;
}
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView
bottom_tensor_view_
;
//
WindowLengths
window_lengths_
;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex
window_origin_
;
};
template
<
typename
TensorView_
,
typename
WindowLengths_
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
TensorView_
&
tensor_view
,
const
WindowLengths_
&
window_lengths
,
const
multi_index
<
TensorView_
::
get_num_of_dimension
()
>&
origin
)
{
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths_
>::
value
,
"wrong! lengths should be static"
);
return
tile_window_with_static_lengths
<
remove_cvref_t
<
TensorView_
>
,
remove_cvref_t
<
WindowLengths_
>>
{
tensor_view
,
window_lengths
,
origin
};
}
template
<
typename
TensorView_
,
typename
WindowLengths_
>
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>&
window
,
const
typename
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>::
BottomTensorIndex
&
step
)
{
window
.
move
(
step
);
}
}
// namespace ck_tile
include/ck_tile/core/utility/bit_cast.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace
ck_tile
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
Y
bit_cast
(
const
X
&
x
)
{
static_assert
(
__has_builtin
(
__builtin_bit_cast
),
""
);
static_assert
(
sizeof
(
X
)
==
sizeof
(
Y
),
"Do not support cast between different size of type"
);
return
__builtin_bit_cast
(
Y
,
x
);
}
}
// namespace ck_tile
include/ck_tile/core/utility/functional.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include <stdint.h>
#include <utility>
namespace
ck_tile
{
namespace
detail
{
struct
swallow
{
template
<
typename
...
Ts
>
CK_TILE_HOST_DEVICE
constexpr
swallow
(
Ts
&&
...)
{
}
};
template
<
class
>
struct
static_for_impl
;
template
<
index_t
...
Is
>
struct
static_for_impl
<
sequence
<
Is
...
>>
{
template
<
class
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
)
const
{
swallow
{(
f
(
number
<
Is
>
{}),
0
)...};
}
};
}
// namespace detail
// F signature: F(number<Iter>)
template
<
index_t
NBegin
,
index_t
NEnd
,
index_t
Increment
>
struct
static_for
{
CK_TILE_HOST_DEVICE
constexpr
static_for
()
{
static_assert
(
Increment
!=
0
&&
(
NEnd
-
NBegin
)
%
Increment
==
0
,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0"
);
static_assert
((
Increment
>
0
&&
NBegin
<=
NEnd
)
||
(
Increment
<
0
&&
NBegin
>=
NEnd
),
"wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
"NBegin >= NEnd)"
);
}
template
<
class
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
)
const
{
detail
::
static_for_impl
<
typename
arithmetic_sequence_gen
<
NBegin
,
NEnd
,
Increment
>::
type
>
{}(
f
);
}
};
struct
identity
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
constexpr
T
&&
operator
()(
T
&&
arg
)
const
noexcept
{
return
std
::
forward
<
T
>
(
arg
);
}
};
namespace
detail
{
// RemainLengths: sequence<...>
// Orders: sequence<...>
template
<
class
RemainLengths
,
class
Orders
>
struct
static_ford_impl
{
CK_TILE_HOST_DEVICE
constexpr
static_ford_impl
()
{
static_assert
(
RemainLengths
::
size
()
>
0
,
"wrong! should not get here"
);
}
// F signature: F(sequence<...>)
// CurrentOrderedId: sequence<...>
template
<
class
F
,
class
CurrentOrderedId
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
CurrentOrderedId
)
const
{
static_for
<
0
,
RemainLengths
::
front
(),
1
>
{}([
=
](
auto
I
)
{
static_ford_impl
<
decltype
(
RemainLengths
::
pop_front
()),
Orders
>
{}(
f
,
CurrentOrderedId
::
push_back
(
I
));
});
}
};
template
<
class
Orders
>
struct
static_ford_impl
<
sequence
<>
,
Orders
>
{
// F signature: F(sequence<...>)
// OrderedId: sequence<...>
template
<
class
F
,
class
OrderedId
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
OrderedId
)
const
{
// retrive unordered Id
f
(
OrderedId
::
reorder_old_to_new
(
Orders
{}));
}
};
}
// namespace detail
// Lengths is sequence<...>, it is the length of each dimension for
// N-dimensional loop
// Orders is sequence<...>, it is the order of dimension in which static_ford
// will loop over each
// dimension
template
<
class
Lengths
,
class
Orders
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
size
(),
1
>
::
type
>
struct
static_ford
{
CK_TILE_HOST_DEVICE
constexpr
static_ford
()
{
static_assert
(
Lengths
::
size
()
>
0
,
"wrong! Lengths is empty"
);
static_assert
(
Lengths
::
size
()
==
Orders
::
size
(),
"wrong! inconsistent size"
);
}
// F signature: F(sequence<...> multi_id)
// multi_id is the unordered multi-index
template
<
class
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
)
const
{
constexpr
auto
ordered_lengths
=
Lengths
::
reorder_new_to_old
(
Orders
{});
detail
::
static_ford_impl
<
decltype
(
ordered_lengths
),
Orders
>
{}(
f
,
sequence
<>
{});
}
};
namespace
detail
{
template
<
typename
Indices
>
struct
unpack_impl
;
template
<
index_t
...
Is
>
struct
unpack_impl
<
sequence
<
Is
...
>>
{
template
<
typename
F
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
)
const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
#else
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
template
at
<
Is
>()...);
#endif
}
};
template
<
typename
Seq0
,
typename
Seq1
>
struct
unpack2_impl
;
// TODO: remove this, after properly implementing unpack that takes any number of containers
template
<
index_t
...
Is
,
index_t
...
Js
>
struct
unpack2_impl
<
sequence
<
Is
...
>
,
sequence
<
Js
...
>>
{
template
<
typename
F
,
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
,
Y
&&
y
)
const
{
#if 0
return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
std::forward<Y>(y).at(number<Js>{})...);
#else
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
template
at
<
Is
>()...,
std
::
forward
<
Y
>
(
y
).
template
at
<
Js
>()...);
#endif
}
};
}
// namespace detail
template
<
typename
F
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
unpack
(
F
&&
f
,
X
&&
x
)
{
using
X_
=
remove_reference_t
<
X
>
;
return
detail
::
unpack_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
));
}
// TODO: properly implement unpack that takes any number of containers
template
<
typename
F
,
typename
X
,
typename
Y
>
CK_TILE_HOST_DEVICE
constexpr
auto
unpack2
(
F
&&
f
,
X
&&
x
,
Y
&&
y
)
{
using
X_
=
remove_reference_t
<
X
>
;
using
Y_
=
remove_reference_t
<
Y
>
;
return
detail
::
unpack2_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
size
(),
1
>::
type
,
typename
arithmetic_sequence_gen
<
0
,
Y_
::
size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Y
>
(
y
));
}
// z = predicate ? x : y
template
<
bool
predicate
,
typename
X
,
typename
Y
>
constexpr
auto
conditional_expr
(
X
&&
x
,
Y
&&
y
)
{
if
constexpr
(
predicate
)
{
return
std
::
forward
<
X
>
(
x
);
}
else
{
return
std
::
forward
<
Y
>
(
y
);
}
}
}
// namespace ck_tile
include/ck_tile/core/utility/ignore.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// https://en.cppreference.com/w/cpp/utility/tuple/ignore
namespace
ck_tile
{
namespace
detail
{
struct
ignore_t
{
template
<
typename
T
>
constexpr
void
operator
=
(
T
&&
)
const
noexcept
{
}
};
}
// namespace detail
inline
constexpr
detail
::
ignore_t
ignore
;
}
// namespace ck_tile
include/ck_tile/core/utility/magic_div.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
namespace
ck_tile
{
// magic number division
// Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
// division implementation for uint32_t is then used. Therefore, dividend value need to be
// non-negative.
// TODO:
// 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range
struct
magic_division32_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_magic_numbers
(
uint32_t
divisor
)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= INT32_MAX)
uint32_t
shift_u32
=
0
;
while
((
1U
<<
shift_u32
)
<
divisor
)
{
shift_u32
++
;
};
uint64_t
tmp_u64
=
((
1UL
<<
shift_u32
)
-
divisor
)
<<
32
;
uint32_t
multiplier_u32
=
tmp_u64
/
divisor
+
1
;
return
make_tuple
(
multiplier_u32
,
shift_u32
);
}
template
<
auto
Divisor
,
typename
=
std
::
enable_if_t
<
(
0
<
Divisor
)>
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_magic_numbers
(
constant
<
Divisor
>
)
{
constexpr
auto
tmp
=
calculate_magic_numbers
(
uint32_t
{
Divisor
});
constexpr
uint32_t
multiplier
=
tmp
[
number
<
0
>
{}];
constexpr
uint32_t
shift
=
tmp
[
number
<
1
>
{}];
return
make_tuple
(
constant
<
multiplier
>
{},
constant
<
shift
>
{});
}
// magic division for uint32_t
CK_TILE_DEVICE
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
tmp
=
__umulhi
(
dividend
,
multiplier
);
return
(
tmp
+
dividend
)
>>
shift
;
}
CK_TILE_HOST
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
tmp
=
(
static_cast
<
uint64_t
>
(
dividend
)
*
multiplier
)
>>
32
;
return
(
tmp
+
dividend
)
>>
shift
;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
__umulhi
(
dividend_u32
,
multiplier
);
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
CK_TILE_HOST
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
(
static_cast
<
uint64_t
>
(
dividend_u32
)
*
multiplier
)
>>
32
;
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
};
// magic number division
// This version on works for divisor and dividended between [0, 1 << 16]
struct
magic_division16_bit_range
{
// uint32_t
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_magic_numbers
(
uint32_t
divisor
)
{
// WARNING: magic division is only valid for division inside this range.
// assert(divisor >= 1 && divisor <= (1U << 16));
uint32_t
shift_u32
=
0
;
while
((
1U
<<
shift_u32
)
<
divisor
)
{
shift_u32
++
;
};
uint32_t
one
=
1
;
uint32_t
multiplier_u32
=
((
one
<<
16
)
*
((
one
<<
shift_u32
)
-
divisor
))
/
divisor
+
1
;
return
make_tuple
(
multiplier_u32
,
shift_u32
);
}
// integral_constant<uint32_t, .>
template
<
auto
Divisor
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
calculate_magic_numbers
(
constant
<
Divisor
>
)
{
constexpr
auto
tmp
=
calculate_magic_numbers
(
uint32_t
{
Divisor
});
constexpr
uint32_t
multiplier
=
tmp
[
number
<
0
>
{}];
constexpr
uint32_t
shift
=
tmp
[
number
<
1
>
{}];
return
make_tuple
(
constant
<
multiplier
>
{},
constant
<
shift
>
{});
}
// magic division for uint32_t
CK_TILE_DEVICE
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
tmp
=
(
dividend
*
multiplier
)
>>
16
;
return
(
tmp
+
dividend
)
>>
shift
;
}
CK_TILE_HOST
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
tmp
=
(
dividend
*
multiplier
)
>>
16
;
return
(
tmp
+
dividend
)
>>
shift
;
}
// magic division for int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
CK_TILE_DEVICE
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
(
dividend_u32
*
multiplier
)
>>
16
;
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
CK_TILE_HOST
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
(
dividend_u32
*
multiplier
)
>>
16
;
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
};
// use 32bit version
using
magic_division
=
magic_division32_bit_range
;
struct
mdiv
{
// 1 dword -> 3 dword storage
uint32_t
divisor
;
uint32_t
multiplier
;
uint32_t
shift
;
// TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE
mdiv
(
uint32_t
divisor_
)
:
divisor
(
divisor_
)
{
auto
tmp
=
magic_division
::
calculate_magic_numbers
(
divisor_
);
multiplier
=
tmp
[
number
<
0
>
{}];
shift
=
tmp
[
number
<
1
>
{}];
}
CK_TILE_HOST_DEVICE
mdiv
()
:
divisor
(
0
),
multiplier
(
0
),
shift
(
0
)
{}
CK_TILE_HOST_DEVICE
void
update
(
uint32_t
divisor_
)
{
divisor
=
divisor_
;
auto
tmp
=
magic_division
::
calculate_magic_numbers
(
divisor_
);
multiplier
=
tmp
[
number
<
0
>
{}];
shift
=
tmp
[
number
<
1
>
{}];
}
CK_TILE_HOST_DEVICE
uint32_t
div
(
uint32_t
dividend_
)
const
{
return
magic_division
::
do_magic_division
(
dividend_
,
multiplier
,
shift
);
}
CK_TILE_HOST_DEVICE
void
divmod
(
uint32_t
dividend_
,
uint32_t
&
quotient_
,
uint32_t
&
remainder_
)
const
{
quotient_
=
div
(
dividend_
);
remainder_
=
dividend_
-
(
quotient_
*
divisor
);
}
CK_TILE_HOST_DEVICE
uint32_t
get
()
const
{
return
divisor
;
}
};
struct
mdiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t
multiplier
;
uint32_t
shift
;
// TODO: 8 bit is enough
// prefer construct on host
CK_TILE_HOST_DEVICE
mdiv2
(
uint32_t
divisor_
)
{
auto
tmp
=
magic_division
::
calculate_magic_numbers
(
divisor_
);
multiplier
=
tmp
[
number
<
0
>
{}];
shift
=
tmp
[
number
<
1
>
{}];
}
CK_TILE_HOST_DEVICE
mdiv2
()
:
multiplier
(
0
),
shift
(
0
)
{}
CK_TILE_HOST_DEVICE
uint32_t
div
(
uint32_t
dividend_
)
const
{
return
magic_division
::
do_magic_division
(
dividend_
,
multiplier
,
shift
);
}
CK_TILE_HOST_DEVICE
void
divmod
(
uint32_t
dividend_
,
uint32_t
divisor_
,
uint32_t
&
quotient_
,
uint32_t
&
remainder_
)
const
{
quotient_
=
div
(
dividend_
);
remainder_
=
dividend_
-
(
quotient_
*
divisor_
);
}
};
}
// namespace ck_tile
include/ck_tile/core/utility/random.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include <stdint.h>
#include <tuple>
#include <type_traits>
namespace
ck_tile
{
// return 0 if data is not fp16 or fp32
template
<
typename
T
,
uint32_t
seed_
>
struct
prand_generator_t
{
CK_TILE_HOST_DEVICE
uint32_t
operator
()(
int
,
T
,
uint32_t
=
seed_
)
{
return
0
;
}
};
// version for fp32
template
<
uint32_t
seed_
>
struct
prand_generator_t
<
float
,
seed_
>
{
CK_TILE_HOST_DEVICE
uint32_t
operator
()(
int
id
,
float
val
,
uint32_t
seed
=
seed_
)
{
uint32_t
x
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
val
));
uint32_t
drop_bits
=
uint32_t
(
x
)
&
0xFFFFu
;
drop_bits
^=
x
>>
16
;
drop_bits
=
((
drop_bits
&
31
)
<<
11
)
|
(
drop_bits
>>
5
);
drop_bits
*=
0x7000149
;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t
rng
=
(
drop_bits
^
0x13371337
^
(
id
*
229791
)
^
seed
);
return
rng
;
}
};
// version for fp16
template
<
uint32_t
seed_
>
struct
prand_generator_t
<
half_t
,
seed_
>
{
CK_TILE_HOST_DEVICE
uint32_t
operator
()(
int
id
,
half_t
val
,
uint32_t
seed
=
seed_
)
{
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
uint32_t
drop_bits
=
uint32_t
(
x
)
&
0xFFFFu
;
drop_bits
=
((
drop_bits
&
31
)
<<
11
)
|
(
drop_bits
>>
5
);
drop_bits
*=
0x7000149
;
// NOTE: If id is in 64 bit, we are only using lower 32 bit.
// So, it can have an effect of using same id for multiple elements when the id is
// very large!
uint32_t
rng
=
(
drop_bits
^
0x13371337
^
(
id
*
229791
)
^
seed
);
return
rng
;
}
};
}
// namespace ck_tile
include/ck_tile/core/utility/to_sequence.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/container/sequence.hpp"
// TODO: use c++20 nontype template with struct to implement this
#if 1
// clang happen to support this feature (__cpp_generic_lambdas >= 201707) in c++17 mode
#define TO_SEQUENCE(a, n) \
_Pragma("clang diagnostic push") _Pragma( \
"clang diagnostic ignored \"-Wc++20-extensions\"")[a]<ck_tile::index_t... IDX_IDX_>( \
ck_tile::sequence<IDX_IDX_...>) \
{ \
return ck_tile::sequence<a.at(ck_tile::number<IDX_IDX_>{})...>{}; \
} \
(ck_tile::make_index_sequence<n>{}); \
_Pragma("clang diagnostic pop")
#else
// Macro function
// convert constexpr array to sequence, both a/n need to be constexpr (can't be a rvalue like 2)
#define TO_SEQUENCE(a, n) \
[a, n] { \
static_assert(a.size() >= n, "wrong! out of bound"); \
static_assert(n <= 10, "not implemented"); \
if constexpr(n == 0) \
{ \
return ck_tile::sequence<>{}; \
} \
else if constexpr(n == 1) \
{ \
return ck_tile::sequence<a[0]>{}; \
} \
else if constexpr(n == 2) \
{ \
return ck_tile::sequence<a[0], a[1]>{}; \
} \
else if constexpr(n == 3) \
{ \
return ck_tile::sequence<a[0], a[1], a[2]>{}; \
} \
else if constexpr(n == 4) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3]>{}; \
} \
else if constexpr(n == 5) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4]>{}; \
} \
else if constexpr(n == 6) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5]>{}; \
} \
else if constexpr(n == 7) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6]>{}; \
} \
else if constexpr(n == 8) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7]>{}; \
} \
else if constexpr(n == 9) \
{ \
return ck_tile::sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8]>{}; \
} \
else if constexpr(n == 10) \
{ \
return ck_tile:: \
sequence<a[0], a[1], a[2], a[3], a[4], a[5], a[6], a[7], a[8], a[9]>{}; \
} \
}()
#endif
include/ck_tile/core/utility/transpose_vectors.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
namespace
ck_tile
{
// S: scalar type (or it can be non-scalar type)
// NX: # of vector before transpose
// NY: # of vector after transpose
// we got [NX, NY] amount of S data to be transposed into [NY, NX] amount of S data
template
<
typename
S_
,
index_t
NX
,
index_t
NY
>
struct
transpose_vectors
{
static
constexpr
index_t
s_per_x
=
NY
;
static
constexpr
index_t
s_per_y
=
NX
;
using
S
=
remove_cvref_t
<
S_
>
;
using
VX
=
array
<
S
,
s_per_x
>
;
using
VY
=
array
<
S
,
s_per_y
>
;
CK_TILE_DEVICE
void
operator
()(
const
thread_buffer
<
VX
,
NX
>&
vx_tuple
,
thread_buffer
<
VY
,
NY
>&
vy_tuple
)
{
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
I2
=
number
<
2
>
{};
constexpr
auto
I3
=
number
<
3
>
{};
constexpr
auto
I4
=
number
<
4
>
{};
if
constexpr
(
sizeof
(
S
)
==
2
)
{
static_assert
((
NX
%
2
==
0
&&
NY
%
2
==
0
),
"wrong!"
);
using
S2
=
array
<
S
,
2
>
;
// typename array<S, 2>::type;
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for
<
0
,
NY
,
2
>
{}([
&
](
auto
iy
)
{
static_for
<
0
,
NX
,
2
>
{}([
&
](
auto
ix
)
{
// 2 16bitx2 data from vx_tuple to be transposed
const
int32_t
x_s2_0
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
].
template
get_as
<
S2
>()[
iy
/
I2
]);
const
int32_t
x_s2_1
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I1
].
template
get_as
<
S2
>()[
iy
/
I2
]);
constexpr
int32_t
m0
=
0x05040100
;
constexpr
int32_t
m1
=
0x07060302
;
// transpose 2x2 16bit
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
const
int32_t
y_s2_0
=
__builtin_amdgcn_perm
(
x_s2_1
,
x_s2_0
,
m0
);
const
int32_t
y_s2_1
=
__builtin_amdgcn_perm
(
x_s2_1
,
x_s2_0
,
m1
);
// 2 16bitx2 data after transposed
vy_tuple
(
iy
).
template
get_as
<
S2
>()(
ix
/
I2
)
=
bit_cast
<
S2
>
(
y_s2_0
);
vy_tuple
(
iy
+
I1
).
template
get_as
<
S2
>()(
ix
/
I2
)
=
bit_cast
<
S2
>
(
y_s2_1
);
});
});
}
else
if
constexpr
(
sizeof
(
S
)
==
1
)
{
static_assert
((
NX
%
4
==
0
&&
NY
%
4
==
0
),
"wrong!"
);
using
S4
=
array
<
S
,
4
>
;
// typename array<S, 4>::type;
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
static_for
<
0
,
NY
,
4
>
{}([
&
](
auto
iy
)
{
static_for
<
0
,
NX
,
4
>
{}([
&
](
auto
ix
)
{
// 4 int8x4 data from vx_tuple
const
int32_t
x_s4_0
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_1
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I1
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_2
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I2
].
template
get_as
<
S4
>()[
iy
/
I4
]);
const
int32_t
x_s4_3
=
bit_cast
<
int32_t
>
(
vx_tuple
[
ix
+
I3
].
template
get_as
<
S4
>()[
iy
/
I4
]);
// transpose
int32_t
t_s4_0
,
t_s4_1
;
int32_t
y_s4_0
,
y_s4_1
,
y_s4_2
,
y_s4_3
;
constexpr
int32_t
m0
=
0x05010400
;
constexpr
int32_t
m1
=
0x05040100
;
constexpr
int32_t
m2
=
0x07060302
;
constexpr
int32_t
m3
=
0x07030602
;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
t_s4_0
=
__builtin_amdgcn_perm
(
x_s4_1
,
x_s4_0
,
m0
);
t_s4_1
=
__builtin_amdgcn_perm
(
x_s4_3
,
x_s4_2
,
m0
);
y_s4_0
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m1
);
y_s4_1
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m2
);
t_s4_0
=
__builtin_amdgcn_perm
(
x_s4_1
,
x_s4_0
,
m3
);
t_s4_1
=
__builtin_amdgcn_perm
(
x_s4_3
,
x_s4_2
,
m3
);
y_s4_2
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m1
);
y_s4_3
=
__builtin_amdgcn_perm
(
t_s4_1
,
t_s4_0
,
m2
);
// 4 int8x4 data from vy_tuple
vy_tuple
(
iy
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_0
);
vy_tuple
(
iy
+
I1
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_1
);
vy_tuple
(
iy
+
I2
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_2
);
vy_tuple
(
iy
+
I3
).
template
get_as
<
S4
>()(
ix
/
I4
)
=
bit_cast
<
S4
>
(
y_s4_3
);
});
});
}
else
{
static_assert
(
false
,
"not implemented"
);
}
}
};
}
// namespace ck_tile
include/ck_tile/core/utility/type_traits.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <type_traits>
#include <stdint.h>
namespace
ck_tile
{
// remove_cvref_t
template
<
typename
T
>
using
remove_reference_t
=
typename
std
::
remove_reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
{
using
value_t
=
std
::
false_type
;
using
type
=
Default
;
};
template
<
class
Default
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
<
Default
,
std
::
void_t
<
Op
<
Args
...
>>
,
Op
,
Args
...
>
{
using
value_t
=
std
::
true_type
;
using
type
=
Op
<
Args
...
>
;
};
}
// namespace detail
struct
nonesuch
{
~
nonesuch
()
=
delete
;
nonesuch
(
nonesuch
const
&
)
=
delete
;
void
operator
=
(
nonesuch
const
&
)
=
delete
;
};
template
<
template
<
class
...
>
class
Op
,
class
...
Args
>
using
is_detected
=
typename
detail
::
detector
<
nonesuch
,
void
,
Op
,
Args
...
>::
value_t
;
namespace
impl
{
template
<
typename
T
>
using
has_is_static
=
decltype
(
T
::
is_static
());
template
<
typename
T
>
struct
is_static_impl
{
static
constexpr
bool
value
=
[]()
{
if
constexpr
(
is_detected
<
has_is_static
,
T
>
{})
return
T
::
is_static
();
else
return
std
::
is_arithmetic
<
T
>::
value
;
}();
};
}
// namespace impl
template
<
typename
T
>
using
is_static
=
impl
::
is_static_impl
<
remove_cvref_t
<
T
>>
;
template
<
typename
T
>
inline
constexpr
bool
is_static_v
=
is_static
<
T
>::
value
;
// TODO: deprecate this
template
<
typename
T
>
using
is_known_at_compile_time
=
is_static
<
T
>
;
// TODO: if evaluating a rvalue, e.g. a const integer
// , this helper will also return false, which is not good(?)
// do we need something like is_constexpr()?
// FIXME: do we need this anymore?
template
<
typename
PY
,
typename
PX
,
typename
std
::
enable_if
<
std
::
is_pointer_v
<
PY
>
&&
std
::
is_pointer_v
<
PX
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
PY
c_style_pointer_cast
(
PX
p_x
)
{
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
#pragma clang diagnostic ignored "-Wcast-align"
return
(
PY
)
p_x
;
// NOLINT(old-style-cast, cast-align)
#pragma clang diagnostic pop
}
}
// namespace ck_tile
include/ck_tile/core/utility/unary_element_function.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
F
,
typename
...
Fs
>
struct
composes
:
private
composes
<
F
>
{
template
<
typename
FirstArg
,
typename
...
RestArgs
>
CK_TILE_HOST_DEVICE
constexpr
explicit
composes
(
FirstArg
&&
firstArg
,
RestArgs
&&
...
restArgs
)
:
composes
<
F
>
(
std
::
forward
<
FirstArg
>
(
firstArg
)),
inner_
(
std
::
forward
<
RestArgs
>
(
restArgs
)...)
{
}
template
<
typename
Arg
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
Arg
&&
arg
)
const
{
return
static_cast
<
const
composes
<
F
>&>
(
*
this
)(
inner_
(
std
::
forward
<
Arg
>
(
arg
)));
}
private:
composes
<
Fs
...
>
inner_
;
};
template
<
typename
F
>
struct
composes
<
F
>
{
static_assert
(
!
std
::
is_reference_v
<
F
>
);
template
<
typename
Arg
,
typename
=
std
::
enable_if_t
<
std
::
is_constructible_v
<
F
,
Arg
>
>>
CK_TILE_HOST_DEVICE
constexpr
explicit
composes
(
Arg
&&
arg
)
:
f_
(
std
::
forward
<
Arg
>
(
arg
))
{
}
template
<
typename
Arg
,
typename
=
std
::
enable_if_t
<
std
::
is_invocable_v
<
std
::
add_const_t
<
F
>
&
,
Arg
>>>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
Arg
&&
arg
)
const
{
return
f_
(
std
::
forward
<
Arg
>
(
arg
));
}
private:
F
f_
;
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template
<
typename
...
Ts
>
__host__
__device__
composes
(
Ts
&&
...)
->
composes
<
remove_cvref_t
<
Ts
>
...
>
;
template
<
typename
To
>
struct
saturates
{
template
<
typename
From
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
const
From
&
from
)
const
->
std
::
enable_if_t
<
std
::
is_arithmetic_v
<
From
>
,
From
>
{
return
clamp
(
from
,
type_convert
<
From
>
(
numeric
<
To
>::
lowest
()),
type_convert
<
From
>
(
numeric
<
To
>::
max
()));
}
};
}
// namespace ck_tile
include/ck_tile/host.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"
include/ck_tile/host/arg_parser.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <iomanip>
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
#include <string>
#include <unordered_map>
#include <vector>
namespace
ck_tile
{
/*
* a host side utility, arg parser for
* -[key0]=[value0] -[key1]=[value1] ...
*/
class
ArgParser
{
public:
class
Arg
{
public:
std
::
string
name
;
std
::
string
value
;
std
::
string
help_text
;
};
ArgParser
()
{}
ArgParser
&
insert
(
const
std
::
string
&
_name
,
const
std
::
string
&
_default_value
,
const
std
::
string
&
_help_text
)
{
Arg
in
;
in
.
name
=
_name
;
in
.
value
=
_default_value
;
in
.
help_text
=
_help_text
;
if
(
input_map
.
count
(
_name
)
!=
0
)
{
printf
(
"arg:%s already exist
\n
"
,
_name
.
c_str
());
}
else
{
input_map
[
_name
]
=
in
;
keys
.
push_back
(
_name
);
}
return
*
this
;
}
void
print
()
{
printf
(
"args:
\n
"
);
for
(
auto
&
key
:
keys
)
{
auto
value
=
input_map
[
key
];
std
::
vector
<
std
::
string
>
help_text_lines
;
size_t
pos
=
0
;
for
(
size_t
next_pos
=
value
.
help_text
.
find
(
'\n'
,
pos
);
next_pos
!=
std
::
string
::
npos
;)
{
help_text_lines
.
push_back
(
std
::
string
(
value
.
help_text
.
begin
()
+
pos
,
value
.
help_text
.
begin
()
+
next_pos
++
));
pos
=
next_pos
;
next_pos
=
value
.
help_text
.
find
(
'\n'
,
pos
);
}
help_text_lines
.
push_back
(
std
::
string
(
value
.
help_text
.
begin
()
+
pos
,
value
.
help_text
.
end
()));
std
::
string
default_value
=
std
::
string
(
"(default:"
)
+
value
.
value
+
std
::
string
(
")"
);
std
::
cout
<<
std
::
setw
(
2
)
<<
std
::
setw
(
12
-
value
.
name
.
length
())
<<
"-"
<<
key
<<
std
::
setw
(
4
)
<<
" "
<<
help_text_lines
[
0
]
<<
" "
<<
default_value
<<
std
::
endl
;
for
(
auto
help_next_line
=
std
::
next
(
help_text_lines
.
begin
());
help_next_line
!=
help_text_lines
.
end
();
++
help_next_line
)
{
std
::
cout
<<
std
::
setw
(
17
)
<<
" "
<<
*
help_next_line
<<
std
::
endl
;
}
}
}
bool
parse
(
int
argc
,
char
*
argv
[],
int
start_index
=
1
)
{
if
(
argc
<
start_index
)
{
printf
(
"not enough args
\n
"
);
return
false
;
}
for
(
int
i
=
start_index
;
i
<
argc
;
i
++
)
{
char
*
cur_arg
=
argv
[
i
];
if
(
cur_arg
[
0
]
!=
'-'
)
{
printf
(
"illegal input
\n
"
);
print
();
return
false
;
}
else
{
std
::
string
text
(
cur_arg
+
1
);
if
(
text
==
"?"
)
{
print
();
return
false
;
}
auto
pos
=
text
.
find
(
'='
);
if
(
pos
==
std
::
string
::
npos
)
{
printf
(
"arg should be [key]=[value] pair, here:%s
\n
"
,
text
.
c_str
());
return
false
;
}
if
(
pos
>=
(
text
.
size
()
-
1
))
{
printf
(
"cant find value after
\"
=
\"
, here:%s
\n
"
,
text
.
c_str
());
return
false
;
}
auto
key
=
text
.
substr
(
0
,
pos
);
auto
value
=
text
.
substr
(
pos
+
1
);
if
(
input_map
.
count
(
key
)
==
0
)
{
printf
(
"no such arg:%s
\n
"
,
key
.
c_str
());
return
false
;
}
input_map
[
key
].
value
=
value
;
}
}
return
true
;
}
std
::
string
get_str
(
const
std
::
string
&
name
)
const
{
std
::
string
value
=
input_map
.
at
(
name
).
value
;
return
value
;
}
int
get_int
(
const
std
::
string
&
name
)
const
{
int
value
=
atoi
(
input_map
.
at
(
name
).
value
.
c_str
());
return
value
;
}
uint32_t
get_uint32
(
const
std
::
string
&
name
)
const
{
uint32_t
value
=
strtoul
(
input_map
.
at
(
name
).
value
.
c_str
(),
nullptr
,
10
);
return
value
;
}
uint64_t
get_uint64
(
const
std
::
string
&
name
)
const
{
uint64_t
value
=
strtoull
(
input_map
.
at
(
name
).
value
.
c_str
(),
nullptr
,
10
);
return
value
;
}
bool
get_bool
(
const
std
::
string
&
name
)
const
{
auto
v
=
input_map
.
at
(
name
).
value
;
if
(
v
.
compare
(
"t"
)
==
0
||
v
.
compare
(
"true"
)
==
0
)
return
true
;
if
(
v
.
compare
(
"f"
)
==
0
||
v
.
compare
(
"false"
)
==
0
)
return
false
;
int
value
=
atoi
(
v
.
c_str
());
return
value
==
0
?
false
:
true
;
}
float
get_float
(
const
std
::
string
&
name
)
const
{
double
value
=
atof
(
input_map
.
at
(
name
).
value
.
c_str
());
return
static_cast
<
float
>
(
value
);
}
double
get_double
(
const
std
::
string
&
name
)
const
{
double
value
=
atof
(
input_map
.
at
(
name
).
value
.
c_str
());
return
value
;
}
private:
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
vector
<
std
::
string
>
keys
;
};
}
// namespace ck_tile
include/ck_tile/host/check_err.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <iostream>
#include <iomanip>
#include <iterator>
#include <limits>
#include <type_traits>
#include <vector>
#include "ck_tile/core.hpp"
#include "ck_tile/host/ranges.hpp"
namespace
ck_tile
{
template
<
typename
T
>
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
std
::
vector
<
T
>&
v
)
{
using
size_type
=
typename
std
::
vector
<
T
>::
size_type
;
os
<<
"["
;
for
(
size_type
idx
=
0
;
idx
<
v
.
size
();
++
idx
)
{
if
(
0
<
idx
)
{
os
<<
", "
;
}
os
<<
v
[
idx
];
}
return
os
<<
"]"
;
}
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_floating_point_v
<
ranges
::
range_value_t
<
Range
>>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
half_t
>
,
bool
>::
type
CK_TILE_HOST
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-5
,
double
atol
=
3e-6
,
bool
allow_infinity_ref
=
false
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
double
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
*
std
::
next
(
std
::
begin
(
out
),
i
);
const
double
r
=
*
std
::
next
(
std
::
begin
(
ref
),
i
);
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
is_infinity_error
(
o
,
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
const
float
error_percent
=
static_cast
<
float
>
(
err_count
)
/
static_cast
<
float
>
(
out
.
size
())
*
100.
f
;
std
::
cerr
<<
"max err: "
<<
max_err
;
std
::
cerr
<<
", number of errors: "
<<
err_count
;
std
::
cerr
<<
", "
<<
error_percent
<<
"% wrong values"
<<
std
::
endl
;
}
return
res
;
}
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bf16_t
>
,
bool
>::
type
CK_TILE_HOST
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
,
bool
allow_infinity_ref
=
false
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
// TODO: This is a hack. We should have proper specialization for bf16_t data type.
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
is_infinity_error
(
o
,
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
const
float
error_percent
=
static_cast
<
float
>
(
err_count
)
/
static_cast
<
float
>
(
out
.
size
())
*
100.
f
;
std
::
cerr
<<
"max err: "
<<
max_err
;
std
::
cerr
<<
", number of errors: "
<<
err_count
;
std
::
cerr
<<
", "
<<
error_percent
<<
"% wrong values"
<<
std
::
endl
;
}
return
res
;
}
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
half_t
>
,
bool
>::
type
CK_TILE_HOST
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
,
bool
allow_infinity_ref
=
false
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
static_cast
<
double
>
(
std
::
numeric_limits
<
ranges
::
range_value_t
<
Range
>>::
min
());
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
is_infinity_error
(
o
,
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
const
float
error_percent
=
static_cast
<
float
>
(
err_count
)
/
static_cast
<
float
>
(
out
.
size
())
*
100.
f
;
std
::
cerr
<<
"max err: "
<<
max_err
;
std
::
cerr
<<
", number of errors: "
<<
err_count
;
std
::
cerr
<<
", "
<<
error_percent
<<
"% wrong values"
<<
std
::
endl
;
}
return
res
;
}
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_integral_v
<
ranges
::
range_value_t
<
Range
>>
&&
!
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bf16_t
>
)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
int4_t
>
#endif
,
bool
>
CK_TILE_HOST
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
=
0
,
double
atol
=
0
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
bool
res
{
true
};
int
err_count
=
0
;
int64_t
err
=
0
;
int64_t
max_err
=
std
::
numeric_limits
<
int64_t
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
int64_t
o
=
*
std
::
next
(
std
::
begin
(
out
),
i
);
const
int64_t
r
=
*
std
::
next
(
std
::
begin
(
ref
),
i
);
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
)
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
const
float
error_percent
=
static_cast
<
float
>
(
err_count
)
/
static_cast
<
float
>
(
out
.
size
())
*
100.
f
;
std
::
cerr
<<
"max err: "
<<
max_err
;
std
::
cerr
<<
", number of errors: "
<<
err_count
;
std
::
cerr
<<
", "
<<
error_percent
<<
"% wrong values"
<<
std
::
endl
;
}
return
res
;
}
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
fp8_t
>
),
bool
>
CK_TILE_HOST
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
unsigned
max_rounding_point_distance
=
1
,
double
atol
=
1e-1
,
bool
allow_infinity_ref
=
false
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
static
const
auto
get_rounding_point_distance
=
[](
fp8_t
o
,
fp8_t
r
)
->
unsigned
{
static
const
auto
get_sign_bit
=
[](
fp8_t
v
)
->
bool
{
return
0x80
&
bit_cast
<
uint8_t
>
(
v
);
};
if
(
get_sign_bit
(
o
)
^
get_sign_bit
(
r
))
{
return
std
::
numeric_limits
<
unsigned
>::
max
();
}
else
{
return
std
::
abs
(
bit_cast
<
int8_t
>
(
o
)
-
bit_cast
<
int8_t
>
(
r
));
}
};
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
fp8_t
o_fp8
=
*
std
::
next
(
std
::
begin
(
out
),
i
);
const
fp8_t
r_fp8
=
*
std
::
next
(
std
::
begin
(
ref
),
i
);
const
double
o_fp64
=
type_convert
<
float
>
(
o_fp8
);
const
double
r_fp64
=
type_convert
<
float
>
(
r_fp8
);
err
=
std
::
abs
(
o_fp64
-
r_fp64
);
if
(
!
(
less_equal
<
double
>
{}(
err
,
atol
)
||
get_rounding_point_distance
(
o_fp8
,
r_fp8
)
<=
max_rounding_point_distance
)
||
is_infinity_error
(
o_fp64
,
r_fp64
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o_fp64
<<
" != "
<<
r_fp64
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
}
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
bf8_t
>
),
bool
>
CK_TILE_HOST
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
1e-3
,
double
atol
=
1e-3
,
bool
allow_infinity_ref
=
false
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
is_infinity_error
(
o
,
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
std
::
endl
;
}
return
res
;
}
}
// namespace ck_tile
include/ck_tile/host/device_memory.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
#include <stdint.h>
#include <stdexcept>
#include "ck_tile/host/hip_check_error.hpp"
namespace
ck_tile
{
template
<
typename
T
>
__global__
void
set_buffer_value
(
T
*
p
,
T
x
,
uint64_t
buffer_element_size
)
{
for
(
uint64_t
i
=
threadIdx
.
x
;
i
<
buffer_element_size
;
i
+=
blockDim
.
x
)
{
p
[
i
]
=
x
;
}
}
/**
* @brief Container for storing data in GPU device memory
*
*/
struct
DeviceMem
{
DeviceMem
()
:
mpDeviceBuf
(
nullptr
),
mMemSize
(
0
)
{}
DeviceMem
(
std
::
size_t
mem_size
)
:
mMemSize
(
mem_size
)
{
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
void
Realloc
(
std
::
size_t
mem_size
)
{
if
(
mpDeviceBuf
)
{
HIP_CHECK_ERROR
(
hipFree
(
mpDeviceBuf
));
}
mMemSize
=
mem_size
;
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
void
*
GetDeviceBuffer
()
const
{
return
mpDeviceBuf
;
}
std
::
size_t
GetBufferSize
()
const
{
return
mMemSize
;
}
void
ToDevice
(
const
void
*
p
)
const
{
if
(
mpDeviceBuf
)
{
HIP_CHECK_ERROR
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
hipMemcpyHostToDevice
));
}
else
{
throw
std
::
runtime_error
(
"ToDevice with an empty pointer"
);
}
}
void
ToDevice
(
const
void
*
p
,
const
std
::
size_t
cpySize
)
const
{
HIP_CHECK_ERROR
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
cpySize
,
hipMemcpyHostToDevice
));
}
void
FromDevice
(
void
*
p
)
const
{
if
(
mpDeviceBuf
)
{
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
hipMemcpyDeviceToHost
));
}
else
{
throw
std
::
runtime_error
(
"FromDevice with an empty pointer"
);
}
}
void
FromDevice
(
void
*
p
,
const
std
::
size_t
cpySize
)
const
{
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
void
SetZero
()
const
{
if
(
mpDeviceBuf
)
{
HIP_CHECK_ERROR
(
hipMemset
(
mpDeviceBuf
,
0
,
mMemSize
));
}
}
template
<
typename
T
>
void
SetValue
(
T
x
)
const
{
if
(
mMemSize
%
sizeof
(
T
)
!=
0
)
{
throw
std
::
runtime_error
(
"wrong! not entire DeviceMem will be set"
);
}
// TODO: call a gpu kernel to set the value (?)
set_buffer_value
<
T
><<<
1
,
1024
>>>
(
static_cast
<
T
*>
(
mpDeviceBuf
),
x
,
mMemSize
/
sizeof
(
T
));
}
~
DeviceMem
()
{
if
(
mpDeviceBuf
)
{
try
{
HIP_CHECK_ERROR
(
hipFree
(
mpDeviceBuf
));
}
catch
(
std
::
runtime_error
&
re
)
{
std
::
cerr
<<
re
.
what
()
<<
std
::
endl
;
}
}
}
void
*
mpDeviceBuf
;
std
::
size_t
mMemSize
;
};
}
// namespace ck_tile
include/ck_tile/host/fill.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cmath>
#include <iterator>
#include <optional>
#include <random>
#include <type_traits>
#include <utility>
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
T
>
struct
FillUniformDistribution
{
float
a_
{
-
5.
f
};
float
b_
{
5.
f
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillUniformDistribution
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
template
<
typename
T
>
struct
FillNormalDistribution
{
float
mean_
{
0.
f
};
float
variance_
{
1.
f
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
std
::
normal_distribution
<
float
>
dis
(
mean_
,
std
::
sqrt
(
variance_
));
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillNormalDistribution
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below.
// However this produces segfaults in std::mt19937 which look like inifite loop.
// template <typename T>
// struct FillUniformDistributionIntegerValue
// {
// int a_{-5};
// int b_{5};
//
// template <typename ForwardIter>
// void operator()(ForwardIter first, ForwardIter last) const
// {
// std::mt19937 gen(11939);
// std::uniform_int_distribution<int> dis(a_, b_);
// std::generate(
// first, last, [&dis, &gen]() { return ck_tile::type_convert<T>(dis(gen)); });
// }
// };
// Workaround for uniform_int_distribution not working as expected. See note above.<
template
<
typename
T
>
struct
FillUniformDistributionIntegerValue
{
float
a_
{
-
5.
f
};
float
b_
{
5.
f
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
std
::
round
(
dis
(
gen
)));
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillUniformDistributionIntegerValue
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
template
<
typename
T
>
struct
FillNormalDistributionIntegerValue
{
float
mean_
{
0.
f
};
float
variance_
{
1.
f
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
std
::
normal_distribution
<
float
>
dis
(
mean_
,
std
::
sqrt
(
variance_
));
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
std
::
round
(
dis
(
gen
)));
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillNormalDistributionIntegerValue
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
template
<
typename
T
>
struct
FillMonotonicSeq
{
T
init_value_
{
0
};
T
step_
{
1
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
generate
(
first
,
last
,
[
=
,
n
=
init_value_
]()
mutable
{
auto
tmp
=
n
;
n
+=
step_
;
return
tmp
;
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillMonotonicSeq
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
template
<
typename
T
>
struct
FillConstant
{
T
value_
{
0
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
fill
(
first
,
last
,
value_
);
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillConstant
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
template
<
typename
T
,
bool
UseCos
=
true
,
bool
UseAbs
=
false
>
struct
FillTrigValue
{
template
<
typename
T_
,
bool
UseCos_
=
true
,
bool
UseAbs_
=
false
>
struct
LinearTrigGen
{
int
i
{
0
};
auto
operator
()()
{
float
v
=
0
;
if
constexpr
(
UseCos_
)
{
v
=
cos
(
i
);
}
else
{
v
=
sin
(
i
);
}
if
constexpr
(
UseAbs_
)
v
=
abs
(
v
);
i
++
;
return
ck_tile
::
type_convert
<
T_
>
(
v
);
}
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
LinearTrigGen
<
T
,
UseCos
,
UseAbs
>
gen
;
std
::
generate
(
first
,
last
,
gen
);
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillTrigValue
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
}
// namespace ck_tile
include/ck_tile/host/hip_check_error.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <sstream>
#include <stdexcept>
#include <hip/hip_runtime.h>
namespace
ck_tile
{
// To be removed, which really does not tell the location of failed HIP functional call
CK_TILE_HOST
void
hip_check_error
(
hipError_t
x
)
{
if
(
x
!=
hipSuccess
)
{
std
::
ostringstream
ss
;
ss
<<
"HIP runtime error: "
<<
hipGetErrorString
(
x
)
<<
". "
<<
__FILE__
<<
": "
<<
__LINE__
<<
"in function: "
<<
__func__
;
throw
std
::
runtime_error
(
ss
.
str
());
}
}
}
// namespace ck_tile
#define HIP_CHECK_ERROR(retval_or_funcall) \
do \
{ \
hipError_t _tmpVal = retval_or_funcall; \
if(_tmpVal != hipSuccess) \
{ \
std::ostringstream ostr; \
ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \
<< hipGetErrorString(_tmpVal); \
throw std::runtime_error(ostr.str()); \
} \
} while(0)
include/ck_tile/host/host_tensor.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cassert>
#include <iostream>
#include <iomanip>
#include <numeric>
#include <thread>
#include <utility>
#include <vector>
#include "ck_tile/core.hpp"
#include "ck_tile/host/ranges.hpp"
namespace
ck_tile
{
template
<
typename
Range
>
CK_TILE_HOST
std
::
ostream
&
LogRange
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
,
int
precision
=
std
::
cout
.
precision
(),
int
width
=
0
)
{
bool
first
=
true
;
for
(
auto
&&
v
:
range
)
{
if
(
first
)
first
=
false
;
else
os
<<
delim
;
os
<<
std
::
setw
(
width
)
<<
std
::
setprecision
(
precision
)
<<
v
;
}
return
os
;
}
template
<
typename
T
,
typename
Range
>
CK_TILE_HOST
std
::
ostream
&
LogRangeAsType
(
std
::
ostream
&
os
,
Range
&&
range
,
std
::
string
delim
,
int
precision
=
std
::
cout
.
precision
(),
int
width
=
0
)
{
bool
first
=
true
;
for
(
auto
&&
v
:
range
)
{
if
(
first
)
first
=
false
;
else
os
<<
delim
;
os
<<
std
::
setw
(
width
)
<<
std
::
setprecision
(
precision
)
<<
static_cast
<
T
>
(
v
);
}
return
os
;
}
template
<
typename
F
,
typename
T
,
std
::
size_t
...
Is
>
CK_TILE_HOST
auto
call_f_unpack_args_impl
(
F
f
,
T
args
,
std
::
index_sequence
<
Is
...
>
)
{
return
f
(
std
::
get
<
Is
>
(
args
)...);
}
template
<
typename
F
,
typename
T
>
CK_TILE_HOST
auto
call_f_unpack_args
(
F
f
,
T
args
)
{
constexpr
std
::
size_t
N
=
std
::
tuple_size
<
T
>
{};
return
call_f_unpack_args_impl
(
f
,
args
,
std
::
make_index_sequence
<
N
>
{});
}
template
<
typename
F
,
typename
T
,
std
::
size_t
...
Is
>
CK_TILE_HOST
auto
construct_f_unpack_args_impl
(
T
args
,
std
::
index_sequence
<
Is
...
>
)
{
return
F
(
std
::
get
<
Is
>
(
args
)...);
}
template
<
typename
F
,
typename
T
>
CK_TILE_HOST
auto
construct_f_unpack_args
(
F
,
T
args
)
{
constexpr
std
::
size_t
N
=
std
::
tuple_size
<
T
>
{};
return
construct_f_unpack_args_impl
<
F
>
(
args
,
std
::
make_index_sequence
<
N
>
{});
}
struct
HostTensorDescriptor
{
HostTensorDescriptor
()
=
default
;
void
CalculateStrides
()
{
mStrides
.
clear
();
mStrides
.
resize
(
mLens
.
size
(),
0
);
if
(
mStrides
.
empty
())
return
;
mStrides
.
back
()
=
1
;
std
::
partial_sum
(
mLens
.
rbegin
(),
mLens
.
rend
()
-
1
,
mStrides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
}
template
<
typename
X
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
X
,
std
::
size_t
>
>>
HostTensorDescriptor
(
const
std
::
initializer_list
<
X
>&
lens
)
:
mLens
(
lens
.
begin
(),
lens
.
end
())
{
this
->
CalculateStrides
();
}
template
<
typename
Lengths
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
ck_tile
::
ranges
::
range_value_t
<
Lengths
>,
std
::
size_t
>>>
HostTensorDescriptor
(
const
Lengths
&
lens
)
:
mLens
(
lens
.
begin
(),
lens
.
end
())
{
this
->
CalculateStrides
();
}
template
<
typename
X
,
typename
Y
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
X
,
std
::
size_t
>
&&
std
::
is_convertible_v
<
Y
,
std
::
size_t
>>>
HostTensorDescriptor
(
const
std
::
initializer_list
<
X
>&
lens
,
const
std
::
initializer_list
<
Y
>&
strides
)
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
())
{
}
template
<
typename
Lengths
,
typename
Strides
,
typename
=
std
::
enable_if_t
<
std
::
is_convertible_v
<
ck_tile
::
ranges
::
range_value_t
<
Lengths
>,
std
::
size_t
>
&&
std
::
is_convertible_v
<
ck_tile
::
ranges
::
range_value_t
<
Strides
>
,
std
::
size_t
>>>
HostTensorDescriptor
(
const
Lengths
&
lens
,
const
Strides
&
strides
)
:
mLens
(
lens
.
begin
(),
lens
.
end
()),
mStrides
(
strides
.
begin
(),
strides
.
end
())
{
}
std
::
size_t
get_num_of_dimension
()
const
{
return
mLens
.
size
();
}
std
::
size_t
get_element_size
()
const
{
assert
(
mLens
.
size
()
==
mStrides
.
size
());
return
std
::
accumulate
(
mLens
.
begin
(),
mLens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
std
::
size_t
get_element_space_size
()
const
{
std
::
size_t
space
=
1
;
for
(
std
::
size_t
i
=
0
;
i
<
mLens
.
size
();
++
i
)
{
if
(
mLens
[
i
]
==
0
)
continue
;
space
+=
(
mLens
[
i
]
-
1
)
*
mStrides
[
i
];
}
return
space
;
}
const
std
::
vector
<
std
::
size_t
>&
get_lengths
()
const
{
return
mLens
;
}
const
std
::
vector
<
std
::
size_t
>&
GetStrides
()
const
{
return
mStrides
;
}
template
<
typename
...
Is
>
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
{
assert
(
sizeof
...(
Is
)
==
this
->
get_num_of_dimension
());
std
::
initializer_list
<
std
::
size_t
>
iss
{
static_cast
<
std
::
size_t
>
(
is
)...};
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
}
std
::
size_t
GetOffsetFromMultiIndex
(
std
::
vector
<
std
::
size_t
>
iss
)
const
{
return
std
::
inner_product
(
iss
.
begin
(),
iss
.
end
(),
mStrides
.
begin
(),
std
::
size_t
{
0
});
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensorDescriptor
&
desc
);
private:
std
::
vector
<
std
::
size_t
>
mLens
;
std
::
vector
<
std
::
size_t
>
mStrides
;
};
template
<
typename
New2Old
>
CK_TILE_HOST
HostTensorDescriptor
transpose_host_tensor_descriptor_given_new2old
(
const
HostTensorDescriptor
&
a
,
const
New2Old
&
new2old
)
{
std
::
vector
<
std
::
size_t
>
new_lengths
(
a
.
get_num_of_dimension
());
std
::
vector
<
std
::
size_t
>
new_strides
(
a
.
get_num_of_dimension
());
for
(
std
::
size_t
i
=
0
;
i
<
a
.
get_num_of_dimension
();
i
++
)
{
new_lengths
[
i
]
=
a
.
get_lengths
()[
new2old
[
i
]];
new_strides
[
i
]
=
a
.
GetStrides
()[
new2old
[
i
]];
}
return
HostTensorDescriptor
(
new_lengths
,
new_strides
);
}
struct
joinable_thread
:
std
::
thread
{
template
<
typename
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
{
}
joinable_thread
(
joinable_thread
&&
)
=
default
;
joinable_thread
&
operator
=
(
joinable_thread
&&
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
template
<
typename
F
,
typename
...
Xs
>
struct
ParallelTensorFunctor
{
F
mF
;
static
constexpr
std
::
size_t
NDIM
=
sizeof
...(
Xs
);
std
::
array
<
std
::
size_t
,
NDIM
>
mLens
;
std
::
array
<
std
::
size_t
,
NDIM
>
mStrides
;
std
::
size_t
mN1d
;
ParallelTensorFunctor
(
F
f
,
Xs
...
xs
)
:
mF
(
f
),
mLens
({
static_cast
<
std
::
size_t
>
(
xs
)...})
{
mStrides
.
back
()
=
1
;
std
::
partial_sum
(
mLens
.
rbegin
(),
mLens
.
rend
()
-
1
,
mStrides
.
rbegin
()
+
1
,
std
::
multiplies
<
std
::
size_t
>
());
mN1d
=
mStrides
[
0
]
*
mLens
[
0
];
}
std
::
array
<
std
::
size_t
,
NDIM
>
GetNdIndices
(
std
::
size_t
i
)
const
{
std
::
array
<
std
::
size_t
,
NDIM
>
indices
;
for
(
std
::
size_t
idim
=
0
;
idim
<
NDIM
;
++
idim
)
{
indices
[
idim
]
=
i
/
mStrides
[
idim
];
i
-=
indices
[
idim
]
*
mStrides
[
idim
];
}
return
indices
;
}
void
operator
()(
std
::
size_t
num_thread
=
1
)
const
{
std
::
size_t
work_per_thread
=
(
mN1d
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
iw_begin
=
it
*
work_per_thread
;
std
::
size_t
iw_end
=
std
::
min
((
it
+
1
)
*
work_per_thread
,
mN1d
);
auto
f
=
[
this
,
iw_begin
,
iw_end
]
{
for
(
std
::
size_t
iw
=
iw_begin
;
iw
<
iw_end
;
++
iw
)
{
call_f_unpack_args
(
this
->
mF
,
this
->
GetNdIndices
(
iw
));
}
};
threads
[
it
]
=
joinable_thread
(
f
);
}
}
};
template
<
typename
F
,
typename
...
Xs
>
CK_TILE_HOST
auto
make_ParallelTensorFunctor
(
F
f
,
Xs
...
xs
)
{
return
ParallelTensorFunctor
<
F
,
Xs
...
>
(
f
,
xs
...);
}
template
<
typename
T
>
struct
HostTensor
{
using
Descriptor
=
HostTensorDescriptor
;
using
Data
=
std
::
vector
<
T
>
;
template
<
typename
X
>
HostTensor
(
std
::
initializer_list
<
X
>
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
get_element_space_size
())
{
}
template
<
typename
X
,
typename
Y
>
HostTensor
(
std
::
initializer_list
<
X
>
lens
,
std
::
initializer_list
<
Y
>
strides
)
:
mDesc
(
lens
,
strides
),
mData
(
mDesc
.
get_element_space_size
())
{
}
template
<
typename
Lengths
>
HostTensor
(
const
Lengths
&
lens
)
:
mDesc
(
lens
),
mData
(
mDesc
.
get_element_space_size
())
{
}
template
<
typename
Lengths
,
typename
Strides
>
HostTensor
(
const
Lengths
&
lens
,
const
Strides
&
strides
)
:
mDesc
(
lens
,
strides
),
mData
(
get_element_space_size
())
{
}
HostTensor
(
const
Descriptor
&
desc
)
:
mDesc
(
desc
),
mData
(
mDesc
.
get_element_space_size
())
{}
template
<
typename
OutT
>
HostTensor
<
OutT
>
CopyAsType
()
const
{
HostTensor
<
OutT
>
ret
(
mDesc
);
std
::
transform
(
mData
.
cbegin
(),
mData
.
cend
(),
ret
.
mData
.
begin
(),
[](
auto
value
)
{
return
ck_tile
::
type_convert
<
OutT
>
(
value
);
});
return
ret
;
}
HostTensor
()
=
delete
;
HostTensor
(
const
HostTensor
&
)
=
default
;
HostTensor
(
HostTensor
&&
)
=
default
;
~
HostTensor
()
=
default
;
HostTensor
&
operator
=
(
const
HostTensor
&
)
=
default
;
HostTensor
&
operator
=
(
HostTensor
&&
)
=
default
;
template
<
typename
FromT
>
explicit
HostTensor
(
const
HostTensor
<
FromT
>&
other
)
:
HostTensor
(
other
.
template
CopyAsType
<
T
>())
{
}
decltype
(
auto
)
get_lengths
()
const
{
return
mDesc
.
get_lengths
();
}
decltype
(
auto
)
GetStrides
()
const
{
return
mDesc
.
GetStrides
();
}
std
::
size_t
get_num_of_dimension
()
const
{
return
mDesc
.
get_num_of_dimension
();
}
std
::
size_t
get_element_size
()
const
{
return
mDesc
.
get_element_size
();
}
std
::
size_t
get_element_space_size
()
const
{
return
mDesc
.
get_element_space_size
();
}
std
::
size_t
get_element_space_size_in_bytes
()
const
{
return
sizeof
(
T
)
*
get_element_space_size
();
}
// void SetZero() { ck_tile::ranges::fill<T>(mData, 0); }
void
SetZero
()
{
std
::
fill
(
mData
.
begin
(),
mData
.
end
(),
0
);
}
template
<
typename
F
>
void
ForEach_impl
(
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
{
if
(
rank
==
mDesc
.
get_num_of_dimension
())
{
f
(
*
this
,
idx
);
return
;
}
// else
for
(
size_t
i
=
0
;
i
<
mDesc
.
get_lengths
()[
rank
];
i
++
)
{
idx
[
rank
]
=
i
;
ForEach_impl
(
std
::
forward
<
F
>
(
f
),
idx
,
rank
+
1
);
}
}
template
<
typename
F
>
void
ForEach
(
F
&&
f
)
{
std
::
vector
<
size_t
>
idx
(
mDesc
.
get_num_of_dimension
(),
0
);
ForEach_impl
(
std
::
forward
<
F
>
(
f
),
idx
,
size_t
(
0
));
}
template
<
typename
F
>
void
ForEach_impl
(
const
F
&&
f
,
std
::
vector
<
size_t
>&
idx
,
size_t
rank
)
const
{
if
(
rank
==
mDesc
.
get_num_of_dimension
())
{
f
(
*
this
,
idx
);
return
;
}
// else
for
(
size_t
i
=
0
;
i
<
mDesc
.
get_lengths
()[
rank
];
i
++
)
{
idx
[
rank
]
=
i
;
ForEach_impl
(
std
::
forward
<
const
F
>
(
f
),
idx
,
rank
+
1
);
}
}
template
<
typename
F
>
void
ForEach
(
const
F
&&
f
)
const
{
std
::
vector
<
size_t
>
idx
(
mDesc
.
get_num_of_dimension
(),
0
);
ForEach_impl
(
std
::
forward
<
const
F
>
(
f
),
idx
,
size_t
(
0
));
}
template
<
typename
G
>
void
GenerateTensorValue
(
G
g
,
std
::
size_t
num_thread
=
1
)
{
switch
(
mDesc
.
get_num_of_dimension
())
{
case
1
:
{
auto
f
=
[
&
](
auto
i
)
{
(
*
this
)(
i
)
=
g
(
i
);
};
make_ParallelTensorFunctor
(
f
,
mDesc
.
get_lengths
()[
0
])(
num_thread
);
break
;
}
case
2
:
{
auto
f
=
[
&
](
auto
i0
,
auto
i1
)
{
(
*
this
)(
i0
,
i1
)
=
g
(
i0
,
i1
);
};
make_ParallelTensorFunctor
(
f
,
mDesc
.
get_lengths
()[
0
],
mDesc
.
get_lengths
()[
1
])(
num_thread
);
break
;
}
case
3
:
{
auto
f
=
[
&
](
auto
i0
,
auto
i1
,
auto
i2
)
{
(
*
this
)(
i0
,
i1
,
i2
)
=
g
(
i0
,
i1
,
i2
);
};
make_ParallelTensorFunctor
(
f
,
mDesc
.
get_lengths
()[
0
],
mDesc
.
get_lengths
()[
1
],
mDesc
.
get_lengths
()[
2
])(
num_thread
);
break
;
}
case
4
:
{
auto
f
=
[
&
](
auto
i0
,
auto
i1
,
auto
i2
,
auto
i3
)
{
(
*
this
)(
i0
,
i1
,
i2
,
i3
)
=
g
(
i0
,
i1
,
i2
,
i3
);
};
make_ParallelTensorFunctor
(
f
,
mDesc
.
get_lengths
()[
0
],
mDesc
.
get_lengths
()[
1
],
mDesc
.
get_lengths
()[
2
],
mDesc
.
get_lengths
()[
3
])(
num_thread
);
break
;
}
case
5
:
{
auto
f
=
[
&
](
auto
i0
,
auto
i1
,
auto
i2
,
auto
i3
,
auto
i4
)
{
(
*
this
)(
i0
,
i1
,
i2
,
i3
,
i4
)
=
g
(
i0
,
i1
,
i2
,
i3
,
i4
);
};
make_ParallelTensorFunctor
(
f
,
mDesc
.
get_lengths
()[
0
],
mDesc
.
get_lengths
()[
1
],
mDesc
.
get_lengths
()[
2
],
mDesc
.
get_lengths
()[
3
],
mDesc
.
get_lengths
()[
4
])(
num_thread
);
break
;
}
case
6
:
{
auto
f
=
[
&
](
auto
i0
,
auto
i1
,
auto
i2
,
auto
i3
,
auto
i4
,
auto
i5
)
{
(
*
this
)(
i0
,
i1
,
i2
,
i3
,
i4
,
i5
)
=
g
(
i0
,
i1
,
i2
,
i3
,
i4
,
i5
);
};
make_ParallelTensorFunctor
(
f
,
mDesc
.
get_lengths
()[
0
],
mDesc
.
get_lengths
()[
1
],
mDesc
.
get_lengths
()[
2
],
mDesc
.
get_lengths
()[
3
],
mDesc
.
get_lengths
()[
4
],
mDesc
.
get_lengths
()[
5
])(
num_thread
);
break
;
}
default:
throw
std
::
runtime_error
(
"unspported dimension"
);
}
}
template
<
typename
...
Is
>
std
::
size_t
GetOffsetFromMultiIndex
(
Is
...
is
)
const
{
return
mDesc
.
GetOffsetFromMultiIndex
(
is
...);
}
template
<
typename
...
Is
>
T
&
operator
()(
Is
...
is
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
template
<
typename
...
Is
>
const
T
&
operator
()(
Is
...
is
)
const
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
is
...)];
}
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
const
T
&
operator
()(
std
::
vector
<
std
::
size_t
>
idx
)
const
{
return
mData
[
mDesc
.
GetOffsetFromMultiIndex
(
idx
)];
}
typename
Data
::
iterator
begin
()
{
return
mData
.
begin
();
}
typename
Data
::
iterator
end
()
{
return
mData
.
end
();
}
typename
Data
::
pointer
data
()
{
return
mData
.
data
();
}
typename
Data
::
const_iterator
begin
()
const
{
return
mData
.
begin
();
}
typename
Data
::
const_iterator
end
()
const
{
return
mData
.
end
();
}
typename
Data
::
const_pointer
data
()
const
{
return
mData
.
data
();
}
typename
Data
::
size_type
size
()
const
{
return
mData
.
size
();
}
template
<
typename
U
=
T
>
auto
AsSpan
()
const
{
constexpr
std
::
size_t
FromSize
=
sizeof
(
T
);
constexpr
std
::
size_t
ToSize
=
sizeof
(
U
);
using
Element
=
std
::
add_const_t
<
std
::
remove_reference_t
<
U
>>
;
return
ck_tile
::
span
<
Element
>
{
reinterpret_cast
<
Element
*>
(
data
()),
size
()
*
FromSize
/
ToSize
};
}
template
<
typename
U
=
T
>
auto
AsSpan
()
{
constexpr
std
::
size_t
FromSize
=
sizeof
(
T
);
constexpr
std
::
size_t
ToSize
=
sizeof
(
U
);
using
Element
=
std
::
remove_reference_t
<
U
>
;
return
ck_tile
::
span
<
Element
>
{
reinterpret_cast
<
Element
*>
(
data
()),
size
()
*
FromSize
/
ToSize
};
}
Descriptor
mDesc
;
Data
mData
;
};
}
// namespace ck_tile
include/ck_tile/host/kernel_launch.hpp
0 → 100644
View file @
20ddaeba
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <hip/hip_runtime.h>
#include <cstddef>
namespace
ck_tile
{
template
<
int
MaxThreadPerBlock
,
int
MinBlockPerCu
,
typename
Kernel
,
typename
...
Args
>
#if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__
(
MaxThreadPerBlock
,
MinBlockPerCu
)
#endif
__global__
void
kentry
(
Kernel
f
,
Args
...
args
)
{
f
(
args
...);
}
template
<
typename
...
Args
,
typename
F
>
CK_TILE_HOST
float
launch_and_time_kernel
(
const
stream_config
&
s
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
#if CK_TILE_TIME_KERNEL
if
(
s
.
time_kernel_
)
{
// warm up
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
const
int
nrepeat
=
s
.
nrepeat_
;
hipEvent_t
start
,
stop
;
HIP_CHECK_ERROR
(
hipEventCreate
(
&
start
));
HIP_CHECK_ERROR
(
hipEventCreate
(
&
stop
));
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
HIP_CHECK_ERROR
(
hipEventRecord
(
start
,
s
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
HIP_CHECK_ERROR
(
hipEventRecord
(
stop
,
s
.
stream_id_
));
HIP_CHECK_ERROR
(
hipEventSynchronize
(
stop
));
float
total_time
=
0
;
HIP_CHECK_ERROR
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
return
total_time
/
nrepeat
;
}
else
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
}
template
<
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
CK_TILE_HOST
float
launch_and_time_kernel_with_preprocess
(
const
stream_config
&
s
,
PreProcessFunc
preprocess
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
#if CK_TILE_TIME_KERNEL
if
(
s
.
time_kernel_
)
{
#if CK_TILE_DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up 1 time
\n
"
);
#endif
// warm up
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
const
int
nrepeat
=
10
;
#if CK_TILE_DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
hipEvent_t
start
,
stop
;
HIP_CHECK_ERROR
(
hipEventCreate
(
&
start
));
HIP_CHECK_ERROR
(
hipEventCreate
(
&
stop
));
HIP_CHECK_ERROR
(
hipDeviceSynchronize
());
HIP_CHECK_ERROR
(
hipEventRecord
(
start
,
s
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
}
HIP_CHECK_ERROR
(
hipEventRecord
(
stop
,
s
.
stream_id_
));
HIP_CHECK_ERROR
(
hipEventSynchronize
(
stop
));
float
total_time
=
0
;
HIP_CHECK_ERROR
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
return
total_time
/
nrepeat
;
}
else
{
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
s
.
stream_id_
>>>
(
args
...);
hip_check_error
(
hipGetLastError
());
return
0
;
#endif
}
template
<
int
MaxThreadPerBlock
=
CK_TILE_MAX_THREAD_PER_BLOCK
,
int
MinBlockPerCu
=
CK_TILE_MIN_BLOCK_PER_CU
,
typename
KernelImpl
,
typename
...
Args
>
CK_TILE_HOST
float
launch_kernel
(
const
stream_config
&
s
,
KernelImpl
kernel_impl
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
dynamic_smem_byte
,
Args
...
args
)
{
const
auto
kernel
=
kentry
<
MaxThreadPerBlock
,
MinBlockPerCu
,
KernelImpl
,
Args
...
>
;
return
launch_and_time_kernel
(
s
,
kernel
,
grid_dim
,
block_dim
,
dynamic_smem_byte
,
kernel_impl
,
args
...);
}
}
// namespace ck_tile
Prev
1
…
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment