Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
300337cd
Commit
300337cd
authored
May 30, 2024
by
letaoqin
Browse files
Merge branch 'develop' into jizhan/reduce_threadwise_multi_d
parents
f306d02e
02fa2c29
Changes
255
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
3567 additions
and
0 deletions
+3567
-0
include/ck_tile/core/numeric/numeric.hpp
include/ck_tile/core/numeric/numeric.hpp
+191
-0
include/ck_tile/core/numeric/type_convert.hpp
include/ck_tile/core/numeric/type_convert.hpp
+66
-0
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+185
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+1068
-0
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+81
-0
include/ck_tile/core/tensor/null_tensor.hpp
include/ck_tile/core/tensor/null_tensor.hpp
+12
-0
include/ck_tile/core/tensor/null_tile_window.hpp
include/ck_tile/core/tensor/null_tile_window.hpp
+88
-0
include/ck_tile/core/tensor/shuffle_tile.hpp
include/ck_tile/core/tensor/shuffle_tile.hpp
+177
-0
include/ck_tile/core/tensor/slice_tile.hpp
include/ck_tile/core/tensor/slice_tile.hpp
+92
-0
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+190
-0
include/ck_tile/core/tensor/store_tile.hpp
include/ck_tile/core/tensor/store_tile.hpp
+93
-0
include/ck_tile/core/tensor/sweep_tile.hpp
include/ck_tile/core/tensor/sweep_tile.hpp
+30
-0
include/ck_tile/core/tensor/tensor_adaptor.hpp
include/ck_tile/core/tensor/tensor_adaptor.hpp
+945
-0
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
+257
-0
include/ck_tile/core/tensor/tensor_coordinate.hpp
include/ck_tile/core/tensor/tensor_coordinate.hpp
+92
-0
No files found.
Too many changes to show.
To preserve performance only
255 of 255+
files are displayed.
Plain diff
Email patch
include/ck_tile/core/numeric/numeric.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include <limits>
#include <stdint.h>
namespace
ck_tile
{
// this struct has the information of
// 1. limit of a certain type, simliar to std::numeric_limits
// 2. some pre-defined value, zero, one...
//
template
<
typename
T
>
struct
numeric
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
T
min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
T
lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
T
max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
T
epsilon
()
{
return
std
::
numeric_limits
<
T
>::
epsilon
();
}
// maximum rounding error
CK_TILE_HOST_DEVICE
static
constexpr
T
round_error
()
{
return
std
::
numeric_limits
<
T
>::
round_error
();
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
T
infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
();
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
T
quiet_NaN
()
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
T
signaling_NaN
()
{
return
std
::
numeric_limits
<
T
>::
signaling_NaN
();
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
T
denorm_min
()
{
return
std
::
numeric_limits
<
T
>::
denorm_min
();
}
CK_TILE_HOST_DEVICE
static
constexpr
T
zero
()
{
return
static_cast
<
T
>
(
0
);
}
CK_TILE_HOST_DEVICE
static
constexpr
T
one
()
{
return
static_cast
<
T
>
(
1
);
}
#ifndef C_LOG2E
#define C_LOG2E 1.44269504088896340736 // log2(e)
#endif
CK_TILE_HOST_DEVICE
static
constexpr
T
log2e
()
{
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
)
{
return
static_cast
<
T
>
(
C_LOG2E
);
}
else
{
return
0
;
// TODO: integer?
}
}
};
template
<
typename
T
>
struct
numeric_traits
;
template
<
>
struct
numeric_traits
<
float
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
int
bias
=
127
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
using
bitwise_type
=
uint32_t
;
};
}
// namespace ck_tile
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_) \
attr_ bool operator==(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) == static_cast<float>(y); \
} \
attr_ bool operator!=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) != static_cast<float>(y); \
} \
attr_ bool operator<(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) < static_cast<float>(y); \
} \
attr_ bool operator<=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) <= static_cast<float>(y); \
} \
attr_ bool operator>(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) > static_cast<float>(y); \
} \
attr_ bool operator>=(const type_& x, const type_& y) \
{ \
return static_cast<float>(x) >= static_cast<float>(y); \
} \
attr_ type_ operator+(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) + static_cast<float>(y)); \
} \
attr_ type_ operator-(const type_& x) \
{ \
constexpr uint32_t bits = sizeof(type_) * 8; \
constexpr uint32_t mask = 1 << (bits - 1); \
type_ y = x; \
y.data ^= static_cast<typename type_::raw_type>(mask); \
return y; \
} \
attr_ type_ operator-(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) - static_cast<float>(y)); \
} \
attr_ type_ operator*(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) * static_cast<float>(y)); \
} \
attr_ type_ operator/(const type_& x, const type_& y) \
{ \
return type_(static_cast<float>(x) / static_cast<float>(y)); \
} \
attr_ type_& operator+=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) + static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator-=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) - static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator*=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) * static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator/=(type_& x, const type_& y) \
{ \
x = type_(static_cast<float>(x) / static_cast<float>(y)); \
return x; \
} \
attr_ type_& operator++(type_& x) \
{ \
x = type_(static_cast<float>(x) + 1.f); \
return x; \
} \
attr_ type_& operator--(type_& x) \
{ \
x = type_(static_cast<float>(x) - 1.f); \
return x; \
} \
attr_ type_ operator++(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) + 1.f); \
return y; \
} \
attr_ type_ operator--(type_& x, int) \
{ \
type_ y(x); \
x = type_(static_cast<float>(x) - 1.f); \
return y; \
}
include/ck_tile/core/numeric/type_convert.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
#include <tuple>
#include <type_traits>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
namespace
ck_tile
{
#if CK_TILE_USE_CUSTOM_DATA_TYPE
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
Y
>
type_convert
(
const
X
&
x
)
{
return
static_cast
<
Y
>
(
x
);
}
#else
// Convert X to Y, both X and Y are non-const data types.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
// Convert X to Y, either X or Y is a const data type.
template
<
typename
Y
,
typename
X
,
std
::
enable_if_t
<
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
,
bool
>
=
false
>
CK_TILE_HOST_DEVICE
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
using
non_const_y
=
std
::
remove_const_t
<
Y
>
;
using
non_const_x
=
std
::
remove_const_t
<
X
>
;
return
static_cast
<
Y
>
(
type_convert
<
non_const_y
,
non_const_x
>
(
x
));
}
#define CK_TILE_TYPE_CONVERT(dtype_, dname_, stype_, sname_) \
template <> \
CK_TILE_HOST_DEVICE constexpr dtype_ type_convert<dtype_, stype_>(stype_ x) \
{ \
return sname_##_to_##dname_(x); \
}
CK_TILE_TYPE_CONVERT
(
float
,
float
,
fp16_t
,
fp16
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
bf16_t
,
bf16
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
fp8_t
,
fp8
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
bf8_t
,
bf8
)
CK_TILE_TYPE_CONVERT
(
fp16_t
,
fp16
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
bf16_t
,
bf16
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
fp8_t
,
fp8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
bf8_t
,
bf8
,
float
,
float
)
#undef CK_TILE_TYPE_CONVERT
#endif
}
// namespace ck_tile
include/ck_tile/core/numeric/vector_type.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// this structure is used to pick up the <base> type inside
// using xxx = <base> __attribute__((ext_vector_type(N)));
// because clang only allow native type + bool in this term (custom type will fail)
// overload this structure to let proper <base> type
template
<
typename
T
>
struct
native_t
{
using
type
=
remove_cvref_t
<
T
>
;
};
// we name this as ext_vector purposely, because clang ext_vector_type extention only accept literay
// basic type to construct a ext_vector_type you must be very careful using this, or will have lot
// of compiler errors e.g. struct A; using Ax2_t = A __attribute__((ext_vector_type(2))); -> will
// have compiler error
namespace
impl
{
template
<
typename
T_
,
index_t
N_
>
struct
ext_vector
{
static
constexpr
index_t
N
=
N_
;
using
value_type
=
typename
native_t
<
remove_cvref_t
<
T_
>>::
type
;
static_assert
(
!
std
::
is_class_v
<
value_type
>
);
using
type
=
value_type
__attribute__
((
ext_vector_type
(
N
)));
// this is danguous
};
template
<
typename
V_
,
index_t
Vs_
,
index_t
N_
>
struct
ext_vector
<
V_
__attribute__
((
ext_vector_type
(
Vs_
))),
N_
>
{
static
constexpr
index_t
N
=
Vs_
*
N_
;
using
value_type
=
typename
native_t
<
remove_cvref_t
<
V_
>>::
type
;
static_assert
(
!
std
::
is_class_v
<
value_type
>
);
using
type
=
value_type
__attribute__
((
ext_vector_type
(
N
)));
// this is danguous
};
}
// namespace impl
template
<
typename
T
,
index_t
N
>
using
ext_vector_t
=
typename
impl
::
ext_vector
<
T
,
N
>::
type
;
// by default, any type will result in a vector_size=1 with scalar_type=T traits.
// ... unless we have other vector_traits specialization
template
<
typename
T
>
struct
vector_traits
{
using
scalar_type
=
remove_cvref_t
<
T
>
;
static
constexpr
index_t
vector_size
=
1
;
};
// specialization for ext_vector_type()
template
<
typename
T
,
index_t
N
>
struct
vector_traits
<
T
__attribute__
((
ext_vector_type
(
N
)))
>
{
using
scalar_type
=
T
;
static
constexpr
index_t
vector_size
=
N
;
};
template
<
typename
X
,
typename
Y
>
using
has_same_scalar_type
=
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
Y
>>::
scalar_type
>
;
// below are some pre-defines of ext_vector_type
// attention! 2 vector type could be just the same type
// fp64
using
fp64_t
=
double
;
using
fp64x2_t
=
double
__attribute__
((
ext_vector_type
(
2
)));
using
fp64x4_t
=
double
__attribute__
((
ext_vector_type
(
4
)));
// fp32
using
fp32_t
=
float
;
using
fp32x2_t
=
float
__attribute__
((
ext_vector_type
(
2
)));
using
fp32x4_t
=
float
__attribute__
((
ext_vector_type
(
4
)));
using
fp32x8_t
=
float
__attribute__
((
ext_vector_type
(
8
)));
using
fp32x16_t
=
float
__attribute__
((
ext_vector_type
(
16
)));
using
fp32x32_t
=
float
__attribute__
((
ext_vector_type
(
32
)));
using
fp32x64_t
=
float
__attribute__
((
ext_vector_type
(
64
)));
// fp16
// using fp16_t = ...
using
fp16x2_t
=
_Float16
__attribute__
((
ext_vector_type
(
2
)));
using
fp16x4_t
=
_Float16
__attribute__
((
ext_vector_type
(
4
)));
using
fp16x8_t
=
_Float16
__attribute__
((
ext_vector_type
(
8
)));
using
fp16x16_t
=
_Float16
__attribute__
((
ext_vector_type
(
16
)));
using
fp16x32_t
=
_Float16
__attribute__
((
ext_vector_type
(
32
)));
using
fp16x64_t
=
_Float16
__attribute__
((
ext_vector_type
(
64
)));
// bf16
// using bf16_t = ...
using
bf16x2_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
2
)));
using
bf16x4_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
4
)));
using
bf16x8_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
8
)));
using
bf16x16_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
16
)));
using
bf16x32_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
32
)));
using
bf16x64_t
=
bf16_raw_t
__attribute__
((
ext_vector_type
(
64
)));
// i32
// using int32_t = ...
using
int32x2_t
=
int32_t
__attribute__
((
ext_vector_type
(
2
)));
using
int32x4_t
=
int32_t
__attribute__
((
ext_vector_type
(
4
)));
using
int32x8_t
=
int32_t
__attribute__
((
ext_vector_type
(
8
)));
using
int32x16_t
=
int32_t
__attribute__
((
ext_vector_type
(
16
)));
using
int32x32_t
=
int32_t
__attribute__
((
ext_vector_type
(
32
)));
using
int32x64_t
=
int32_t
__attribute__
((
ext_vector_type
(
64
)));
// i16
// using int16_t = ...
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
using
int16x4_t
=
int16_t
__attribute__
((
ext_vector_type
(
4
)));
using
int16x8_t
=
int16_t
__attribute__
((
ext_vector_type
(
8
)));
using
int16x16_t
=
int16_t
__attribute__
((
ext_vector_type
(
16
)));
using
int16x32_t
=
int16_t
__attribute__
((
ext_vector_type
(
32
)));
using
int16x64_t
=
int16_t
__attribute__
((
ext_vector_type
(
64
)));
// u16
// using uint16_t
using
uint16x2_t
=
uint16_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint16x4_t
=
uint16_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint16x8_t
=
uint16_t
__attribute__
((
ext_vector_type
(
8
)));
using
uint16x16_t
=
uint16_t
__attribute__
((
ext_vector_type
(
16
)));
using
uint16x32_t
=
uint16_t
__attribute__
((
ext_vector_type
(
32
)));
using
uint16x64_t
=
uint16_t
__attribute__
((
ext_vector_type
(
64
)));
// i8
// using int8_t
using
int8x2_t
=
int8_t
__attribute
((
ext_vector_type
(
2
)));
using
int8x4_t
=
int8_t
__attribute
((
ext_vector_type
(
4
)));
using
int8x8_t
=
int8_t
__attribute
((
ext_vector_type
(
8
)));
using
int8x16_t
=
int8_t
__attribute
((
ext_vector_type
(
16
)));
using
int8x32_t
=
int8_t
__attribute
((
ext_vector_type
(
32
)));
using
int8x64_t
=
int8_t
__attribute
((
ext_vector_type
(
64
)));
#if CK_TILE_USE_CUSTOM_DATA_TYPE
// f8
// using fp8_t
using
fp8x2_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x4_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x8_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
8
)));
using
fp8x16_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
16
)));
using
fp8x32_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
32
)));
using
fp8x64_t
=
fp8_raw_t
__attribute
((
ext_vector_type
(
64
)));
// bf8
// using bf8_t
using
bf8x2_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x4_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x8_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
8
)));
using
bf8x16_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
16
)));
using
bf8x32_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
32
)));
using
bf8x64_t
=
bf8_raw_t
__attribute
((
ext_vector_type
(
64
)));
#else
// f8
// using fp8_t
using
fp8x2_t
=
fp8_t
__attribute
((
ext_vector_type
(
2
)));
using
fp8x4_t
=
fp8_t
__attribute
((
ext_vector_type
(
4
)));
using
fp8x8_t
=
fp8_t
__attribute
((
ext_vector_type
(
8
)));
using
fp8x16_t
=
fp8_t
__attribute
((
ext_vector_type
(
16
)));
using
fp8x32_t
=
fp8_t
__attribute
((
ext_vector_type
(
32
)));
using
fp8x64_t
=
fp8_t
__attribute
((
ext_vector_type
(
64
)));
// bf8
// using bf8_t
using
bf8x2_t
=
bf8_t
__attribute
((
ext_vector_type
(
2
)));
using
bf8x4_t
=
bf8_t
__attribute
((
ext_vector_type
(
4
)));
using
bf8x8_t
=
bf8_t
__attribute
((
ext_vector_type
(
8
)));
using
bf8x16_t
=
bf8_t
__attribute
((
ext_vector_type
(
16
)));
using
bf8x32_t
=
bf8_t
__attribute
((
ext_vector_type
(
32
)));
using
bf8x64_t
=
bf8_t
__attribute
((
ext_vector_type
(
64
)));
#endif
}
// namespace ck_tile
include/ck_tile/core/tensor/buffer_view.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
// FIXME: amd_buffer_coherence_enum is only meaningful for buffer addressing. Need to split
// buffer_view definition for different memory address space (Global/GenericLds/Vgpr)
template
<
address_space_enum
BufferAddressSpace
,
typename
T
,
typename
BufferSizeType
,
bool
InvalidElementUseNumericalZeroValue
,
amd_buffer_coherence_enum
Coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
struct
buffer_view
;
// Address Space: generic
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
template
<
typename
T
,
typename
BufferSizeType
,
bool
InvalidElementUseNumericalZeroValue
>
struct
buffer_view
<
address_space_enum
::
generic
,
T
,
BufferSizeType
,
InvalidElementUseNumericalZeroValue
,
amd_buffer_coherence_enum
::
coherence_default
>
{
using
type
=
T
;
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
{
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
generic
;
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
return
tmp
;
#else
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
]);
#endif
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
X
{
numeric
<
remove_cvref_t
<
T
>>::
zero
()};
}
else
{
return
X
{
invalid_element_value_
};
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#endif
}
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_static_buffer
()
{
return
false
;
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_dynamic_buffer
()
{
return
true
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"buffer_view{"
);
// AddressSpace
printf
(
"AddressSpace: generic, "
);
// p_data_
printf
(
"p_data_: %p, "
,
static_cast
<
void
*>
(
const_cast
<
remove_cvref_t
<
T
>*>
(
p_data_
)));
// buffer_size_
printf
(
"buffer_size_: "
);
print
(
buffer_size_
);
printf
(
", "
);
// invalid_element_value_
printf
(
"invalid_element_value_: "
);
print
(
invalid_element_value_
);
printf
(
"}"
);
}
};
// Address Space: Global
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
template
<
typename
T
,
typename
BufferSizeType
,
bool
InvalidElementUseNumericalZeroValue
,
amd_buffer_coherence_enum
Coherence
>
struct
buffer_view
<
address_space_enum
::
global
,
T
,
BufferSizeType
,
InvalidElementUseNumericalZeroValue
,
Coherence
>
{
using
type
=
T
;
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
{
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
global
;
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
#if CK_TILE_USE_AMD_BUFFER_LOAD
bool
constexpr
use_amd_buffer_addressing
=
true
;
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
if
constexpr
(
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
amd_buffer_load_invalid_element_return_zero
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
else
{
return
amd_buffer_load_invalid_element_return_customized_value
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
p_data_
,
i
,
is_valid_element
,
buffer_size_
,
invalid_element_value_
);
}
}
else
{
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
return
tmp
;
#else
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
]);
#endif
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
X
{
numeric
<
remove_cvref_t
<
T
>>::
zero
()};
}
else
{
return
X
{
invalid_element_value_
};
}
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
i
,
bool
is_valid_element
)
const
{
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
dst
,
p_data_
,
i
,
buffer_size_
,
is_valid_element
);
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
async_get
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
/*is_valid_element*/
)
const
{
// X is vector of T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
p_data_
,
i
,
buffer_size_
);
}
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
{
this
->
template
atomic_add
<
X
>(
i
,
is_valid_element
,
x
);
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
{
this
->
template
atomic_max
<
X
>(
i
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
#if CK_TILE_USE_AMD_BUFFER_STORE
bool
constexpr
use_amd_buffer_addressing
=
true
;
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
if
constexpr
(
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
else
{
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#endif
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set_raw
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
atomic_add
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
using
scalar_t
=
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
;
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
static_assert
(
get_address_space
()
==
address_space_enum
::
global
,
"only support global mem"
);
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
||
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
if
constexpr
(
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
else
{
if
(
is_valid_element
)
{
atomic_add
<
X
>
(
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
]),
x
);
}
}
}
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
atomic_max
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
static_assert
(
get_address_space
()
==
address_space_enum
::
global
,
"only support global mem"
);
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
using
scalar_t
=
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
;
bool
constexpr
use_amd_buffer_addressing
=
std
::
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
double
>
;
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
if
constexpr
(
use_amd_buffer_addressing
)
{
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
}
else
if
(
is_valid_element
)
{
atomic_max
<
X
>
(
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
]),
x
);
}
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_static_buffer
()
{
return
false
;
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_dynamic_buffer
()
{
return
true
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"buffer_view{"
);
// AddressSpace
printf
(
"AddressSpace: Global, "
);
// p_data_
printf
(
"p_data_: %p, "
,
static_cast
<
void
*>
(
const_cast
<
remove_cvref_t
<
T
>*>
(
p_data_
)));
// buffer_size_
printf
(
"buffer_size_: "
);
print
(
buffer_size_
);
printf
(
", "
);
// invalid_element_value_
printf
(
"invalid_element_value_: "
);
print
(
invalid_element_value_
);
printf
(
"}"
);
}
};
// Address Space: LDS
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
template
<
typename
T
,
typename
BufferSizeType
,
bool
InvalidElementUseNumericalZeroValue
>
struct
buffer_view
<
address_space_enum
::
lds
,
T
,
BufferSizeType
,
InvalidElementUseNumericalZeroValue
,
amd_buffer_coherence_enum
::
coherence_default
>
{
using
type
=
T
;
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
{
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
lds
;
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
return
tmp
;
#else
using
buf_t
=
ext_vector_t
<
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
,
scalar_per_t_vector
*
scalar_per_x_vector
>
;
// using buf_t = ushort __attribute__((ext_vector_type(8)));
auto
rtn
=
*
c_style_pointer_cast
<
const
buf_t
*>
(
&
p_data_
[
i
]);
return
bit_cast
<
X
>
(
rtn
);
#endif
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
X
{
numeric
<
remove_cvref_t
<
T
>>::
zero
()};
}
else
{
return
X
{
invalid_element_value_
};
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
#if CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
bool
constexpr
workaround_int8_ds_write_issue
=
true
;
#else
bool
constexpr
workaround_int8_ds_write_issue
=
false
;
#endif
if
constexpr
(
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
,
int8_t
>::
value
&&
workaround_int8_ds_write_issue
)
{
if
(
is_valid_element
)
{
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
// ds_write_b128
// TODO: remove this after compiler fix
static_assert
((
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x2_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
||
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
),
"wrong! not implemented for this combination, please add "
"implementation"
);
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int8_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int8_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x2_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int16_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int16_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x2_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x4_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x4_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x4_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x4_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x8_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x8_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x2_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x2_t
*>
(
&
x
);
}
else
if
constexpr
(
std
::
is_same
<
remove_cvref_t
<
T
>
,
int8x16_t
>::
value
&&
std
::
is_same
<
remove_cvref_t
<
X
>
,
int8x16_t
>::
value
)
{
// HACK: cast pointer of x is bad
// TODO: remove this after compiler fix
*
c_style_pointer_cast
<
int32x4_t
*>
(
&
p_data_
[
i
])
=
*
c_style_pointer_cast
<
const
int32x4_t
*>
(
&
x
);
}
}
}
else
{
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
#else
using
buf_t
=
ext_vector_t
<
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
,
scalar_per_t_vector
*
scalar_per_x_vector
>
;
*
c_style_pointer_cast
<
buf_t
*>
(
&
p_data_
[
i
])
=
reinterpret_cast
<
const
buf_t
&>
(
x
);
#endif
}
}
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_static_buffer
()
{
return
false
;
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_dynamic_buffer
()
{
return
true
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"buffer_view{"
);
// AddressSpace
printf
(
"AddressSpace: Lds, "
);
// p_data_
printf
(
"p_data_: %p, "
,
static_cast
<
void
*>
(
const_cast
<
remove_cvref_t
<
T
>*>
(
p_data_
)));
// buffer_size_
printf
(
"buffer_size_: "
);
print
(
buffer_size_
);
printf
(
", "
);
// invalid_element_value_
printf
(
"invalid_element_value_: "
);
print
(
invalid_element_value_
);
printf
(
"}"
);
}
};
// Address Space: Vgpr
// T may be scalar or vector
// X may be scalar or vector
// T and X have same scalar type
// X contains multiple T
// FIXME: InvalidElementUseNumericalZeroValue and invalid_element_value_ should be a property of
// transforms of tensor_view/Tensor
template
<
typename
T
,
typename
BufferSizeType
,
bool
InvalidElementUseNumericalZeroValue
>
struct
buffer_view
<
address_space_enum
::
vgpr
,
T
,
BufferSizeType
,
InvalidElementUseNumericalZeroValue
,
amd_buffer_coherence_enum
::
coherence_default
>
{
using
type
=
T
;
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
{
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
{
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
return
address_space_enum
::
vgpr
;
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
// i is offset of T
// FIXME: doesn't do is_valid check
CK_TILE_DEVICE
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
return
tmp
;
#else
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
]);
#endif
}
else
{
if
constexpr
(
InvalidElementUseNumericalZeroValue
)
{
return
X
{
numeric
<
remove_cvref_t
<
T
>>::
zero
()};
}
else
{
return
X
{
invalid_element_value_
};
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
}
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
{
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
if
(
is_valid_element
)
{
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
#endif
}
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_static_buffer
()
{
return
false
;
}
// FIXME: remove
CK_TILE_DEVICE
static
constexpr
bool
is_dynamic_buffer
()
{
return
true
;
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"buffer_view{"
);
// AddressSpace
printf
(
"AddressSpace: Vgpr, "
);
// p_data_
printf
(
"p_data_: %p, "
,
static_cast
<
void
*>
(
const_cast
<
remove_cvref_t
<
T
>*>
(
p_data_
)));
// buffer_size_
printf
(
"buffer_size_: "
);
print
(
buffer_size_
);
printf
(
", "
);
// invalid_element_value_
printf
(
"invalid_element_value_: "
);
print
(
invalid_element_value_
);
printf
(
"}"
);
}
};
template
<
address_space_enum
BufferAddressSpace
,
amd_buffer_coherence_enum
Coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
typename
T
,
typename
BufferSizeType
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_buffer_view
(
T
*
p
,
BufferSizeType
buffer_size
)
{
return
buffer_view
<
BufferAddressSpace
,
T
,
BufferSizeType
,
true
,
Coherence
>
{
p
,
buffer_size
};
}
template
<
address_space_enum
BufferAddressSpace
,
amd_buffer_coherence_enum
Coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
typename
T
,
typename
BufferSizeType
,
typename
X
,
typename
std
::
enable_if
<
std
::
is_same
<
remove_cvref_t
<
T
>,
remove_cvref_t
<
X
>>::
value
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_buffer_view
(
T
*
p
,
BufferSizeType
buffer_size
,
X
invalid_element_value
)
{
return
buffer_view
<
BufferAddressSpace
,
T
,
BufferSizeType
,
false
,
Coherence
>
{
p
,
buffer_size
,
invalid_element_value
};
}
}
// namespace ck_tile
include/ck_tile/core/tensor/load_tile.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
namespace
ck_tile
{
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
T
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
)
{
return
tile_window
.
async_load
(
lds_tile
);
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
{
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
auto
load_tile
(
const
null_tile_window
<
WindowLengths
>&
)
{
return
null_tensor
{};
}
template
<
typename
T
,
typename
WindowLengths
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
/*null_tile*/
,
const
null_tile_window
<
WindowLengths
>&
)
{
}
}
// namespace ck_tile
include/ck_tile/core/tensor/null_tensor.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
struct
null_tensor
{
};
}
// namespace ck_tile
include/ck_tile/core/tensor/null_tile_window.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tensor_view.hpp"
namespace
ck_tile
{
// placeholder type if we want to opt-out a tile window parameter
template
<
typename
WindowLengths_
>
struct
null_tile_window
{
using
BottomTensorView
=
null_tensor_view
;
using
WindowLengths
=
remove_cvref_t
<
WindowLengths_
>
;
using
BottomTensorIndex
=
array
<
index_t
,
WindowLengths
::
size
()
>
;
CK_TILE_DEVICE
constexpr
null_tile_window
()
=
default
;
CK_TILE_DEVICE
constexpr
null_tile_window
(
const
WindowLengths
&
window_lengths
)
:
window_lengths_
{
window_lengths
}
{
}
CK_TILE_DEVICE
constexpr
auto
get_window_lengths
()
const
{
return
window_lengths_
;
}
CK_TILE_DEVICE
constexpr
auto
get_bottom_tensor_view
()
const
{
return
null_tensor_view
{};
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
BottomTensorIndex
{};
}
WindowLengths
window_lengths_
;
};
// utility to check if this is a Null Tile Window
namespace
impl
{
template
<
typename
>
struct
is_null_tile_window
:
public
std
::
false_type
{
};
template
<
typename
T
>
struct
is_null_tile_window
<
null_tile_window
<
T
>>
:
public
std
::
true_type
{
};
}
// namespace impl
template
<
typename
T
>
CK_TILE_DEVICE
constexpr
auto
is_null_tile_window
(
const
T
&
)
{
return
impl
::
is_null_tile_window
<
remove_cvref_t
<
T
>>::
value
;
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
constexpr
auto
make_null_tile_window
(
const
WindowLengths
&
window_lengths
)
{
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths
>::
value
,
"wrong! lengths should be static"
);
return
null_tile_window
<
remove_cvref_t
<
WindowLengths
>>
{
window_lengths
};
}
template
<
typename
WindowLengths
,
typename
...
Ts
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
null_tensor_view
,
const
WindowLengths
&
window_lengths
,
const
multi_index
<
WindowLengths
::
size
()
>&
/*origin*/
,
Ts
&&
...)
{
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths
>::
value
,
"wrong! lengths should be static"
);
return
null_tile_window
<
remove_cvref_t
<
WindowLengths
>>
{
window_lengths
};
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
void
move_tile_window
(
null_tile_window
<
WindowLengths
>&
,
const
typename
null_tile_window
<
WindowLengths
>::
BottomTensorIndex
&
)
{
}
}
// namespace ck_tile
include/ck_tile/core/tensor/shuffle_tile.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
namespace
ck_tile
{
namespace
detail
{
template
<
typename
OutTensor
,
typename
InTensor
>
CK_TILE_DEVICE
void
shuffle_tile_impl_in_thread
(
OutTensor
&
out_tensor
,
const
InTensor
&
in_tensor
)
{
constexpr
auto
I0
=
number
<
0
>
{};
using
DataType
=
typename
InTensor
::
DataType
;
constexpr
auto
y_in_desc
=
InTensor
::
get_tile_distribution
().
get_ys_to_d_descriptor
();
constexpr
auto
y_out_desc
=
OutTensor
::
get_tile_distribution
().
get_ys_to_d_descriptor
();
// y_dim_out_to_in
constexpr
auto
get_rh_major_minor_to_y
=
[](
auto
dstr_tensor
)
{
using
DstrEncode
=
typename
decltype
(
dstr_tensor
.
get_tile_distribution
())
::
DstrEncode
;
map
<
array
<
index_t
,
2
>
,
index_t
>
rh_major_minor_to_y_
;
static_for
<
0
,
DstrEncode
::
NDimY
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
rh_major
=
DstrEncode
::
ys_to_rhs_major_
[
i
];
constexpr
index_t
rh_minor
=
DstrEncode
::
ys_to_rhs_minor_
[
i
];
rh_major_minor_to_y_
({
rh_major
,
rh_minor
})
=
i
;
});
return
rh_major_minor_to_y_
;
};
constexpr
auto
rh_major_minor_to_y_in
=
get_rh_major_minor_to_y
(
InTensor
{});
constexpr
auto
rh_major_minor_to_y_out
=
get_rh_major_minor_to_y
(
OutTensor
{});
constexpr
auto
y_dim_out_to_in
=
[
&
]
{
map
<
index_t
,
index_t
>
y_dim_out_to_in_
;
for
(
const
auto
&
[
rh_major_minor
,
y_out
]
:
rh_major_minor_to_y_out
)
{
y_dim_out_to_in_
(
y_out
)
=
rh_major_minor_to_y_in
[
rh_major_minor
];
}
return
y_dim_out_to_in_
;
}();
//
constexpr
index_t
NDimY
=
InTensor
::
get_tile_distribution
().
get_num_of_dimension_y
();
constexpr
auto
y_lengths
=
to_sequence
(
y_in_desc
.
get_lengths
());
// input and output vector dim in the order of input Y dims
constexpr
index_t
y_dim_vec_in
=
NDimY
-
1
;
constexpr
index_t
y_dim_vec_out
=
y_dim_out_to_in
[
NDimY
-
1
];
// vector lengths
constexpr
index_t
vec_length_in
=
y_lengths
[
y_dim_vec_in
];
constexpr
index_t
vec_length_out
=
y_lengths
[
y_dim_vec_out
];
// # of vectors
constexpr
index_t
num_vec_in
=
vec_length_out
;
constexpr
index_t
num_vec_out
=
vec_length_in
;
using
InVec
=
array
<
DataType
,
vec_length_in
>
;
using
OutVec
=
array
<
DataType
,
vec_length_out
>
;
// using InVec = typename InVec::type;
// using OutVec = typename OutVec::type;
// SFC
constexpr
auto
scalars_per_access_arr
=
generate_array
(
[
&
](
auto
i
)
{
return
(
i
==
y_dim_vec_in
or
i
==
y_dim_vec_out
)
?
y_lengths
[
i
]
:
1
;
},
number
<
NDimY
>
{});
constexpr
auto
scalars_per_access
=
TO_SEQUENCE
(
scalars_per_access_arr
,
NDimY
);
using
SFC_Y
=
space_filling_curve
<
decltype
(
y_lengths
),
typename
arithmetic_sequence_gen
<
0
,
NDimY
,
1
>::
type
,
decltype
(
scalars_per_access
)
>
;
constexpr
index_t
num_access
=
SFC_Y
::
get_num_of_access
();
static_assert
(
num_access
>
0
,
"wrong! num_access should be larger than 0"
);
// in/out vectors to be transposed
thread_buffer
<
InVec
,
num_vec_in
>
in_vectors
;
thread_buffer
<
OutVec
,
num_vec_out
>
out_vectors
;
// loop over SFC and do transpose
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
// data index [y0, y1, ...] in the order of input tensor
constexpr
auto
idx_y_start
=
SFC_Y
::
get_index
(
iAccess
);
// get input vectors
static_for
<
0
,
num_vec_in
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
idx_y_in
=
generate_array
(
[
&
](
auto
ii
)
{
return
ii
==
y_dim_vec_out
?
idx_y_start
[
ii
]
+
i
:
idx_y_start
[
ii
];
},
number
<
NDimY
>
{});
constexpr
index_t
in_offset
=
y_in_desc
.
calculate_offset
(
idx_y_in
);
static_assert
(
in_offset
%
vec_length_in
==
0
);
in_vectors
(
i
).
template
get_as
<
InVec
>()(
I0
)
=
in_tensor
.
get_thread_buffer
()
.
template
get_as
<
InVec
>()[
number
<
in_offset
/
vec_length_in
>
{}];
});
// transpose
transpose_vectors
<
DataType
,
num_vec_in
,
num_vec_out
>
{}(
in_vectors
,
out_vectors
);
// set output vectors
static_for
<
0
,
num_vec_out
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
idx_y_out_tmp
=
generate_array
(
[
&
](
auto
ii
)
{
return
ii
==
y_dim_vec_in
?
idx_y_start
[
ii
]
+
i
:
idx_y_start
[
ii
];
},
number
<
NDimY
>
{});
constexpr
auto
idx_y_out
=
container_reorder_given_new2old
(
idx_y_out_tmp
,
y_dim_out_to_in
);
constexpr
index_t
out_offset
=
y_out_desc
.
calculate_offset
(
idx_y_out
);
static_assert
(
out_offset
%
vec_length_out
==
0
);
out_tensor
.
get_thread_buffer
().
template
set_as
<
OutVec
>(
number
<
out_offset
/
vec_length_out
>
{},
out_vectors
[
i
].
template
get_as
<
OutVec
>()[
I0
]);
});
});
}
}
// namespace detail
template
<
typename
OutTensor
,
typename
InTensor
>
CK_TILE_DEVICE
void
shuffle_tile
(
OutTensor
&
out
,
const
InTensor
&
in
)
{
using
InDataType
=
typename
InTensor
::
DataType
;
using
OutDataType
=
typename
OutTensor
::
DataType
;
using
InDstrEncode
=
typename
InTensor
::
StaticTileDistribution
::
DstrEncode
;
using
OutDstrEncode
=
typename
OutTensor
::
StaticTileDistribution
::
DstrEncode
;
// type convert
const
auto
in_tmp
=
tile_elementwise_in
(
type_convert
<
OutDataType
,
InDataType
>
,
in
);
// shuffle
if
constexpr
(
InDstrEncode
::
rs_lengths_
==
OutDstrEncode
::
rs_lengths_
&&
InDstrEncode
::
hs_lengthss_
==
OutDstrEncode
::
hs_lengthss_
&&
InDstrEncode
::
ps_to_rhss_major_
==
OutDstrEncode
::
ps_to_rhss_major_
&&
InDstrEncode
::
ps_to_rhss_minor_
==
OutDstrEncode
::
ps_to_rhss_minor_
&&
InDstrEncode
::
NDimY
==
OutDstrEncode
::
NDimY
)
{
detail
::
shuffle_tile_impl_in_thread
(
out
,
in_tmp
);
}
else
{
// NOT implemented
}
}
}
// namespace ck_tile
include/ck_tile/core/tensor/slice_tile.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
index_t
...
SliceBegins
,
index_t
...
SliceEnds
>
CK_TILE_DEVICE
constexpr
auto
get_slice_tile
(
const
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile
,
sequence
<
SliceBegins
...
>
slice_begins
,
sequence
<
SliceEnds
...
>
slice_ends
)
{
using
TileWindow
=
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>
;
// NOTE: This API will override the origin of the tile window!
static_assert
(
sizeof
...(
SliceBegins
)
==
sizeof
...(
SliceEnds
));
static_assert
(
sizeof
...(
SliceBegins
)
==
TileWindow
::
get_num_of_dimension
());
constexpr
auto
slice_lengths
=
slice_ends
-
slice_begins
;
return
make_tile_window
(
tile
.
get_bottom_tensor_view
(),
sequence_to_tuple_of_number
(
slice_lengths
),
to_multi_index
(
slice_begins
));
}
template
<
typename
DataType_
,
typename
StaticTileDistribution_
,
index_t
...
SliceBegins
,
index_t
...
SliceEnds
>
CK_TILE_DEVICE
constexpr
auto
get_slice_tile
(
const
static_distributed_tensor
<
DataType_
,
StaticTileDistribution_
>&
tile
,
sequence
<
SliceBegins
...
>
slice_begins
,
sequence
<
SliceEnds
...
>
slice_ends
)
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
using
Distribution
=
remove_cvref_t
<
StaticTileDistribution_
>
;
constexpr
auto
sliced_dstr_yidx_ylen
=
detail
::
slice_distribution_from_x
(
Distribution
{},
slice_begins
,
slice_ends
);
constexpr
auto
sliced_dstr
=
sliced_dstr_yidx_ylen
.
template
at
<
0
>();
constexpr
auto
sliced_y_origins
=
sliced_dstr_yidx_ylen
.
template
at
<
1
>();
constexpr
auto
sliced_y_lengths
=
sliced_dstr_yidx_ylen
.
template
at
<
2
>();
auto
sliced_tensor
=
make_static_distributed_tensor
<
DataType
>
(
sliced_dstr
);
sliced_tensor
.
get_thread_buffer
()
=
tile
.
get_y_sliced_thread_data
(
sliced_y_origins
,
sliced_y_lengths
);
return
sliced_tensor
;
}
template
<
typename
DstDataType_
,
typename
DstStaticTileDistribution_
,
typename
SrcDataType_
,
typename
SrcStaticTileDistribution_
,
index_t
...
SliceBegins
,
index_t
...
SliceEnds
>
CK_TILE_DEVICE
constexpr
auto
set_slice_tile
(
static_distributed_tensor
<
DstDataType_
,
DstStaticTileDistribution_
>&
dst_tile
,
const
static_distributed_tensor
<
SrcDataType_
,
SrcStaticTileDistribution_
>&
src_tile
,
sequence
<
SliceBegins
...
>
slice_begins
,
sequence
<
SliceEnds
...
>
slice_ends
)
{
using
DstDistribution
=
remove_cvref_t
<
DstStaticTileDistribution_
>
;
constexpr
auto
sliced_dstr_yidx_ylen
=
detail
::
slice_distribution_from_x
(
DstDistribution
{},
slice_begins
,
slice_ends
);
constexpr
auto
sliced_dstr
=
sliced_dstr_yidx_ylen
.
template
at
<
0
>();
constexpr
auto
sliced_y_origins
=
sliced_dstr_yidx_ylen
.
template
at
<
1
>();
constexpr
auto
sliced_y_lengths
=
sliced_dstr_yidx_ylen
.
template
at
<
2
>();
static_assert
(
std
::
is_same_v
<
decltype
(
sliced_dstr
),
DstDistribution
>
,
"wrong!"
);
dst_tile
.
SetSlicedThreadData
(
sliced_y_origins
,
sliced_y_lengths
,
src_tile
.
get_thread_buffer
());
}
}
// namespace ck_tile
include/ck_tile/core/tensor/static_distributed_tensor.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
namespace
ck_tile
{
template
<
typename
DataType_
,
typename
StaticTileDistribution_
>
struct
static_distributed_tensor
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
using
StaticTileDistribution
=
remove_cvref_t
<
StaticTileDistribution_
>
;
static_assert
(
StaticTileDistribution
::
is_static
(),
"wrong! StaticTileDistribution should be known at compile tile"
);
using
ThreadTensorDesc
=
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
{
return
StaticTileDistribution
::
get_num_of_dimension_x
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lengths
()
{
return
StaticTileDistribution
::
get_lengths
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_tile_distribution
()
{
return
StaticTileDistribution
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_distributed_spans
()
{
return
StaticTileDistribution
::
get_distributed_spans
();
}
CK_TILE_HOST_DEVICE
void
initialize
(
const
DataType
&
x
)
{
thread_buf_
.
initialize
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_thread_buffer
()
const
{
return
thread_buf_
;
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_thread_buffer
()
{
return
thread_buf_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_thread_buffer_size
()
{
return
kThreadElementSpaceSize
;
}
template
<
index_t
...
YSliceOrigins
,
index_t
...
YSliceLengths
>
CK_TILE_HOST_DEVICE
auto
get_y_sliced_thread_data
(
sequence
<
YSliceOrigins
...
>
,
sequence
<
YSliceLengths
...
>
)
const
{
static_assert
(
sizeof
...(
YSliceOrigins
)
==
StaticTileDistribution
::
NDimY
&&
sizeof
...(
YSliceLengths
)
==
StaticTileDistribution
::
NDimY
,
"wrong!"
);
constexpr
auto
sliced_thread_tensor_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
YSliceLengths
...));
thread_buffer
<
DataType
,
sliced_thread_tensor_desc
.
get_element_space_size
()
>
sliced_thread_data
;
static_ford
<
sequence
<
YSliceLengths
...
>>
{}([
&
](
auto
idx
)
{
constexpr
auto
idx_ys
=
idx
+
sequence
<
YSliceOrigins
...
>
{};
sliced_thread_data
(
number
<
sliced_thread_tensor_desc
.
calculate_offset
(
idx
)
>
{})
=
thread_buf_
[
number
<
ThreadTensorDesc
{}.
calculate_offset
(
idx_ys
)
>
{}];
});
return
sliced_thread_data
;
}
template
<
index_t
...
YSliceOrigins
,
index_t
...
YSliceLengths
,
typename
SlicedThreadData
>
CK_TILE_HOST_DEVICE
void
set_y_sliced_thread_data
(
sequence
<
YSliceOrigins
...
>
,
sequence
<
YSliceLengths
...
>
,
const
SlicedThreadData
&
sliced_thread_data
)
{
static_assert
(
sizeof
...(
YSliceOrigins
)
==
StaticTileDistribution
::
NDimY
&&
sizeof
...(
YSliceLengths
)
==
StaticTileDistribution
::
NDimY
,
"wrong!"
);
constexpr
auto
sliced_thread_tensor_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
YSliceLengths
...));
static_ford
<
sequence
<
YSliceLengths
...
>>
{}([
&
](
auto
idx
)
{
constexpr
auto
idx_ys
=
idx
+
sequence
<
YSliceOrigins
...
>
{};
thread_buf_
(
number
<
ThreadTensorDesc
{}.
calculate_offset
(
idx_ys
)
>
{})
=
sliced_thread_data
[
number
<
sliced_thread_tensor_desc
.
calculate_offset
(
idx
)
>
{}];
});
}
template
<
typename
TileDistributedIndices
>
CK_TILE_HOST_DEVICE
constexpr
const
DataType
&
operator
[](
TileDistributedIndices
)
const
{
static_assert
(
is_static_v
<
TileDistributedIndices
>
,
"wrong! Tile Distributed Indices should be static"
);
constexpr
auto
y_idx
=
get_tile_distribution
().
get_y_indices_from_distributed_indices
(
TileDistributedIndices
{});
return
thread_buf_
[
number
<
ThreadTensorDesc
{}.
calculate_offset
(
y_idx
)
>
{}];
}
template
<
typename
TileDistributedIndices
>
CK_TILE_HOST_DEVICE
constexpr
DataType
&
operator
()(
TileDistributedIndices
)
{
static_assert
(
is_static_v
<
TileDistributedIndices
>
,
"wrong! Tile Distributed Indices should be static"
);
constexpr
auto
y_idx
=
get_tile_distribution
().
get_y_indices_from_distributed_indices
(
TileDistributedIndices
{});
return
thread_buf_
(
number
<
ThreadTensorDesc
{}.
calculate_offset
(
y_idx
)
>
{});
}
//
thread_buffer
<
DataType
,
kThreadElementSpaceSize
>
thread_buf_
;
};
template
<
typename
DataType
,
typename
StaticTileDistribution
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_static_distributed_tensor
(
const
StaticTileDistribution
&
)
{
return
static_distributed_tensor
<
remove_cvref_t
<
DataType
>
,
remove_cvref_t
<
StaticTileDistribution
>>
{};
}
template
<
typename
DataType
,
typename
StaticTileDistribution
,
typename
ThreadBuffer
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_static_distributed_tensor
(
const
StaticTileDistribution
&
,
ThreadBuffer
&&
thread_buffer_
)
{
return
static_distributed_tensor
<
remove_cvref_t
<
DataType
>
,
remove_cvref_t
<
StaticTileDistribution
>>
{
thread_buffer_
};
}
// get X indices from tuple of tile_distributed_index<>
template
<
typename
StaticTileDistribution
,
typename
DistributedIndices
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_x_indices_from_distributed_indices
(
StaticTileDistribution
tile_distribution
,
DistributedIndices
distributed_indices
)
{
const
auto
partition_index
=
detail
::
get_partition_index
(
tile_distribution
);
constexpr
auto
y_indices
=
tile_distribution
.
get_y_indices_from_distributed_indices
(
distributed_indices
);
const
auto
x_coord
=
make_tensor_adaptor_coordinate
(
tile_distribution
.
get_ps_ys_to_xs_adaptor
(),
container_concat
(
partition_index
,
to_array
<
ck_tile
::
index_t
,
y_indices
.
size
()
>
(
y_indices
)));
return
x_coord
.
get_bottom_index
();
}
template
<
typename
DataType
,
typename
StaticTileDistribution
,
typename
XIndicesPredicate
>
CK_TILE_HOST_DEVICE
void
set_tile_if
(
static_distributed_tensor
<
DataType
,
StaticTileDistribution
>&
out_tensor
,
DataType
value
,
XIndicesPredicate
predicate
)
{
constexpr
auto
out_spans
=
static_distributed_tensor
<
DataType
,
StaticTileDistribution
>::
get_distributed_spans
();
sweep_tile_span
(
out_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
out_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
distributed_indices
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
StaticTileDistribution
{},
distributed_indices
);
if
(
predicate
(
x_indices
))
{
out_tensor
(
distributed_indices
)
=
value
;
}
});
});
}
}
// namespace ck_tile
include/ck_tile/core/tensor/store_tile.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
DataType_
>
,
DataType
>
,
"wrong!"
);
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
tile_window
=
make_tile_window
(
tile_window_tmp
.
get_bottom_tensor_view
(),
tile_window_tmp
.
get_window_lengths
(),
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_window
.
store
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
DataType_
>
,
DataType
>
,
"wrong!"
);
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
tile_window
=
make_tile_window
(
tile_window_tmp
.
get_bottom_tensor_view
(),
tile_window_tmp
.
get_window_lengths
(),
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_window
.
store_raw
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
store
(
dstr_tensor
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
store_raw
(
dstr_tensor
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/sweep_tile.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// sweep over a span of a distribted tile and apply lambda function F
template
<
typename
TileDistributedSpan_
,
// tile_distributed_span<...>
typename
F
// signature: F(tile_distributed_index<...>)
>
CK_TILE_DEVICE
void
sweep_tile_span
(
TileDistributedSpan_
,
const
F
&
f
)
{
using
DstrSpan
=
remove_cvref_t
<
TileDistributedSpan_
>
;
static_ford
<
typename
DstrSpan
::
Impl
>
{}([
&
](
auto
dstr_idx_impl
)
{
constexpr
auto
dstr_idx
=
detail
::
make_tile_distributed_index
(
dstr_idx_impl
);
f
(
dstr_idx
);
});
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tensor_adaptor.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
namespace
ck_tile
{
// Transforms: Tuple<transforms...>
// LowerDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// UpperDimensionHiddenIdss : Tuple<Sequence<...>, ...>
// BottomDimensionHiddenIds : Sequence<...>
// TopDimensionHiddenIds : Sequence<...>
template
<
typename
Transforms
,
typename
LowerDimensionHiddenIdss
,
typename
UpperDimensionHiddenIdss
,
typename
BottomDimensionHiddenIds
,
typename
TopDimensionHiddenIds
>
struct
tensor_adaptor
{
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_transform
()
{
return
Transforms
::
size
();
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_transforms
()
const
{
return
transforms_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lower_dimension_hidden_idss
()
{
return
LowerDimensionHiddenIdss
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_upper_dimension_hidden_idss
()
{
return
UpperDimensionHiddenIdss
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_bottom_dimension_hidden_ids
()
{
return
BottomDimensionHiddenIds
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_top_dimension_hidden_ids
()
{
return
TopDimensionHiddenIds
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
initialize_element_size
(
const
Transforms
&
transforms
)
{
const
auto
lengths
=
generate_tuple
(
[
&
](
auto
idim_top
)
{
constexpr
index_t
idim_hidden
=
TopDimensionHiddenIds
::
at
(
idim_top
);
constexpr
auto
tmp
=
get_transform_and_its_upper_dimension
(
number
<
idim_hidden
>
{});
constexpr
index_t
itran
=
tmp
[
number
<
0
>
{}];
constexpr
index_t
idim_up
=
tmp
[
number
<
1
>
{}];
constexpr
bool
found
=
tmp
[
number
<
2
>
{}];
static_assert
(
found
==
true
,
"wrong! not found matching transformation and upper-dimension"
);
const
auto
length
=
transforms
[
number
<
itran
>
{}].
get_upper_lengths
()[
number
<
idim_up
>
{}];
return
length
;
},
number
<
ndim_top_
>
{});
// TODO: make container_reduce support tuple of number and index_t
return
container_reduce
(
lengths
,
multiplies
{},
number
<
1
>
{});
}
template
<
index_t
IDimHidden
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_transform_and_its_upper_dimension
(
number
<
IDimHidden
>
)
{
// FIXME: length of bottom dimension is not known, since info about lower dim length are not
// saved in transformation
static_assert
(
IDimHidden
>=
ndim_bottom_
,
"wrong! not implemented"
);
index_t
itran_found
=
0
;
index_t
idim_up_found
=
0
;
bool
found
=
false
;
static_for
<
0
,
ntransform_
,
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
up_dim_ids
=
UpperDimensionHiddenIdss
{}[
itran
];
static_for
<
0
,
up_dim_ids
.
size
(),
1
>
{}([
&
](
auto
idim_up
)
{
if
constexpr
(
up_dim_ids
[
idim_up
]
==
IDimHidden
)
{
itran_found
=
itran
;
idim_up_found
=
idim_up
;
found
=
true
;
}
});
});
return
make_tuple
(
itran_found
,
idim_up_found
,
found
);
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_bottom_dimension
()
{
return
BottomDimensionHiddenIds
::
size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_top_dimension
()
{
return
TopDimensionHiddenIds
::
size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_hidden_dimension
()
{
constexpr
auto
all_low_dim_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
LowerDimensionHiddenIdss
{});
constexpr
auto
all_up_dim_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
UpperDimensionHiddenIdss
{});
constexpr
auto
all_dim_ids
=
merge_sequences
(
all_low_dim_ids
,
all_up_dim_ids
);
using
unique_sort_all_dim_ids
=
typename
sequence_unique_sort
<
decltype
(
all_dim_ids
),
less
<
index_t
>
,
equal
<
index_t
>>::
type
;
return
unique_sort_all_dim_ids
::
size
();
}
constexpr
static
index_t
ntransform_
=
get_num_of_transform
();
constexpr
static
index_t
ndim_hidden_
=
get_num_of_hidden_dimension
();
constexpr
static
index_t
ndim_bottom_
=
get_num_of_bottom_dimension
();
constexpr
static
index_t
ndim_top_
=
get_num_of_top_dimension
();
using
HiddenIndex
=
multi_index
<
ndim_hidden_
>
;
using
BottomIndex
=
multi_index
<
ndim_bottom_
>
;
using
TopIndex
=
multi_index
<
ndim_top_
>
;
// may be index_t or number<>
using
ElementSize
=
remove_cv_t
<
decltype
(
initialize_element_size
(
Transforms
{}))
>
;
public:
CK_TILE_HOST_DEVICE
constexpr
tensor_adaptor
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
tensor_adaptor
(
const
Transforms
&
transforms
)
:
transforms_
{
transforms
},
element_size_
{
initialize_element_size
(
transforms
)}
{
static_assert
(
Transforms
::
size
()
==
ntransform_
&&
LowerDimensionHiddenIdss
::
size
()
==
ntransform_
&&
UpperDimensionHiddenIdss
::
size
()
==
ntransform_
,
"wrong! inconsistent # of transformations"
);
// TODO check dependency of dimensions is valid
}
CK_TILE_HOST_DEVICE
constexpr
auto
get_element_size
()
const
{
return
element_size_
;
}
// FIXME: this logic is wrong when getting bottome dimension lengths
template
<
index_t
IDimHidden
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_hidden_dimension_length
(
number
<
IDimHidden
>
)
const
{
static_assert
(
IDimHidden
>=
0
&&
IDimHidden
<
ndim_hidden_
,
"wrong! out of range"
);
constexpr
auto
tmp
=
get_transform_and_its_upper_dimension
(
number
<
IDimHidden
>
{});
constexpr
index_t
itran
=
tmp
[
number
<
0
>
{}];
constexpr
index_t
idim_up
=
tmp
[
number
<
1
>
{}];
constexpr
bool
found
=
tmp
[
number
<
2
>
{}];
static_assert
(
found
==
true
,
"wrong! not found matching transformation and upper-dimension"
);
return
transforms_
[
number
<
itran
>
{}].
get_upper_lengths
()[
number
<
idim_up
>
{}];
}
template
<
index_t
IDimTop
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_top_dimension_length
(
number
<
IDimTop
>
idim_top
)
const
{
return
get_hidden_dimension_length
(
TopDimensionHiddenIds
::
at
(
idim_top
));
}
#if 0
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
template <index_t IDimBottom>
CK_TILE_HOST_DEVICE constexpr index_t
get_bottom_dimension_length(number<IDimBottom> idim_bottom) const
{
return get_hidden_dimension_length(TopDimensionHiddenIds::at(idim_bottom));
}
#endif
CK_TILE_HOST_DEVICE
constexpr
auto
get_top_dimension_lengths
()
const
{
return
generate_tuple
([
&
](
auto
i
)
{
return
get_top_dimension_length
(
i
);
},
number
<
ndim_top_
>
{});
}
#if 0
// FIXME: get_hidden_dimension_length is wrong when getting bottome dimension lengths
CK_TILE_HOST_DEVICE constexpr auto GetBottomDimensionLengths() const
{
return generate_tuple([&](auto i) { return get_bottom_dimension_length(i); },
number<ndim_bottom_>{});
}
#endif
template
<
typename
TopIdx
>
CK_TILE_HOST_DEVICE
constexpr
auto
calculate_bottom_index
(
const
TopIdx
&
idx_top
)
const
{
static_assert
(
TopIdx
::
size
()
==
TopDimensionHiddenIds
::
size
(),
"wrong! # of dimension inconsistent"
);
constexpr
index_t
ntransform
=
get_num_of_transform
();
constexpr
index_t
ndim_hidden
=
get_num_of_hidden_dimension
();
multi_index
<
ndim_hidden
>
idx_hidden
;
// initialize uppest index
set_container_subset
(
idx_hidden
,
get_top_dimension_hidden_ids
(),
idx_top
);
// calculate hidden index
static_for
<
ntransform
,
0
,
-
1
>
{}([
&
](
auto
itran_p1
)
{
auto
itran
=
itran_p1
-
number
<
1
>
{};
const
auto
&
tran
=
get_transforms
().
at
(
itran
);
constexpr
auto
dims_low
=
get_lower_dimension_hidden_idss
().
at
(
itran
);
constexpr
auto
dims_up
=
get_upper_dimension_hidden_idss
().
at
(
itran
);
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
dims_up
);
multi_index
<
dims_low
.
size
()
>
idx_low
;
tran
.
calculate_lower_index
(
idx_low
,
idx_up
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
});
return
get_container_subset
(
idx_hidden
,
BottomDimensionHiddenIds
{});
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_static
()
{
bool
is_known
=
true
;
static_for
<
0
,
Transforms
::
size
(),
1
>
{}([
&
](
auto
i
)
{
is_known
&=
remove_cvref_t
<
decltype
(
Transforms
{}[
i
])
>::
is_known_at_compile_time
();
});
return
is_known
&&
ck_tile
::
is_known_at_compile_time
<
ElementSize
>::
value
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
is_static
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_top_dimension_safe_vector_length_strides
(
const
array
<
index_t
,
ndim_hidden_
>&
guaranteed_vector_lengths
,
const
array
<
index_t
,
ndim_hidden_
>&
guaranteed_vector_strides
)
{
auto
vector_lengths
=
guaranteed_vector_lengths
;
auto
vector_strides
=
guaranteed_vector_strides
;
static_for
<
0
,
get_num_of_transform
(),
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
low_dims
=
get_lower_dimension_hidden_idss
().
at
(
itran
);
constexpr
auto
up_dims
=
get_upper_dimension_hidden_idss
().
at
(
itran
);
const
auto
up_guaranteed_vector_lengths
=
get_container_subset
(
guaranteed_vector_lengths
,
up_dims
);
const
auto
up_guaranteed_vector_strides
=
get_container_subset
(
guaranteed_vector_strides
,
up_dims
);
// only need type of transform
auto
[
up_vector_lengths
,
up_vector_strides
]
=
Transforms
{}.
at
(
itran
).
calculate_upper_dimension_safe_vector_length_strides
(
get_container_subset
(
vector_lengths
,
low_dims
),
get_container_subset
(
vector_strides
,
low_dims
));
if
constexpr
(
up_dims
.
size
()
>
0
)
{
for
(
index_t
i
=
0
;
i
<
up_dims
.
size
();
++
i
)
{
up_vector_lengths
(
i
)
=
(
up_guaranteed_vector_lengths
[
i
]
!=
-
1
)
?
up_guaranteed_vector_lengths
[
i
]
:
up_vector_lengths
[
i
];
up_vector_strides
(
i
)
=
(
up_guaranteed_vector_strides
[
i
]
!=
-
1
)
?
up_guaranteed_vector_strides
[
i
]
:
up_vector_strides
[
i
];
}
}
set_container_subset
(
vector_lengths
,
up_dims
,
up_vector_lengths
);
set_container_subset
(
vector_strides
,
up_dims
,
up_vector_strides
);
});
constexpr
auto
top_dims
=
TopDimensionHiddenIds
{};
return
make_tuple
(
get_container_subset
(
vector_lengths
,
top_dims
),
get_container_subset
(
vector_strides
,
top_dims
));
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"tensor_adaptor{"
);
//
printf
(
"transforms: "
);
print
(
transforms_
);
printf
(
", "
);
//
printf
(
"LowerDimensionHiddenIds: "
);
print
(
LowerDimensionHiddenIdss
{});
printf
(
", "
);
//
printf
(
"UpperDimensionHiddenIds: "
);
print
(
UpperDimensionHiddenIdss
{});
printf
(
", "
);
//
printf
(
"BottomDimensionHiddenIds: "
);
print
(
BottomDimensionHiddenIds
{});
printf
(
", "
);
//
printf
(
"TopDimensionHiddenIds: "
);
print
(
TopDimensionHiddenIds
{});
printf
(
"}"
);
}
private:
Transforms
transforms_
;
ElementSize
element_size_
;
};
// Transforms: Tuple<transforms...>
// LowerDimensionOldTopIdss: Tuple<Sequence<...>, ...>
// UpperDimensionNewTopIdss: Tuple<Sequence<...>, ...>
template
<
typename
Transforms
,
typename
LowerDimensionOldTopIdss
,
typename
UpperDimensionNewTopIdss
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_single_stage_tensor_adaptor
(
const
Transforms
&
transforms
,
LowerDimensionOldTopIdss
,
UpperDimensionNewTopIdss
)
{
constexpr
index_t
ntransform
=
Transforms
::
size
();
static_assert
(
LowerDimensionOldTopIdss
::
size
()
==
ntransform
&&
UpperDimensionNewTopIdss
::
size
()
==
ntransform
,
"wrong!"
);
// sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss
constexpr
auto
all_low_dim_old_top_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
LowerDimensionOldTopIdss
{});
constexpr
auto
all_up_dim_new_top_ids
=
unpack
(
[](
auto
&&
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
UpperDimensionNewTopIdss
{});
static_assert
(
is_valid_sequence_map
<
decltype
(
all_low_dim_old_top_ids
)
>::
value
&&
is_valid_sequence_map
<
decltype
(
all_up_dim_new_top_ids
)
>::
value
,
"wrong!"
);
constexpr
index_t
ndim_old_top
=
all_low_dim_old_top_ids
.
size
();
constexpr
index_t
ndim_new_top
=
all_up_dim_new_top_ids
.
size
();
// low_dim_hidden_idss
constexpr
auto
low_dim_hidden_idss
=
LowerDimensionOldTopIdss
{};
// up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom
constexpr
auto
up_dim_hidden_idss
=
generate_tuple
(
[](
auto
itran
)
{
return
UpperDimensionNewTopIdss
{}[
itran
]
+
number
<
ndim_old_top
>
{};
},
number
<
ntransform
>
{});
// bottom_dim_hidden_ids
constexpr
auto
bottom_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
0
,
ndim_old_top
,
1
>::
type
{};
// top_dim_hidden_ids
constexpr
auto
top_dim_hidden_ids
=
typename
arithmetic_sequence_gen
<
0
,
ndim_new_top
,
1
>::
type
{}
+
number
<
ndim_old_top
>
{};
return
tensor_adaptor
<
remove_cvref_t
<
Transforms
>
,
remove_cvref_t
<
decltype
(
low_dim_hidden_idss
)
>
,
remove_cvref_t
<
decltype
(
up_dim_hidden_idss
)
>
,
remove_cvref_t
<
decltype
(
bottom_dim_hidden_ids
)
>
,
remove_cvref_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
transforms
};
}
// TODO: How to fix this? It uses an struct instead of lambda because lambda
// doesn't have constructor, and to put it outside the scope where it is used
// (transform_tensor_adaptor) because template cannot be defined inside a function
// template
template
<
typename
NewTransforms
>
struct
lambda_get_up_dim_num
{
template
<
typename
I
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
()(
I
)
const
{
using
Tran
=
remove_reference_t
<
decltype
(
NewTransforms
{}.
at
(
I
{}))
>
;
return
number
<
Tran
::
get_num_of_upper_dimension
()
>
{};
}
};
template
<
typename
OldTensorAdaptor
,
typename
NewTransforms
,
typename
NewLowerDimensionOldTopIdss
,
typename
NewUpperDimensionNewTopIdss
>
CK_TILE_HOST_DEVICE
constexpr
auto
transform_tensor_adaptor
(
const
OldTensorAdaptor
&
old_tensor_adaptor
,
const
NewTransforms
&
new_transforms
,
NewLowerDimensionOldTopIdss
,
NewUpperDimensionNewTopIdss
)
{
// sanity check
{
static_assert
(
NewTransforms
::
size
()
==
NewLowerDimensionOldTopIdss
::
size
()
&&
NewTransforms
::
size
()
==
NewUpperDimensionNewTopIdss
::
size
(),
"wrong! inconsitent number of transform"
);
constexpr
auto
all_old_top_ids
=
unpack
([](
auto
...
xs
)
{
return
merge_sequences
(
xs
...);
},
NewLowerDimensionOldTopIdss
{});
constexpr
auto
all_new_top_ids
=
unpack
([](
auto
...
xs
)
{
return
merge_sequences
(
xs
...);
},
NewUpperDimensionNewTopIdss
{});
static_assert
(
is_valid_sequence_map
<
decltype
(
all_old_top_ids
)
>::
value
&&
is_valid_sequence_map
<
decltype
(
all_new_top_ids
)
>::
value
,
"wrong!"
);
}
// lower dimension's hidden idss
// convert lower dimension top idss (tuple of sequences) to hidden idss (tuple of
// sequences)
constexpr
auto
low_dim_hidden_idss
=
transform_tuples
(
// convert lower dimension top ids (a sequence) to hidden ids (a sequence)
[](
auto
low_dim_top_ids
)
constexpr
{
return
transform_sequences
(
// convert lower dimension top id to hidden id
[](
auto
low_dim_top_id
)
constexpr
{
return
OldTensorAdaptor
::
get_top_dimension_hidden_ids
()[
low_dim_top_id
];
},
low_dim_top_ids
);
},
NewLowerDimensionOldTopIdss
{});
constexpr
index_t
num_new_transform
=
NewTransforms
::
size
();
// upper dimension's hidden idss
constexpr
index_t
old_hidden_dim_number
=
OldTensorAdaptor
::
get_num_of_hidden_dimension
();
constexpr
auto
up_dim_numbers
=
generate_sequence
(
lambda_get_up_dim_num
<
NewTransforms
>
{},
number
<
num_new_transform
>
{});
constexpr
auto
up_dim_numbers_scan
=
merge_sequences
(
sequence
<
0
>
{},
inclusive_scan_sequence
(
up_dim_numbers
,
plus
<
index_t
>
{},
number
<
0
>
{}));
constexpr
auto
up_dim_hidden_idss
=
generate_tuple
(
[
old_hidden_dim_number
,
up_dim_numbers_scan
](
auto
i
)
constexpr
{
return
typename
arithmetic_sequence_gen
<
old_hidden_dim_number
+
up_dim_numbers_scan
[
i
],
old_hidden_dim_number
+
up_dim_numbers_scan
[
i
+
1
],
1
>::
type
{};
},
number
<
num_new_transform
>
{});
// new top dimension's hidden ids
constexpr
auto
unordered_new_top_dim_hidden_ids
=
unpack
(
[](
auto
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
up_dim_hidden_idss
);
constexpr
auto
new_top_dim_unordered2ordered
=
unpack
(
[](
auto
...
xs
)
constexpr
{
return
merge_sequences
(
xs
...);
},
NewUpperDimensionNewTopIdss
{});
constexpr
auto
new_top_dim_hidden_ids
=
unordered_new_top_dim_hidden_ids
.
reorder_old_to_new
(
new_top_dim_unordered2ordered
);
// put everything together
const
auto
all_transforms
=
container_concat
(
old_tensor_adaptor
.
get_transforms
(),
new_transforms
);
constexpr
auto
all_low_dim_hidden_idss
=
container_concat
(
OldTensorAdaptor
::
get_lower_dimension_hidden_idss
(),
low_dim_hidden_idss
);
constexpr
auto
all_up_dim_hidden_idss
=
container_concat
(
OldTensorAdaptor
::
get_upper_dimension_hidden_idss
(),
up_dim_hidden_idss
);
return
tensor_adaptor
<
remove_cvref_t
<
decltype
(
all_transforms
)
>
,
remove_cvref_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cvref_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cvref_t
<
decltype
(
OldTensorAdaptor
::
get_bottom_dimension_hidden_ids
())
>
,
remove_cvref_t
<
decltype
(
new_top_dim_hidden_ids
)
>>
{
all_transforms
};
}
template
<
typename
TensorAdaptor0
,
typename
TensorAdaptor1
>
CK_TILE_HOST_DEVICE
constexpr
auto
chain_tensor_adaptors
(
const
TensorAdaptor0
&
adaptor0
,
const
TensorAdaptor1
&
adaptor1
)
{
static_assert
(
TensorAdaptor0
::
get_num_of_top_dimension
()
==
TensorAdaptor1
::
get_num_of_bottom_dimension
(),
"wrong!"
);
// all_transforms = transform0 + transform1
const
auto
all_transforms
=
container_concat
(
adaptor0
.
get_transforms
(),
adaptor1
.
get_transforms
());
// shift
constexpr
index_t
adaptor0_max_hidden_id
=
[
&
]()
{
index_t
adaptor0_max_hidden_id_
=
numeric
<
index_t
>::
min
();
static_for
<
0
,
TensorAdaptor0
::
get_num_of_transform
(),
1
>
{}([
&
](
auto
itran
)
{
constexpr
index_t
ndim_low
=
TensorAdaptor0
{}.
get_transforms
()[
itran
].
get_num_of_lower_dimension
();
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
idim_low
)
{
adaptor0_max_hidden_id_
=
max
(
adaptor0_max_hidden_id_
,
TensorAdaptor0
::
get_lower_dimension_hidden_idss
()[
itran
][
idim_low
].
value
);
});
constexpr
index_t
ndim_up
=
TensorAdaptor0
{}.
get_transforms
()[
itran
].
get_num_of_upper_dimension
();
static_for
<
0
,
ndim_up
,
1
>
{}([
&
](
auto
idim_up
)
{
adaptor0_max_hidden_id_
=
max
(
adaptor0_max_hidden_id_
,
TensorAdaptor0
::
get_upper_dimension_hidden_idss
()[
itran
][
idim_up
].
value
);
});
});
return
adaptor0_max_hidden_id_
;
}();
constexpr
index_t
adaptor1_min_hidden_id
=
[
&
]()
{
index_t
adaptor1_min_hidden_id_
=
numeric
<
index_t
>::
max
();
static_for
<
0
,
TensorAdaptor1
::
get_num_of_transform
(),
1
>
{}([
&
](
auto
itran
)
{
constexpr
index_t
ndim_low
=
TensorAdaptor1
{}.
get_transforms
()[
itran
].
get_num_of_lower_dimension
();
// get the min of all lower dimenions, but not bottom dimension (because their id will
// be matched with top id from adaptor0)
static_for
<
0
,
ndim_low
,
1
>
{}([
&
](
auto
idim_low
)
{
constexpr
index_t
low_dim_hidden_id
=
TensorAdaptor1
::
get_lower_dimension_hidden_idss
()[
itran
][
idim_low
].
value
;
bool
is_bottom_dim
=
false
;
static_for
<
0
,
TensorAdaptor1
::
get_num_of_bottom_dimension
(),
1
>
{}([
&
](
auto
i
)
{
if
constexpr
(
low_dim_hidden_id
==
TensorAdaptor1
::
get_bottom_dimension_hidden_ids
()[
i
])
{
is_bottom_dim
=
true
;
}
});
if
(
!
is_bottom_dim
)
{
adaptor1_min_hidden_id_
=
min
(
adaptor1_min_hidden_id_
,
low_dim_hidden_id
);
}
});
constexpr
index_t
ndim_up
=
TensorAdaptor1
{}.
get_transforms
()[
itran
].
get_num_of_upper_dimension
();
// get the min of all upper dimensions
static_for
<
0
,
ndim_up
,
1
>
{}([
&
](
auto
idim_up
)
{
adaptor1_min_hidden_id_
=
min
(
adaptor1_min_hidden_id_
,
TensorAdaptor1
::
get_upper_dimension_hidden_idss
()[
itran
][
idim_up
].
value
);
});
});
return
adaptor1_min_hidden_id_
;
}();
constexpr
index_t
adaptor1_hidden_id_shift
=
adaptor0_max_hidden_id
+
1
-
adaptor1_min_hidden_id
;
constexpr
index_t
ndim_bottom_1
=
TensorAdaptor1
::
get_num_of_bottom_dimension
();
// all_low_dim_hidden_idss =
// low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1))
constexpr
auto
low_dim_hidden_idss_1
=
generate_tuple
(
// generate sequence of ids for a transform
[
&
](
auto
itran
)
{
constexpr
auto
ndim_low_1
=
TensorAdaptor1
::
get_lower_dimension_hidden_idss
()[
itran
].
size
();
constexpr
auto
low_dim_hidden_ids_1
=
TensorAdaptor1
::
get_lower_dimension_hidden_idss
()[
itran
];
// sequence in, sequence out
constexpr
auto
low_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
auto
low_dim_hidden_ids_1_mod_
=
to_multi_index
(
low_dim_hidden_ids_1
);
// shift hidden id so every dim id is unique
static_for
<
0
,
ndim_low_1
,
1
>
{}([
&
](
auto
idim_low_1
)
{
low_dim_hidden_ids_1_mod_
(
idim_low_1
)
+=
adaptor1_hidden_id_shift
;
});
// match hidden id
static_for
<
0
,
ndim_low_1
,
1
>
{}([
&
](
auto
idim_low_1
)
{
static_for
<
0
,
ndim_bottom_1
,
1
>
{}([
&
](
auto
idim_bottom_1
)
{
// if this low dim is bottom dim, then do id matching
if
constexpr
(
low_dim_hidden_ids_1
[
idim_low_1
]
==
TensorAdaptor1
::
get_bottom_dimension_hidden_ids
()
[
idim_bottom_1
])
{
low_dim_hidden_ids_1_mod_
(
idim_low_1
)
=
TensorAdaptor0
::
get_top_dimension_hidden_ids
()[
idim_bottom_1
];
}
});
});
return
low_dim_hidden_ids_1_mod_
;
}
();
return
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
return
number
<
low_dim_hidden_ids_1_mod
[
i
]
>
{};
},
number
<
ndim_low_1
>
{});
},
number
<
TensorAdaptor1
::
get_num_of_transform
()
>
{});
constexpr
auto
all_low_dim_hidden_idss
=
container_concat
(
TensorAdaptor0
::
get_lower_dimension_hidden_idss
(),
low_dim_hidden_idss_1
);
// all_up_dim_hidden_idss =
// up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1)
constexpr
auto
up_dim_hidden_idss_1
=
generate_tuple
(
// generate sequence of ids for a transform
[
&
](
auto
itran
)
{
constexpr
auto
ndim_up_1
=
TensorAdaptor1
::
get_upper_dimension_hidden_idss
()[
itran
].
size
();
constexpr
auto
up_dim_hidden_ids_1
=
TensorAdaptor1
::
get_upper_dimension_hidden_idss
()[
itran
];
// sequence in, constexpr tuple out
constexpr
auto
up_dim_hidden_ids_1_mod
=
[
&
]()
constexpr
{
auto
up_dim_hidden_ids_1_mod_
=
to_multi_index
(
up_dim_hidden_ids_1
);
// shift hidden id
static_for
<
0
,
ndim_up_1
,
1
>
{}([
&
](
auto
idim_up_1
)
{
up_dim_hidden_ids_1_mod_
(
idim_up_1
)
+=
adaptor1_hidden_id_shift
;
});
return
up_dim_hidden_ids_1_mod_
;
}
();
// constexpr tuple to sequence
return
generate_sequence_v2
(
[
&
](
auto
i
)
constexpr
{
return
number
<
up_dim_hidden_ids_1_mod
[
i
]
>
{};
},
number
<
ndim_up_1
>
{});
},
number
<
TensorAdaptor1
::
get_num_of_transform
()
>
{});
constexpr
auto
all_up_dim_hidden_idss
=
container_concat
(
TensorAdaptor0
::
get_upper_dimension_hidden_idss
(),
up_dim_hidden_idss_1
);
// bottom_dim_hidden_ids = bottom_dim_hidden_ids_0
constexpr
auto
bottom_dim_hidden_ids
=
TensorAdaptor0
::
get_bottom_dimension_hidden_ids
();
// top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1)
constexpr
auto
top_dim_hidden_ids
=
TensorAdaptor1
::
get_top_dimension_hidden_ids
()
+
number
<
adaptor1_hidden_id_shift
>
{};
// put everything together
return
tensor_adaptor
<
remove_cvref_t
<
decltype
(
all_transforms
)
>
,
remove_cvref_t
<
decltype
(
all_low_dim_hidden_idss
)
>
,
remove_cvref_t
<
decltype
(
all_up_dim_hidden_idss
)
>
,
remove_cvref_t
<
decltype
(
bottom_dim_hidden_ids
)
>
,
remove_cvref_t
<
decltype
(
top_dim_hidden_ids
)
>>
{
all_transforms
};
}
template
<
typename
X
,
typename
...
Xs
,
typename
std
::
enable_if
<
sizeof
...(
Xs
)
>
=
2
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
chain_tensor_adaptors
(
const
X
&
x
,
const
Xs
&
...
xs
)
{
return
chain_tensor_adaptors
(
x
,
chain_tensor_adaptors
(
xs
...));
}
}
// namespace ck_tile
// Macro function
// construct constexpr tensor_adaptor from constexpr encoding
// encoded_tensor_adaptor are Tuple of following objects:
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
// 1.1 name (coord_transform_enum)
// 1.2 meta data for constructor of the transform
// 1.3 num of lower dimension (index_t)
// 1.4 lower dimension Ids (array of fixed size)
// 1.5 num of up dimension (index_t)
// 1.6 upper dimension Ids (array of fixed size)
// 2. num of transforms (index_t)
// 3. encoded bottom dimension Ids (array of fixed size)
// 4. num of bottom dimension (index_t)
// 5. encoded top dimension Ids (array of fixed size)
// 6. num of top dimension (index_t)
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
[encoded_tensor_adaptor]() { \
using namespace ck_tile; \
\
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
\
static_assert(name == coord_transform_enum::pass_through || \
name == coord_transform_enum::pad || \
name == coord_transform_enum::embed || \
name == coord_transform_enum::merge || \
name == coord_transform_enum::unmerge || \
name == coord_transform_enum::replicate, \
""); \
\
if constexpr(name == coord_transform_enum::pass_through) \
{ \
index_t pos = 0; \
auto low_len = meta_data.template pop<index_t>(pos); \
\
return make_pass_through_transform(low_len); \
} \
else if constexpr(name == coord_transform_enum::pad) \
{ \
index_t pos = 0; \
auto low_len = meta_data.template pop<index_t>(pos); \
auto left_pad = meta_data.template pop<index_t>(pos); \
auto right_pad = meta_data.template pop<index_t>(pos); \
\
return make_pad_transform(low_len, left_pad, right_pad); \
} \
else if constexpr(name == coord_transform_enum::embed) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
auto coefficients = \
meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_embed_transform(up_lens, coefficients); \
} \
else if constexpr(name == coord_transform_enum::merge) \
{ \
index_t pos = 0; \
auto low_lens = meta_data.template pop<array<index_t, num_low_dim>>(pos); \
\
return make_merge_transform(low_lens); \
} \
else if constexpr(name == coord_transform_enum::unmerge) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_unmerge_transform(up_lens); \
} \
else if constexpr(name == coord_transform_enum::replicate) \
{ \
index_t pos = 0; \
auto up_lens = meta_data.template pop<array<index_t, num_up_dim>>(pos); \
\
return make_replicate_transform(up_lens); \
} \
}, \
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms, &num_transform]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
\
return TO_SEQUENCE(low_dims, num_low_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms, &num_transform] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
\
return TO_SEQUENCE(up_dims, num_up_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
\
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
remove_cvref_t<decltype(low_dim_idss)>, \
remove_cvref_t<decltype(up_dim_idss)>, \
remove_cvref_t<decltype(bottom_dim_ids)>, \
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
}()
// Macro function
// construct static tensor_adaptor from constexpr encoding
// encoded_tensor_adaptor are Tuple of following objects:
// 1. encoded transforms (array of fixed size). Each encoded transform is a Tuple of following:
// 1.1 name (coord_transform_enum)
// 1.2 meta data for constructor of the transform
// 1.3 num of lower dimension (index_t)
// 1.4 lower dimension Ids (array of fixed size)
// 1.5 num of up dimension (index_t)
// 1.6 upper dimension Ids (array of fixed size)
// 2. num of transforms (index_t)
// 3. encoded bottom dimension Ids (array of fixed size)
// 4. num of bottom dimension (index_t)
// 5. encoded top dimension Ids (array of fixed size)
// 6. num of top dimension (index_t)
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor) \
[encoded_tensor_adaptor]() { \
using namespace ck_tile; \
\
constexpr auto encoded_transforms = encoded_tensor_adaptor.template at<0>(); \
constexpr index_t num_transform = encoded_tensor_adaptor.template at<1>(); \
constexpr auto encoded_bottom_dims = encoded_tensor_adaptor.template at<2>(); \
constexpr index_t num_bottom_dim = encoded_tensor_adaptor.template at<3>(); \
constexpr auto encoded_top_dims = encoded_tensor_adaptor.template at<4>(); \
constexpr index_t num_top_dim = encoded_tensor_adaptor.template at<5>(); \
\
constexpr auto trans = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) constexpr { \
constexpr auto name = encoded_transforms[i].template at<0>(); \
constexpr auto meta_data = encoded_transforms[i].template at<1>(); \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
\
static_assert(name == coord_transform_enum::pass_through || \
name == coord_transform_enum::pad || \
name == coord_transform_enum::embed || \
name == coord_transform_enum::merge || \
name == coord_transform_enum::unmerge || \
name == coord_transform_enum::replicate, \
""); \
\
if constexpr(name == coord_transform_enum::pass_through) \
{ \
constexpr index_t low_len = meta_data.template get<index_t>(0); \
\
return make_pass_through_transform(number<low_len>{}); \
} \
else if constexpr(name == coord_transform_enum::pad) \
{ \
constexpr index_t low_len = meta_data.template get<index_t>(0); \
\
constexpr index_t left_pad = \
meta_data.template get<index_t>(sizeof(low_len)); \
\
constexpr index_t right_pad = \
meta_data.template pop<index_t>(sizeof(low_len) + sizeof(left_pad)); \
\
return make_pad_transform( \
number<low_len>{}, number<left_pad>{}, number<right_pad>{}); \
} \
else if constexpr(name == coord_transform_enum::embed) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
constexpr auto coefficients = \
meta_data.template get<array<index_t, num_up_dim>>(sizeof(up_lens)); \
\
return make_embed_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim), \
TO_TUPLE_OF_NUMBER(coefficients, num_up_dim)); \
} \
else if constexpr(name == coord_transform_enum::merge) \
{ \
constexpr auto low_lens = \
meta_data.template get<array<index_t, num_low_dim>>(0); \
\
return make_merge_transform(TO_TUPLE_OF_NUMBER(low_lens, num_low_dim)); \
} \
else if constexpr(name == coord_transform_enum::unmerge) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
return make_unmerge_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
} \
else if constexpr(name == coord_transform_enum::replicate) \
{ \
constexpr auto up_lens = \
meta_data.template get<array<index_t, num_up_dim>>(0); \
\
return make_replicate_transform(TO_TUPLE_OF_NUMBER(up_lens, num_up_dim)); \
} \
}, \
number<num_transform>{}); \
}(); \
\
constexpr auto low_dim_idss = [&encoded_transforms]() { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_low_dim = encoded_transforms[i].template at<2>(); \
constexpr auto low_dims = encoded_transforms[i].template at<3>(); \
\
return TO_SEQUENCE(low_dims, num_low_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto up_dim_idss = [&encoded_transforms] { \
return generate_tuple( \
[&encoded_transforms](auto i) { \
constexpr auto num_up_dim = encoded_transforms[i].template at<4>(); \
constexpr auto up_dims = encoded_transforms[i].template at<5>(); \
\
return TO_SEQUENCE(up_dims, num_up_dim); \
}, \
number<num_transform>()); \
}(); \
\
constexpr auto bottom_dim_ids = TO_SEQUENCE(encoded_bottom_dims, num_bottom_dim); \
constexpr auto top_dim_ids = TO_SEQUENCE(encoded_top_dims, num_top_dim); \
\
return tensor_adaptor<remove_cvref_t<decltype(trans)>, \
remove_cvref_t<decltype(low_dim_idss)>, \
remove_cvref_t<decltype(up_dim_idss)>, \
remove_cvref_t<decltype(bottom_dim_ids)>, \
remove_cvref_t<decltype(top_dim_ids)>>{trans}; \
}()
include/ck_tile/core/tensor/tensor_adaptor_coordinate.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
index_t
NDimHidden
,
typename
BottomDimensionHiddenIds
,
typename
TopDimensionHiddenIds
>
struct
tensor_adaptor_coordinate
{
static
constexpr
index_t
ndim_bottom_
=
BottomDimensionHiddenIds
::
size
();
static
constexpr
index_t
ndim_top_
=
TopDimensionHiddenIds
::
size
();
using
HiddenIndex
=
multi_index
<
NDimHidden
>
;
using
BottomIndex
=
multi_index
<
ndim_bottom_
>
;
using
TopIndex
=
multi_index
<
ndim_top_
>
;
public:
CK_TILE_HOST_DEVICE
constexpr
tensor_adaptor_coordinate
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
tensor_adaptor_coordinate
(
const
HiddenIndex
&
idx_hidden
)
:
idx_hidden_
{
idx_hidden
}
{
}
CK_TILE_HOST_DEVICE
constexpr
auto
get_top_index
()
const
{
return
get_container_subset
(
idx_hidden_
,
TopDimensionHiddenIds
{});
}
CK_TILE_HOST_DEVICE
constexpr
auto
get_bottom_index
()
const
{
return
get_container_subset
(
idx_hidden_
,
BottomDimensionHiddenIds
{});
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_hidden_index
()
const
{
return
idx_hidden_
;
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_hidden_index
()
{
return
idx_hidden_
;
}
//
HiddenIndex
idx_hidden_
;
};
template
<
typename
Adaptor
,
typename
TopIndex
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tensor_adaptor_coordinate
(
const
Adaptor
&
adaptor
,
const
TopIndex
&
idx_top
)
{
static_assert
(
Adaptor
::
get_num_of_top_dimension
()
==
TopIndex
::
size
(),
"wrong! # of dimension inconsistent"
);
constexpr
index_t
ntransform
=
Adaptor
::
get_num_of_transform
();
constexpr
index_t
ndim_hidden
=
Adaptor
::
get_num_of_hidden_dimension
();
constexpr
auto
bottom_dim_ids
=
Adaptor
::
get_bottom_dimension_hidden_ids
();
constexpr
auto
top_dim_ids
=
Adaptor
::
get_top_dimension_hidden_ids
();
multi_index
<
ndim_hidden
>
idx_hidden
;
// initialize visible index
set_container_subset
(
idx_hidden
,
top_dim_ids
,
idx_top
);
// calculate hidden index
static_for
<
ntransform
,
0
,
-
1
>
{}([
&
adaptor
,
&
idx_hidden
](
auto
itran_p1
)
{
auto
itran
=
itran_p1
-
number
<
1
>
{};
const
auto
&
tran
=
adaptor
.
get_transforms
().
at
(
itran
);
constexpr
auto
dims_low
=
Adaptor
::
get_lower_dimension_hidden_idss
().
at
(
itran
);
constexpr
auto
dims_up
=
Adaptor
::
get_upper_dimension_hidden_idss
().
at
(
itran
);
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
dims_up
);
multi_index
<
dims_low
.
size
()
>
idx_low
;
tran
.
calculate_lower_index
(
idx_low
,
idx_up
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
});
return
tensor_adaptor_coordinate
<
ndim_hidden
,
remove_cvref_t
<
decltype
(
bottom_dim_ids
)
>
,
remove_cvref_t
<
decltype
(
top_dim_ids
)
>>
{
idx_hidden
};
}
template
<
bool
JudgeDoTransforms
=
true
,
typename
Adaptor
,
typename
AdaptorCoord
,
typename
TopIndex
,
typename
BottomIndex
>
CK_TILE_HOST_DEVICE
constexpr
void
move_tensor_adaptor_coordinate
(
const
Adaptor
&
adaptor
,
AdaptorCoord
&
coord
,
const
TopIndex
&
idx_diff_top
,
BottomIndex
&
idx_diff_bottom
)
{
constexpr
index_t
ndim_hidden
=
Adaptor
::
get_num_of_hidden_dimension
();
constexpr
index_t
ndim_top
=
Adaptor
::
get_num_of_top_dimension
();
// constexpr index_t ndim_bottom = Adaptor::get_num_of_bottom_dimension();
constexpr
index_t
ntransform
=
Adaptor
::
get_num_of_transform
();
// static_assert(TopIndex::size() == ndim_top && BottomIndex::size() == ndim_bottom, "");
// judge whether calculation of lower diff is needed for each transform
// use index_t for boolean type
auto
do_transforms
=
make_zero_multi_index
<
ntransform
>
();
if
constexpr
(
JudgeDoTransforms
)
{
auto
is_non_zero_diff
=
make_zero_multi_index
<
ndim_hidden
>
();
// decide do_transform by checkout non-zero index diff components
multi_index
<
ndim_top
>
non_zero_diff_pick_top
;
static_for
<
0
,
ndim_top
,
1
>
{}(
[
&
](
auto
i
)
{
non_zero_diff_pick_top
(
i
)
=
(
idx_diff_top
[
i
]
!=
0
);
});
set_container_subset
(
is_non_zero_diff
,
Adaptor
::
get_top_dimension_hidden_ids
(),
non_zero_diff_pick_top
);
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
constexpr
auto
dims_low
=
Adaptor
::
get_lower_dimension_hidden_idss
().
at
(
itran
);
constexpr
auto
dims_up
=
Adaptor
::
get_upper_dimension_hidden_idss
().
at
(
itran
);
const
auto
non_zero_diff_pick_up
=
get_container_subset
(
is_non_zero_diff
,
dims_up
);
multi_index
<
dims_low
.
size
()
>
non_zero_diff_pick_low
;
// if any of upper index diff components is non-zero, then
// 1) Need to do this transform
// 2) all components of lower index diff will assume to be non-zero and need to be
// computed
const
bool
idx_diff_up_has_non_zero
=
container_reduce
(
non_zero_diff_pick_up
,
[](
auto
a
,
auto
b
)
constexpr
{
return
a
or
b
;
},
false
);
do_transforms
(
itran
)
=
idx_diff_up_has_non_zero
;
static_for
<
0
,
dims_low
.
size
(),
1
>
{}(
[
&
](
auto
i
)
{
non_zero_diff_pick_low
(
i
)
=
idx_diff_up_has_non_zero
;
});
set_container_subset
(
is_non_zero_diff
,
dims_low
,
non_zero_diff_pick_low
);
});
}
else
{
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
do_transforms
(
itran
)
=
1
;
});
}
// this is what needs to be calculated
auto
idx_diff_hidden
=
make_zero_multi_index
<
ndim_hidden
>
();
// initialize top index diff
set_container_subset
(
idx_diff_hidden
,
Adaptor
::
get_top_dimension_hidden_ids
(),
idx_diff_top
);
// this is what needs to be updated
auto
&
idx_hidden
=
coord
.
get_hidden_index
();
// update top index
auto
idx_hidden_pick_top
=
get_container_subset
(
idx_hidden
,
Adaptor
::
get_top_dimension_hidden_ids
());
idx_hidden_pick_top
+=
idx_diff_top
;
set_container_subset
(
idx_hidden
,
Adaptor
::
get_top_dimension_hidden_ids
(),
idx_hidden_pick_top
);
// update rest of hidden index
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
](
auto
itran
)
{
if
(
do_transforms
[
itran
])
{
const
auto
&
tran
=
adaptor
.
get_transforms
().
at
(
itran
);
constexpr
auto
dims_low
=
Adaptor
::
get_lower_dimension_hidden_idss
().
at
(
itran
);
constexpr
auto
dims_up
=
Adaptor
::
get_upper_dimension_hidden_idss
().
at
(
itran
);
const
auto
idx_up_new
=
get_container_subset
(
idx_hidden
,
dims_up
);
auto
idx_low
=
get_container_subset
(
idx_hidden
,
dims_low
);
const
auto
idx_diff_up
=
get_container_subset
(
idx_diff_hidden
,
dims_up
);
multi_index
<
dims_low
.
size
()
>
idx_diff_low
;
tran
.
update_lower_index
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up_new
);
set_container_subset
(
idx_diff_hidden
,
dims_low
,
idx_diff_low
);
set_container_subset
(
idx_hidden
,
dims_low
,
idx_low
);
}
});
// set bottom index diff
idx_diff_bottom
=
get_container_subset
(
idx_diff_hidden
,
Adaptor
::
get_bottom_dimension_hidden_ids
());
}
template
<
bool
JudgeDoTransforms
=
true
,
typename
Adaptor
,
typename
AdaptorCoord
,
typename
TopIndex
>
CK_TILE_HOST_DEVICE
constexpr
void
move_tensor_adaptor_coordinate
(
const
Adaptor
&
adaptor
,
AdaptorCoord
&
coord
,
const
TopIndex
&
idx_diff_top
)
{
constexpr
index_t
ndim_bottom
=
Adaptor
::
get_num_of_bottom_dimension
();
multi_index
<
ndim_bottom
>
tmp
;
move_tensor_adaptor_coordinate
<
JudgeDoTransforms
>
(
adaptor
,
coord
,
idx_diff_top
,
tmp
);
}
template
<
typename
Adaptor
,
typename
AdaptorCoord
>
CK_TILE_HOST_DEVICE
constexpr
bool
adaptor_coordinate_is_valid_assuming_top_index_is_valid
(
const
Adaptor
&
adaptor
,
const
AdaptorCoord
&
coord
)
{
bool
valid
=
true
;
constexpr
index_t
ntransform
=
Adaptor
::
get_num_of_transform
();
const
auto
&
idx_hidden
=
coord
.
get_hidden_index
();
static_for
<
ntransform
-
1
,
-
1
,
-
1
>
{}([
&
adaptor
,
&
idx_hidden
,
&
valid
](
auto
itran
)
{
const
auto
tran
=
adaptor
.
get_transforms
().
at
(
itran
);
// check validity, only if current transformation does not always has a valid mapping
if
constexpr
(
!
decltype
(
tran
)
::
is_valid_upper_index_always_mapped_to_valid_lower_index
())
{
const
auto
idx_up
=
get_container_subset
(
idx_hidden
,
Adaptor
::
get_upper_dimension_hidden_idss
().
at
(
itran
));
// Comment: using valid = valid && .. will result in weird control flow in ISA
valid
&=
tran
.
is_valid_upper_index_mapped_to_valid_lower_index
(
idx_up
);
}
});
return
valid
;
}
template
<
typename
Adaptor
,
typename
AdpatorCoord
>
CK_TILE_HOST_DEVICE
constexpr
bool
adaptor_coordinate_is_valid
(
const
Adaptor
&
adaptor
,
const
AdpatorCoord
&
coord
)
{
// check top index
const
auto
&
idx_top
=
coord
.
get_top_index
();
bool
is_top_index_valid
=
true
;
static_for
<
0
,
Adaptor
::
get_num_of_dimension
(),
1
>
{}(
[
&
is_top_index_valid
,
&
idx_top
,
&
adaptor
](
auto
i
)
{
is_top_index_valid
=
is_top_index_valid
&&
(
idx_top
[
i
]
>=
0
&&
idx_top
[
i
]
<
adaptor
.
get_length
(
i
));
});
// check other hidden index
return
is_top_index_valid
&&
adaptor_coordinate_is_valid_assuming_top_index_is_valid
(
adaptor
,
coord
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tensor_coordinate.hpp
0 → 100644
View file @
300337cd
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
index_t
NDimHidden
,
typename
TopDimensionHiddenIds
>
struct
tensor_coordinate
:
public
tensor_adaptor_coordinate
<
NDimHidden
,
sequence
<
0
>
,
TopDimensionHiddenIds
>
{
using
Base
=
tensor_adaptor_coordinate
<
NDimHidden
,
sequence
<
0
>
,
TopDimensionHiddenIds
>
;
// TODO make these private
static
constexpr
index_t
ndim_top_
=
TopDimensionHiddenIds
::
size
();
using
HiddenIndex
=
multi_index
<
NDimHidden
>
;
using
TopIndex
=
multi_index
<
ndim_top_
>
;
public:
CK_TILE_HOST_DEVICE
constexpr
tensor_coordinate
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
tensor_coordinate
(
const
HiddenIndex
&
idx_hidden
)
:
Base
{
idx_hidden
}
{
}
// construct from TensorAdaptorCoordinte base class
CK_TILE_HOST_DEVICE
constexpr
tensor_coordinate
(
const
Base
&
adaptor_coord
)
:
Base
{
adaptor_coord
}
{
}
CK_TILE_HOST_DEVICE
constexpr
auto
get_index
()
const
{
return
Base
::
get_top_index
();
}
CK_TILE_HOST_DEVICE
constexpr
index_t
get_offset
()
const
{
return
Base
::
get_bottom_index
()[
number
<
0
>
{}];
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_hidden_index
()
const
{
return
Base
::
get_hidden_index
();
}
CK_TILE_HOST_DEVICE
auto
&
get_hidden_index
()
{
return
Base
::
get_hidden_index
();
}
};
template
<
typename
TensorDesc
,
typename
TopIndex
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
const
TopIndex
&
idx_top
)
{
const
auto
adaptor_coord
=
make_tensor_adaptor_coordinate
(
tensor_desc
,
idx_top
);
return
tensor_coordinate
<
TensorDesc
::
get_num_of_hidden_dimension
(),
remove_cvref_t
<
decltype
(
TensorDesc
::
get_top_dimension_hidden_ids
())
>>
{
adaptor_coord
};
}
template
<
bool
JudgeDoTransforms
=
true
,
typename
TensorDesc
,
typename
TensorCoord
,
typename
Index
>
CK_TILE_HOST_DEVICE
constexpr
void
move_tensor_coordinate
(
const
TensorDesc
&
tensor_desc
,
TensorCoord
&
coord
,
const
Index
&
coord_step
)
{
move_tensor_adaptor_coordinate
(
tensor_desc
,
coord
,
coord_step
);
}
template
<
typename
TensorDesc
,
typename
TensorCoord
>
CK_TILE_HOST_DEVICE
constexpr
bool
coordinate_has_valid_offset_assuming_top_index_is_valid
(
const
TensorDesc
&
tensor_desc
,
const
TensorCoord
&
coord
)
{
return
adaptor_coordinate_is_valid_assuming_top_index_is_valid
(
tensor_desc
,
coord
);
}
template
<
typename
TensorDesc
,
typename
TensorCoord
>
CK_TILE_HOST_DEVICE
constexpr
bool
coordinate_has_valid_offset
(
const
TensorDesc
&
tensor_desc
,
const
TensorCoord
&
coord
)
{
return
adaptor_coordinate_is_valid
(
tensor_desc
,
coord
);
}
}
// namespace ck_tile
Prev
1
…
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment