Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7971bb5b
Commit
7971bb5b
authored
Aug 04, 2024
by
carlushuang
Browse files
add test for scatter/gather
parent
d311c953
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
2688 additions
and
44 deletions
+2688
-44
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+200
-0
include/ck_tile/core/algorithm/indexing_adaptor.hpp
include/ck_tile/core/algorithm/indexing_adaptor.hpp
+60
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+9
-0
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+951
-28
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+33
-16
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+1
-0
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+7
-0
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
.../ck_tile/ops/elementwise/unary_element_wise_operation.hpp
+1151
-0
test/CMakeLists.txt
test/CMakeLists.txt
+2
-0
test/scatter_gather/CMakeLists.txt
test/scatter_gather/CMakeLists.txt
+2
-0
test/scatter_gather/scatter_gather.cpp
test/scatter_gather/scatter_gather.cpp
+271
-0
No files found.
include/ck_tile/core.hpp
View file @
7971bb5b
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/arch.hpp"
...
...
include/ck_tile/core/algorithm/coordinate_transform.hpp
View file @
7971bb5b
...
@@ -23,6 +23,7 @@ enum struct coord_transform_enum
...
@@ -23,6 +23,7 @@ enum struct coord_transform_enum
replicate
,
replicate
,
xor_t
,
xor_t
,
offset
,
offset
,
indexing
,
};
};
template
<
index_t
NDimLow
,
index_t
NDimUp
>
template
<
index_t
NDimLow
,
index_t
NDimUp
>
...
@@ -1549,6 +1550,184 @@ struct offset : public base_transform<1, 1>
...
@@ -1549,6 +1550,184 @@ struct offset : public base_transform<1, 1>
}
}
};
};
#if 0
template <typename UpLength,
typename Index>
struct indexing : public base_transform<1, 1>
{
static constexpr index_t NDimUp = 1;
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
using Indices = decltype(make_tuple(Index{}));
UpLengths up_lengths_;
Indices indices_;
CK_TILE_HOST_DEVICE constexpr indexing() = default;
CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
const Index& index)
: up_lengths_{make_tuple(up_length)}, indices_{make_tuple(indices)}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::indexing;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& /*idx_up*/) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = indices_[number<0>{}];
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& /*idx_diff_up*/,
LowIdx& /*idx_low*/,
const UpIdx& /*idx_up*/) const
{
// TODO: nonthing changed here
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_diff_low(number<0>{}) = 0;
//static_for<0, NDimUp, 1>{}(
// [&](auto i) { idx_diff_low(number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
// idx_low += idx_up;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<Indices>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("indices_: ");
print(indices_);
printf("}");
}
};
#endif
template
<
typename
UpLength
,
typename
IndexingAdaptor
>
struct
indexing
:
public
base_transform
<
1
,
1
>
{
static
constexpr
index_t
NDimUp
=
1
;
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
UpLength
{}));
UpLengths
up_lengths_
;
IndexingAdaptor
iadaptor_
;
CK_TILE_HOST_DEVICE
constexpr
indexing
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing
(
const
UpLength
&
up_length
,
const
IndexingAdaptor
&
iadaptor
)
:
up_lengths_
{
make_tuple
(
up_length
)},
iadaptor_
{
iadaptor
}
{
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
return
coord_transform_enum
::
indexing
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
iadaptor_
.
calculate_lower_index
(
idx_low
,
idx_up
);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
// TODO: nonthing changed here
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
NDimUp
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
iadaptor_
.
update_lower_index
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up
);
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
IndexingAdaptor
::
is_known_at_compile_time
();
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"embed{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
"}"
);
}
};
//*******************************************************************************************************
//*******************************************************************************************************
template
<
typename
LowLength
>
template
<
typename
LowLength
>
...
@@ -1670,3 +1849,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
...
@@ -1670,3 +1849,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
}
}
}
// namespace ck_tile
}
// namespace ck_tile
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
namespace
ck_tile
{
template
<
typename
UpLength
,
typename
Indices
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_indexing_transform
(
const
UpLength
&
up_lengths
,
const
Indices
&
indices
)
{
// by default we use the simplest one
return
indexing
<
UpLength
,
indexing_adaptor_onshot_cached
<
remove_cvref_t
<
Indices
>>>
{
up_lengths
,
indexing_adaptor_onshot_cached
<
remove_cvref_t
<
Indices
>>
{
indices
}};
}
template
<
typename
UpLength
,
typename
IndexingAdaptor
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_indexing_transform_with_adaptor
(
const
UpLength
&
up_lengths
,
const
IndexingAdaptor
&
iadaptor
)
{
return
indexing
<
UpLength
,
IndexingAdaptor
>
{
up_lengths
,
iadaptor
};
}
}
// namespace ck_tile
include/ck_tile/core/algorithm/indexing_adaptor.hpp
0 → 100644
View file @
7971bb5b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// pre-defined indexing adaptor used for indexing(scatter/gather)
// this version cache the index inside thread register(which is also prefered in real senario)
// however it's user's responsibility that each thread only provide one indexing, which means
// move coordinate will not change on this dim
template
<
typename
IndexingType
>
struct
indexing_adaptor_onshot_cached
{
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor_onshot_cached
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor_onshot_cached
(
const
IndexingType
&
idx
)
:
cached_idx_
(
idx
)
{
}
IndexingType
cached_idx_
;
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
/*idx_up*/
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
cached_idx_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
/*idx_low*/
,
const
UpIdx
&
/*idx_up*/
)
const
{
// TODO: nonthing changed here
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
number
<
0
>
{})
=
idx_diff_up
[
number
<
0
>
{}];
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
IndexingType
>::
value
;
}
};
}
// namespace ck_tile
include/ck_tile/core/config.hpp
View file @
7971bb5b
...
@@ -31,12 +31,16 @@
...
@@ -31,12 +31,16 @@
#define CK_TILE_HOST inline __host__
#define CK_TILE_HOST inline __host__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_HOST_EXTERN __host__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else
#else
#define CK_TILE_HOST inline
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_HOST_EXTERN
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
...
@@ -191,3 +195,8 @@
...
@@ -191,3 +195,8 @@
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif
#endif
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif
include/ck_tile/core/numeric/math.hpp
View file @
7971bb5b
...
@@ -41,9 +41,8 @@ struct scales
...
@@ -41,9 +41,8 @@ struct scales
Scale
lhs_
;
Scale
lhs_
;
};
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template
<
typename
Scale
>
template
<
typename
Scale
>
__host__
__device__
scales
(
Scale
)
->
scales
<
Scale
>
;
CK_TILE_HOST_DEVICE_EXTERN
scales
(
Scale
)
->
scales
<
Scale
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
plus
struct
plus
...
@@ -66,8 +65,7 @@ struct plus<void, void>
...
@@ -66,8 +65,7 @@ struct plus<void, void>
}
}
};
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
CK_TILE_HOST_DEVICE_EXTERN
plus
()
->
plus
<
void
,
void
>
;
__host__
__device__
plus
()
->
plus
<
void
,
void
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
minus
struct
minus
...
@@ -90,8 +88,7 @@ struct minus<void, void>
...
@@ -90,8 +88,7 @@ struct minus<void, void>
}
}
};
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
CK_TILE_HOST_DEVICE_EXTERN
minus
()
->
minus
<
void
,
void
>
;
__host__
__device__
minus
()
->
minus
<
void
,
void
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
multiplies
struct
multiplies
...
@@ -114,8 +111,7 @@ struct multiplies<void, void>
...
@@ -114,8 +111,7 @@ struct multiplies<void, void>
}
}
};
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
CK_TILE_HOST_DEVICE_EXTERN
multiplies
()
->
multiplies
<
void
,
void
>
;
__host__
__device__
multiplies
()
->
multiplies
<
void
,
void
>
;
template
<
typename
T
>
template
<
typename
T
>
struct
maximize
struct
maximize
...
@@ -345,8 +341,7 @@ struct equal<void, void>
...
@@ -345,8 +341,7 @@ struct equal<void, void>
}
}
};
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
CK_TILE_HOST_DEVICE_EXTERN
equal
()
->
equal
<
void
,
void
>
;
__host__
__device__
equal
()
->
equal
<
void
,
void
>
;
template
<
>
template
<
>
struct
equal
<
float
,
float
>
struct
equal
<
float
,
float
>
...
@@ -387,8 +382,7 @@ struct less<void, void>
...
@@ -387,8 +382,7 @@ struct less<void, void>
}
}
};
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
CK_TILE_HOST_DEVICE_EXTERN
less
()
->
less
<
void
,
void
>
;
__host__
__device__
less
()
->
less
<
void
,
void
>
;
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
template
<
typename
Left
=
void
,
typename
Right
=
Left
>
struct
less_equal
struct
less_equal
...
@@ -411,8 +405,7 @@ struct less_equal<void, void>
...
@@ -411,8 +405,7 @@ struct less_equal<void, void>
}
}
};
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
CK_TILE_HOST_DEVICE_EXTERN
less_equal
()
->
less_equal
<
void
,
void
>
;
__host__
__device__
less_equal
()
->
less_equal
<
void
,
void
>
;
template
<
>
template
<
>
struct
less_equal
<
float
,
float
>
struct
less_equal
<
float
,
float
>
...
@@ -488,19 +481,19 @@ template <typename T = double>
...
@@ -488,19 +481,19 @@ template <typename T = double>
constexpr
T
log2e_v
=
log2e
<
T
>::
value
;
constexpr
T
log2e_v
=
log2e
<
T
>::
value
;
// math
// math
CK_TILE_HOST_DEVICE
//
CK_TILE_HOST_DEVICE
float
abs
(
const
float
&
x
)
//
float abs(const float& x)
{
//
{
union
//
union
{
//
{
float
f32
;
//
float f32;
uint32_t
u32
;
//
uint32_t u32;
}
y
;
//
} y;
y
.
f32
=
x
;
//
y.f32 = x;
y
.
u32
=
y
.
u32
&
0x7fffffff
;
//
y.u32 = y.u32 & 0x7fffffff;
return
y
.
f32
;
//
return y.f32;
}
//
}
#if 0
CK_TILE_HOST_DEVICE
CK_TILE_HOST_DEVICE
bool isnan(const float& x)
bool isnan(const float& x)
{
{
...
@@ -523,18 +516,20 @@ float exp(float x) { return __ocml_exp_f32(x); };
...
@@ -523,18 +516,20 @@ float exp(float x) { return __ocml_exp_f32(x); };
CK_TILE_HOST
CK_TILE_HOST
float exp(float x) { return std::expf(x); }
float exp(float x) { return std::expf(x); }
#endif
CK_TILE_DEVICE
CK_TILE_DEVICE
float
exp2
(
float
x
)
{
return
exp2f
(
x
);
};
float
exp2
(
float
x
)
{
return
exp2f
(
x
);
};
CK_TILE_HOST
CK_TILE_HOST
float
exp2
(
float
x
)
{
return
std
::
exp2f
(
x
);
};
float
exp2
(
float
x
)
{
return
std
::
exp2f
(
x
);
};
#if 0
CK_TILE_DEVICE
CK_TILE_DEVICE
float log(float x) { return __logf(x); };
float log(float x) { return __logf(x); };
CK_TILE_HOST
CK_TILE_HOST
float log(float x) { return std::logf(x); };
float log(float x) { return std::logf(x); };
#endif
CK_TILE_DEVICE
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
CK_TILE_DEVICE
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
{
...
@@ -547,4 +542,932 @@ CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
...
@@ -547,4 +542,932 @@ CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
}
}
///////////////////////////////////////////////////////////////
}
// namespace ck_tile
// blow function need data type pre-defined
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
namespace
ck_tile
{
#if CK_TILE_WORKAROUND_SWDEV_383542
extern
"C"
CK_TILE_DEVICE
float
__ocml_native_recip_f32
(
float
);
#endif
// math functions for the host, some are implemented by calling C++ std functions
CK_TILE_HOST
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
CK_TILE_HOST
double
abs
(
double
x
)
{
return
std
::
abs
(
x
);
};
CK_TILE_HOST
int8_t
abs
(
int8_t
x
)
{
int8_t
sgn
=
x
>>
(
8
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
CK_TILE_HOST
int32_t
abs
(
int32_t
x
)
{
int32_t
sgn
=
x
>>
(
32
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
CK_TILE_HOST
fp16_t
abs
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
fp16_t
abs_x
=
bit_cast
<
fp16_t
>
(
abs_xx
);
return
abs_x
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST
int4_t
abs
(
int4_t
x
)
{
int4_t
sgn
=
x
>>
(
4
-
1
);
return
(
x
^
sgn
)
-
sgn
;
}
#endif
CK_TILE_HOST
bool
isnan
(
float
x
)
{
return
std
::
isnan
(
x
);
};
CK_TILE_HOST
bool
isnan
(
double
x
)
{
return
std
::
isnan
(
x
);
};
CK_TILE_HOST
bool
isnan
(
int8_t
x
)
{
(
void
)
x
;
return
false
;
};
CK_TILE_HOST
bool
isnan
(
int32_t
x
)
{
(
void
)
x
;
return
false
;
};
CK_TILE_HOST
bool
isnan
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST
bool
isnan
(
int4_t
x
)
{
(
void
)
x
;
return
false
;
};
#endif
CK_TILE_HOST
fp16_t
sqrt
(
fp16_t
x
)
{
return
static_cast
<
fp16_t
>
(
std
::
sqrt
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_HOST
float
sqrt
(
float
x
)
{
return
std
::
sqrt
(
x
);
};
CK_TILE_HOST
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
tanh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
tanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
tanh
<
float
>
(
float
x
)
{
return
std
::
tanhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
tanh
<
double
>
(
double
x
)
{
return
std
::
tanh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
acos
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
acosf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
acos
<
float
>
(
float
x
)
{
return
std
::
acosf
(
x
);
};
template
<
>
CK_TILE_HOST
double
acos
<
double
>
(
double
x
)
{
return
std
::
acos
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
neg
(
T
x
)
{
return
type_convert
<
T
>
(
-
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
neg
<
float
>
(
float
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_HOST
double
neg
<
double
>
(
double
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_HOST
int32_t
neg
<
int32_t
>
(
int32_t
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_HOST
int8_t
neg
<
int8_t
>
(
int8_t
x
)
{
return
-
x
;
};
template
<
typename
T
>
CK_TILE_HOST
T
atan
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
atanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
atan
<
float
>
(
float
x
)
{
return
std
::
atanf
(
x
);
};
template
<
>
CK_TILE_HOST
double
atan
<
double
>
(
double
x
)
{
return
std
::
atan
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
sin
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
sinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
sin
<
float
>
(
float
x
)
{
return
std
::
sinf
(
x
);
};
template
<
>
CK_TILE_HOST
double
sin
<
double
>
(
double
x
)
{
return
std
::
sin
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
asin
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
asinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
asin
<
float
>
(
float
x
)
{
return
std
::
asinf
(
x
);
};
template
<
>
CK_TILE_HOST
double
asin
<
double
>
(
double
x
)
{
return
std
::
asin
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
asinh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
asinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
asinh
<
float
>
(
float
x
)
{
return
std
::
asinhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
asinh
<
double
>
(
double
x
)
{
return
std
::
asinh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
cos
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
cosf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
cos
<
float
>
(
float
x
)
{
return
std
::
cosf
(
x
);
};
template
<
>
CK_TILE_HOST
double
cos
<
double
>
(
double
x
)
{
return
std
::
cos
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
acosh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
acoshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
acosh
<
float
>
(
float
x
)
{
return
std
::
acoshf
(
x
);
};
template
<
>
CK_TILE_HOST
double
acosh
<
double
>
(
double
x
)
{
return
std
::
acosh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
tan
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
tanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
tan
<
float
>
(
float
x
)
{
return
std
::
tanf
(
x
);
};
template
<
>
CK_TILE_HOST
double
tan
<
double
>
(
double
x
)
{
return
std
::
tan
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
atanh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
atanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
atanh
<
float
>
(
float
x
)
{
return
std
::
atanhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
atanh
<
double
>
(
double
x
)
{
return
std
::
atanh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
sinh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
sinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
sinh
<
float
>
(
float
x
)
{
return
std
::
sinhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
sinh
<
double
>
(
double
x
)
{
return
std
::
sinh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
ceil
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
ceilf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
ceil
<
float
>
(
float
x
)
{
return
std
::
ceilf
(
x
);
};
template
<
>
CK_TILE_HOST
double
ceil
<
double
>
(
double
x
)
{
return
std
::
ceil
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
cosh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
coshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
cosh
<
float
>
(
float
x
)
{
return
std
::
coshf
(
x
);
};
template
<
>
CK_TILE_HOST
double
cosh
<
double
>
(
double
x
)
{
return
std
::
cosh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
floor
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
floorf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
floor
<
float
>
(
float
x
)
{
return
std
::
floorf
(
x
);
};
template
<
>
CK_TILE_HOST
double
floor
<
double
>
(
double
x
)
{
return
std
::
floor
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
rcp
(
T
x
)
{
return
type_convert
<
T
>
(
1.
f
/
type_convert
<
float
>
(
x
));
};
template
<
typename
T
>
CK_TILE_HOST
T
exp
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
expf
(
type_convert
<
float
>
(
x
)));
}
template
<
>
CK_TILE_HOST
float
exp
<
float
>
(
float
x
)
{
return
std
::
expf
(
x
);
}
template
<
>
CK_TILE_HOST
double
exp
<
double
>
(
double
x
)
{
return
std
::
exp
(
x
);
}
template
<
typename
T
>
CK_TILE_HOST
T
log
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
logf
(
type_convert
<
float
>
(
x
)));
}
template
<
>
CK_TILE_HOST
float
log
<
float
>
(
float
x
)
{
return
std
::
logf
(
x
);
}
template
<
>
CK_TILE_HOST
double
log
<
double
>
(
double
x
)
{
return
std
::
log
(
x
);
}
template
<
typename
T
>
CK_TILE_HOST
T
pow
(
T
x
,
T
gamma
)
{
return
type_convert
<
T
>
(
std
::
powf
(
type_convert
<
float
>
(
x
),
type_convert
<
float
>
(
gamma
)));
}
template
<
>
CK_TILE_HOST
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
std
::
powf
(
x
,
gamma
);
}
template
<
>
CK_TILE_HOST
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
std
::
pow
(
x
,
gamma
);
}
template
<
typename
T
>
CK_TILE_HOST
T
expm1
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
expm1f
(
type_convert
<
float
>
(
x
)));
}
template
<
>
CK_TILE_HOST
float
expm1
<
float
>
(
float
x
)
{
return
std
::
expm1f
(
x
);
}
template
<
>
CK_TILE_HOST
double
expm1
<
double
>
(
double
x
)
{
return
std
::
expm1
(
x
);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
CK_TILE_DEVICE
float
abs
(
float
x
)
{
union
{
float
f32
;
uint32_t
u32
;
}
y
;
y
.
f32
=
x
;
y
.
u32
=
y
.
u32
&
0x7fffffff
;
return
y
.
f32
;
};
CK_TILE_DEVICE
double
abs
(
double
x
)
{
return
::
abs
(
x
);
};
CK_TILE_DEVICE
int8_t
abs
(
int8_t
x
)
{
int8_t
sgn
=
x
>>
(
8
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
CK_TILE_DEVICE
int32_t
abs
(
int32_t
x
)
{
int32_t
sgn
=
x
>>
(
32
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE
int4_t
abs
(
int4_t
x
)
{
int4_t
sgn
=
x
>>
(
4
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
#endif
CK_TILE_DEVICE
fp16_t
abs
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
fp16_t
abs_x
=
bit_cast
<
fp16_t
>
(
abs_xx
);
return
abs_x
;
};
CK_TILE_DEVICE
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
CK_TILE_DEVICE
bool
isnan
(
double
x
)
{
return
::
isnan
(
x
);
};
CK_TILE_DEVICE
bool
isnan
(
int8_t
x
)
{
(
void
)
x
;
return
false
;
};
CK_TILE_DEVICE
bool
isnan
(
int32_t
x
)
{
(
void
)
x
;
return
false
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE
bool
isnan
(
int4_t
x
)
{
(
void
)
x
;
return
false
;
};
#endif
CK_TILE_DEVICE
bool
isnan
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
CK_TILE_DEVICE
fp16_t
sqrt
(
fp16_t
x
)
{
return
static_cast
<
fp16_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
float
sqrt
(
float
x
)
{
return
__builtin_amdgcn_sqrtf
(
x
);
};
CK_TILE_DEVICE
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
tanh
(
T
x
)
{
return
type_convert
<
T
>
(
::
tanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
tanh
<
float
>
(
float
x
)
{
return
::
tanhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
tanh
<
double
>
(
double
x
)
{
return
::
tanh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
acos
(
T
x
)
{
return
type_convert
<
T
>
(
::
acosf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
acos
<
float
>
(
float
x
)
{
return
::
acosf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
acos
<
double
>
(
double
x
)
{
return
::
acos
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
neg
(
T
x
)
{
return
type_convert
<
T
>
(
-
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
neg
<
float
>
(
float
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
double
neg
<
double
>
(
double
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
int32_t
neg
<
int32_t
>
(
int32_t
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
int8_t
neg
<
int8_t
>
(
int8_t
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
fp16_t
neg
<
fp16_t
>
(
fp16_t
x
)
{
return
__hneg
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
atan
(
T
x
)
{
return
type_convert
<
T
>
(
::
atanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
atan
<
float
>
(
float
x
)
{
return
::
atanf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
atan
<
double
>
(
double
x
)
{
return
::
atan
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
sin
(
T
x
)
{
return
type_convert
<
T
>
(
::
sinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
sin
<
float
>
(
float
x
)
{
return
::
sinf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
sin
<
double
>
(
double
x
)
{
return
::
sin
(
x
);
};
template
<
>
CK_TILE_DEVICE
fp16_t
sin
<
fp16_t
>
(
fp16_t
x
)
{
return
::
hsin
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
asin
(
T
x
)
{
return
type_convert
<
T
>
(
::
asinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
asin
<
float
>
(
float
x
)
{
return
::
asinf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
asin
<
double
>
(
double
x
)
{
return
::
asin
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
asinh
(
T
x
)
{
return
type_convert
<
T
>
(
::
asinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
asinh
<
float
>
(
float
x
)
{
return
::
asinhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
asinh
<
double
>
(
double
x
)
{
return
::
asinh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
acosh
(
T
x
)
{
return
type_convert
<
T
>
(
::
acoshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
acosh
<
float
>
(
float
x
)
{
return
::
acoshf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
acosh
<
double
>
(
double
x
)
{
return
::
acosh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
tan
(
T
x
)
{
return
type_convert
<
T
>
(
::
tanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
tan
<
float
>
(
float
x
)
{
return
::
tanf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
tan
<
double
>
(
double
x
)
{
return
::
tan
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
atanh
(
T
x
)
{
return
type_convert
<
T
>
(
::
atanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
atanh
<
float
>
(
float
x
)
{
return
::
atanhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
atanh
<
double
>
(
double
x
)
{
return
::
atanh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
sinh
(
T
x
)
{
return
type_convert
<
T
>
(
::
sinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
sinh
<
float
>
(
float
x
)
{
return
::
sinhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
sinh
<
double
>
(
double
x
)
{
return
::
sinh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
ceil
(
T
x
)
{
return
type_convert
<
T
>
(
::
ceilf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
ceil
<
float
>
(
float
x
)
{
return
::
ceilf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
ceil
<
double
>
(
double
x
)
{
return
::
ceil
(
x
);
};
template
<
>
CK_TILE_DEVICE
fp16_t
ceil
<
fp16_t
>
(
fp16_t
x
)
{
return
::
hceil
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
cosh
(
T
x
)
{
return
type_convert
<
T
>
(
::
coshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
cosh
<
float
>
(
float
x
)
{
return
::
coshf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
cosh
<
double
>
(
double
x
)
{
return
::
cosh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
floor
(
T
x
)
{
return
type_convert
<
T
>
(
::
floorf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
floor
<
float
>
(
float
x
)
{
return
::
floorf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
floor
<
double
>
(
double
x
)
{
return
::
floor
(
x
);
};
template
<
>
CK_TILE_DEVICE
fp16_t
floor
<
fp16_t
>
(
fp16_t
x
)
{
return
::
hfloor
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
rcp
(
T
x
)
{
#if !CK_TILE_WORKAROUND_SWDEV_383542
return
__frcp_rn
(
x
);
#else
return
__ocml_native_recip_f32
(
x
);
#endif
};
template
<
typename
T
>
CK_TILE_DEVICE
T
exp
(
T
x
)
{
return
type_convert
<
T
>
(
__ocml_exp_f32
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
fp16_t
exp
<
fp16_t
>
(
fp16_t
x
)
{
return
hexp
(
x
);
};
template
<
>
CK_TILE_DEVICE
float
exp
<
float
>
(
float
x
)
{
return
__ocml_exp_f32
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
log
(
T
x
)
{
return
type_convert
<
T
>
(
__logf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
fp16_t
log
<
fp16_t
>
(
fp16_t
x
)
{
return
hlog
(
x
);
};
template
<
>
CK_TILE_DEVICE
float
log
<
float
>
(
float
x
)
{
return
__logf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
log
<
double
>
(
double
x
)
{
return
log
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
pow
(
T
x
,
T
gamma
)
{
return
type_convert
<
T
>
(
powf
(
type_convert
<
float
>
(
x
),
type_convert
<
float
>
(
gamma
)));
};
template
<
>
CK_TILE_DEVICE
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
powf
(
x
,
gamma
);
};
template
<
>
CK_TILE_DEVICE
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
pow
(
x
,
gamma
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
expm1
(
T
x
)
{
return
type_convert
<
T
>
(
expm1f
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
expm1
<
float
>
(
float
x
)
{
return
expm1f
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
expm1
<
double
>
(
double
x
)
{
return
expm1
(
x
);
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
7971bb5b
...
@@ -17,6 +17,14 @@
...
@@ -17,6 +17,14 @@
namespace
ck_tile
{
namespace
ck_tile
{
namespace
detail
{
template
<
typename
Distribution
>
CK_TILE_HOST_DEVICE
auto
get_partition_index
(
Distribution
)
{
return
Distribution
::
_get_partition_index
();
}
}
// namespace detail
// distributed span
// distributed span
template
<
index_t
...
PartialHsLengths
>
template
<
index_t
...
PartialHsLengths
>
struct
tile_distributed_span
struct
tile_distributed_span
...
@@ -83,6 +91,21 @@ struct tile_distribution
...
@@ -83,6 +91,21 @@ struct tile_distribution
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_p
()
{
return
NDimP
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_p
()
{
return
NDimP
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_r
()
{
return
NDimR
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_r
()
{
return
NDimR
;
}
CK_TILE_HOST_DEVICE
static
auto
_get_partition_index
()
{
// only support warp-tile and block-tile
static_assert
(
NDimP
==
1
or
NDimP
==
2
,
"wrong!"
);
if
constexpr
(
NDimP
==
1
)
{
return
array
<
index_t
,
1
>
{
get_lane_id
()};
}
else
if
constexpr
(
NDimP
==
2
)
{
return
array
<
index_t
,
2
>
{
get_warp_id
(),
get_lane_id
()};
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lengths
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lengths
()
{
{
#if 0
#if 0
...
@@ -149,6 +172,16 @@ struct tile_distribution
...
@@ -149,6 +172,16 @@ struct tile_distribution
}
}
#endif
#endif
template
<
typename
PartitionIndex
=
decltype
(
_get_partition_index
())>
CK_TILE_HOST_DEVICE
auto
calculate_index
(
const
PartitionIndex
&
ps_idx
=
_get_partition_index
())
const
{
const
auto
ps_ys_idx
=
container_concat
(
ps_idx
,
array
<
index_t
,
NDimY
>
{
0
});
const
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
ps_ys_to_xs_
,
ps_ys_idx
);
return
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_distributed_spans
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_distributed_spans
()
{
{
constexpr
auto
distributed_spans_impl
=
DstrEncode
::
detail
::
distributed_spans_lengthss_
;
constexpr
auto
distributed_spans_impl
=
DstrEncode
::
detail
::
distributed_spans_lengthss_
;
...
@@ -500,22 +533,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
...
@@ -500,22 +533,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
namespace
detail
{
namespace
detail
{
template
<
typename
Distribution
>
CK_TILE_HOST_DEVICE
auto
get_partition_index
(
Distribution
)
{
// only support warp-tile and block-tile
static_assert
(
Distribution
::
NDimP
==
1
or
Distribution
::
NDimP
==
2
,
"wrong!"
);
if
constexpr
(
Distribution
::
NDimP
==
1
)
{
return
array
<
index_t
,
1
>
{
get_lane_id
()};
}
else
if
constexpr
(
Distribution
::
NDimP
==
2
)
{
return
array
<
index_t
,
2
>
{
get_warp_id
(),
get_lane_id
()};
}
}
template
<
typename
,
typename
,
typename
,
index_t
>
template
<
typename
,
typename
,
typename
,
index_t
>
struct
reverse_slice_sequence_impl
;
struct
reverse_slice_sequence_impl
;
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
7971bb5b
...
@@ -41,6 +41,7 @@ struct tile_window_with_static_distribution
...
@@ -41,6 +41,7 @@ struct tile_window_with_static_distribution
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static_assert
(
NumCoord
==
1
);
// TODO: check WindowLengths and StaticTileDistribution are consistent
// TODO: check WindowLengths and StaticTileDistribution are consistent
...
...
include/ck_tile/ops/elementwise.hpp
0 → 100644
View file @
7971bb5b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
0 → 100644
View file @
7971bb5b
This diff is collapsed.
Click to expand it.
test/CMakeLists.txt
View file @
7971bb5b
...
@@ -217,3 +217,5 @@ if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_
...
@@ -217,3 +217,5 @@ if(GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_
add_subdirectory
(
smfmac_op
)
add_subdirectory
(
smfmac_op
)
endif
()
endif
()
add_subdirectory
(
position_embedding
)
add_subdirectory
(
position_embedding
)
add_subdirectory
(
scatter_gather
)
test/scatter_gather/CMakeLists.txt
0 → 100644
View file @
7971bb5b
add_test_executable
(
test_scatter_gather scatter_gather.cpp
)
# target_compile_options(test_scatter_gather PRIVATE -v --save-temps -Wno-gnu-line-marker)
test/scatter_gather/scatter_gather.cpp
0 → 100644
View file @
7971bb5b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#ifndef TEST_SCATTER_GATHER_VERBOSE
#define TEST_SCATTER_GATHER_VERBOSE 0
#endif
#define HIP_CALL(call) \
do \
{ \
hipError_t err = call; \
if(err != hipSuccess) \
{ \
printf("[hiperror](%d) fail to call %s", static_cast<int>(err), #call); \
exit(0); \
} \
} while(0)
template
<
ck_tile
::
index_t
ROW_TILE_SIZE
=
8
,
ck_tile
::
index_t
COL_TILE_SIZE
=
32
*
8
,
ck_tile
::
index_t
BLOCK_SIZE
=
256
,
ck_tile
::
index_t
ALIGNMENT
=
8
,
typename
INDEX_BUF_TYPE
=
ck_tile
::
index_t
,
typename
DATA_TYPE
=
ck_tile
::
fp16_t
>
__global__
void
row_scatter_gather
(
const
INDEX_BUF_TYPE
*
src_row_idx_ptr
,
const
INDEX_BUF_TYPE
*
dst_row_idx_ptr
,
const
DATA_TYPE
*
src_ptr
,
DATA_TYPE
*
dst_ptr
,
ck_tile
::
index_t
n_row_total
,
ck_tile
::
index_t
/*n_row_select*/
,
ck_tile
::
index_t
n_cols
)
{
using
namespace
ck_tile
;
// some constexpr vars
constexpr
index_t
vec
=
ALIGNMENT
;
static_assert
(
COL_TILE_SIZE
%
vec
==
0
);
constexpr
index_t
col_lanes
=
COL_TILE_SIZE
/
vec
;
constexpr
index_t
warp_size
=
ck_tile
::
get_warp_size
();
static_assert
(
warp_size
%
col_lanes
==
0
);
constexpr
index_t
row_lanes
=
warp_size
/
col_lanes
;
constexpr
index_t
num_warps
=
BLOCK_SIZE
/
warp_size
;
static_assert
(
ROW_TILE_SIZE
%
(
num_warps
*
row_lanes
)
==
0
);
constexpr
index_t
row_repeat
=
ROW_TILE_SIZE
/
(
num_warps
*
row_lanes
);
static_assert
(
row_repeat
==
1
,
"currently indexing not support(and would be not performant) if row_repeat has more"
);
// tile partitioner
index_t
tile_col_idx
=
0
;
index_t
tile_row_idx
=
blockIdx
.
x
*
ROW_TILE_SIZE
;
// create our tild distribution, which tell us the location of different threads
constexpr
auto
src_dist
=
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
row_repeat
,
num_warps
,
row_lanes
>
,
sequence
<
col_lanes
,
vec
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
const
auto
coord
=
src_dist
.
calculate_index
();
const
auto
row_coord
=
coord
[
number
<
0
>
{}]
+
tile_row_idx
;
// load the current row index from the indexing buffer. we do not use ck_tile utility here
INDEX_BUF_TYPE
src_row_id
=
src_row_idx_ptr
[
row_coord
];
INDEX_BUF_TYPE
dst_row_id
=
dst_row_idx_ptr
[
row_coord
];
// printf("-- tid:%d, src_row_id:%d, dst_row_id:%d\n", static_cast<int>(threadIdx.x),
// static_cast<int>(src_row_id), static_cast<int>(dst_row_id));
const
auto
src_view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
src_ptr
,
make_tuple
(
n_row_total
,
n_cols
),
make_tuple
(
n_cols
,
1
),
number
<
vec
>
{},
// alignement
number
<
1
>
{});
const
auto
src_gather_view
=
transform_tensor_view
(
src_view
,
make_tuple
(
make_indexing_transform
(
n_row_total
,
src_row_id
),
// here we replace row_idx which is loaded from another buffer
make_pass_through_transform
(
n_cols
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
auto
src_tile
=
make_tile_window
(
src_gather_view
,
make_tuple
(
number
<
ROW_TILE_SIZE
>
{},
number
<
COL_TILE_SIZE
>
{}),
{
tile_row_idx
,
tile_col_idx
},
src_dist
);
const
auto
dst_view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
dst_ptr
,
make_tuple
(
n_row_total
,
n_cols
),
make_tuple
(
n_cols
,
1
),
number
<
vec
>
{},
number
<
1
>
{});
const
auto
dst_scatter_view
=
transform_tensor_view
(
dst_view
,
make_tuple
(
make_indexing_transform
(
n_row_total
,
dst_row_id
),
// here we replace row_idx which is loaded from another buffer
make_pass_through_transform
(
n_cols
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
auto
dst_tile
=
make_tile_window
(
dst_scatter_view
,
make_tuple
(
number
<
ROW_TILE_SIZE
>
{},
number
<
COL_TILE_SIZE
>
{}),
{
tile_row_idx
,
tile_col_idx
},
src_dist
/*reuse distribution*/
);
// we finished descriptor construction and index calculation, now start load/store
for
(
auto
i
=
0
;
i
<
n_cols
;
i
+=
COL_TILE_SIZE
)
{
// note that scatter/gather are just the same API when doing load store as normal memory
// operation
auto
data
=
load_tile
(
src_tile
);
store_tile
(
dst_tile
,
data
);
move_tile_window
(
src_tile
,
{
0
,
COL_TILE_SIZE
});
move_tile_window
(
dst_tile
,
{
0
,
COL_TILE_SIZE
});
}
}
union
pixel
{
struct
__attribute__
((
packed
))
{
unsigned
int
r
:
6
;
unsigned
int
c
:
10
;
};
ushort
data
;
};
struct
unique_linear_rand
{
unique_linear_rand
(
int
capacity_
)
:
capacity
(
capacity_
)
{}
std
::
unordered_set
<
int
>
set
;
int
gen
()
{
if
(
static_cast
<
int
>
(
set
.
size
())
>=
capacity
)
{
printf
(
"overflow, but will give you an number as well
\n
"
);
return
std
::
rand
()
%
capacity
;
}
while
(
1
)
{
int
r
=
std
::
rand
()
%
capacity
;
if
(
set
.
count
(
r
)
==
1
)
{
continue
;
}
set
.
insert
(
r
);
return
r
;
}
}
int
capacity
;
};
int
main
()
{
int
row_total
=
64
;
int
row_select
=
8
*
2
;
int
col
=
256
*
2
;
using
fp16_t
=
ck_tile
::
fp16_t
;
constexpr
int
row_tile
=
8
;
constexpr
int
col_tile
=
256
;
fp16_t
*
src
=
reinterpret_cast
<
fp16_t
*>
(
malloc
(
row_total
*
col
*
sizeof
(
fp16_t
)));
for
(
int
i_r
=
0
;
i_r
<
row_total
;
i_r
++
)
{
for
(
int
i_c
=
0
;
i_c
<
col
;
i_c
++
)
{
int
i
=
i_r
*
col
+
i_c
;
pixel
p
;
p
.
r
=
i_r
;
p
.
c
=
i_c
;
ushort
d
=
p
.
data
;
src
[
i
]
=
ck_tile
::
bit_cast
<
fp16_t
>
(
d
);
// for simplicity, just cast
}
}
fp16_t
*
dst
=
reinterpret_cast
<
fp16_t
*>
(
malloc
(
row_total
*
col
*
sizeof
(
fp16_t
)));
int
*
src_idx
=
reinterpret_cast
<
int
*>
(
malloc
(
row_select
*
sizeof
(
int
)));
int
*
dst_idx
=
reinterpret_cast
<
int
*>
(
malloc
(
row_select
*
sizeof
(
int
)));
// std::srand(std::time(std::nullptr));
// std::srand(11935);
std
::
srand
(
std
::
time
(
nullptr
));
auto
src_gen
=
unique_linear_rand
(
row_total
);
auto
dst_gen
=
unique_linear_rand
(
row_total
);
// dst index must be unique. src is fine
for
(
int
i_r
=
0
;
i_r
<
row_select
;
i_r
++
)
{
src_idx
[
i_r
]
=
src_gen
.
gen
();
dst_idx
[
i_r
]
=
dst_gen
.
gen
();
}
void
*
dev_src
;
void
*
dev_dst
;
void
*
dev_src_idx
;
void
*
dev_dst_idx
;
HIP_CALL
(
hipMalloc
(
&
dev_src
,
row_total
*
col
*
sizeof
(
fp16_t
)));
HIP_CALL
(
hipMalloc
(
&
dev_dst
,
row_total
*
col
*
sizeof
(
fp16_t
)));
HIP_CALL
(
hipMalloc
(
&
dev_src_idx
,
row_select
*
sizeof
(
int
)));
HIP_CALL
(
hipMalloc
(
&
dev_dst_idx
,
row_select
*
sizeof
(
int
)));
HIP_CALL
(
hipMemcpy
(
dev_src
,
src
,
row_total
*
col
*
sizeof
(
fp16_t
),
hipMemcpyHostToDevice
));
HIP_CALL
(
hipMemcpy
(
dev_src_idx
,
src_idx
,
row_select
*
sizeof
(
int
),
hipMemcpyHostToDevice
));
HIP_CALL
(
hipMemcpy
(
dev_dst_idx
,
dst_idx
,
row_select
*
sizeof
(
int
),
hipMemcpyHostToDevice
));
constexpr
int
bdim
=
256
;
int
gdim
=
(
row_select
+
row_tile
-
1
)
/
row_tile
;
row_scatter_gather
<
row_tile
,
col_tile
><<<
gdim
,
bdim
>>>
(
reinterpret_cast
<
int
*>
(
dev_src_idx
),
reinterpret_cast
<
int
*>
(
dev_dst_idx
),
reinterpret_cast
<
fp16_t
*>
(
dev_src
),
reinterpret_cast
<
fp16_t
*>
(
dev_dst
),
row_total
,
row_select
,
col
);
HIP_CALL
(
hipMemcpy
(
dst
,
dev_dst
,
row_total
*
col
*
sizeof
(
fp16_t
),
hipMemcpyDeviceToHost
));
#if TEST_SCATTER_GATHER_VERBOSE
printf
(
"select row:"
);
for
(
int
i_r
=
0
;
i_r
<
row_select
;
i_r
++
)
{
printf
(
"%d->%d->%d "
,
i_r
,
src_idx
[
i_r
],
dst_idx
[
i_r
]);
}
printf
(
"
\n
"
);
#endif
int
err_cnt
=
0
;
for
(
int
i_r
=
0
;
i_r
<
row_select
;
i_r
++
)
{
for
(
int
i_c
=
0
;
i_c
<
col
;
i_c
++
)
{
int
i
=
dst_idx
[
i_r
]
*
col
+
i_c
;
pixel
p
=
ck_tile
::
bit_cast
<
pixel
>
(
dst
[
i
]);
bool
is_ok
=
p
.
r
==
src_idx
[
i_r
]
&&
p
.
c
==
i_c
;
if
(
!
is_ok
)
{
if
(
i_c
==
0
)
printf
(
"(%d)pixel: %dx%d -> %d
\n
"
,
i_r
,
p
.
r
,
p
.
c
,
dst_idx
[
i_r
]);
err_cnt
++
;
}
}
}
#if TEST_SCATTER_GATHER_VERBOSE
printf
(
"err:%d
\n
"
,
err_cnt
);
#endif
free
(
src
);
free
(
dst
);
free
(
src_idx
);
free
(
dst_idx
);
return
err_cnt
==
0
?
0
:
-
1
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment