Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
253f942b
Commit
253f942b
authored
Sep 22, 2023
by
Umang Yadav
Browse files
changes to make it compile
parent
8f9c0243
Changes
275
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
730 additions
and
477 deletions
+730
-477
include/ck/utility/container_helper.hpp
include/ck/utility/container_helper.hpp
+5
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+247
-16
include/ck/utility/debug.hpp
include/ck/utility/debug.hpp
+50
-49
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+5
-0
include/ck/utility/enable_if.hpp
include/ck/utility/enable_if.hpp
+12
-3
include/ck/utility/functional.hpp
include/ck/utility/functional.hpp
+5
-0
include/ck/utility/functional2.hpp
include/ck/utility/functional2.hpp
+5
-0
include/ck/utility/functional3.hpp
include/ck/utility/functional3.hpp
+5
-0
include/ck/utility/functional4.hpp
include/ck/utility/functional4.hpp
+5
-0
include/ck/utility/generic_memory_space_atomic.hpp
include/ck/utility/generic_memory_space_atomic.hpp
+5
-0
include/ck/utility/get_id.hpp
include/ck/utility/get_id.hpp
+5
-0
include/ck/utility/ignore.hpp
include/ck/utility/ignore.hpp
+5
-0
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+5
-0
include/ck/utility/integral_constant.hpp
include/ck/utility/integral_constant.hpp
+33
-27
include/ck/utility/is_known_at_compile_time.hpp
include/ck/utility/is_known_at_compile_time.hpp
+5
-0
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+133
-139
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+99
-149
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+91
-94
include/ck/utility/multi_index.hpp
include/ck/utility/multi_index.hpp
+5
-0
include/ck/utility/number.hpp
include/ck/utility/number.hpp
+5
-0
No files found.
include/ck/utility/container_helper.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -391,3 +394,5 @@ __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
...
@@ -391,3 +394,5 @@ __host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
}
// namespace ck
}
// namespace ck
#endif
#endif
#pragma clang diagnostic pop
include/ck/utility/data_type.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -5,6 +8,19 @@
...
@@ -5,6 +8,19 @@
#include "ck/utility/statically_indexed_array.hpp"
#include "ck/utility/statically_indexed_array.hpp"
#ifdef __HIPCC_RTC__
/// Definitions from <cstdint>, <cmath> conflict with
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
using
int8_t
=
signed
char
;
using
uint8_t
=
unsigned
char
;
using
int16_t
=
signed
short
;
using
uint16_t
=
unsigned
short
;
using
float_t
=
float
;
namespace
std
{
using
byte
=
unsigned
char
;
}
#endif // __HIPCC_RTC__
namespace
ck
{
namespace
ck
{
using
bhalf_t
=
ushort
;
using
bhalf_t
=
ushort
;
...
@@ -19,21 +35,22 @@ template <typename T, index_t N>
...
@@ -19,21 +35,22 @@ template <typename T, index_t N>
struct
vector_type
;
struct
vector_type
;
// Caution: DO NOT REMOVE
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation
failure when trying to
// intentionally have only declaration but no definition to cause compilation
// instantiate this template. The purpose is to catch
user's mistake when trying to make "vector of
//
failure when trying to
instantiate this template. The purpose is to catch
// vectors"
//
user's mistake when trying to make "vector of
vectors"
template
<
typename
T
,
index_t
V
,
index_t
N
>
template
<
typename
T
,
index_t
V
,
index_t
N
>
struct
vector_type
<
T
__attribute__
((
ext_vector_type
(
V
))),
N
>
;
struct
vector_type
<
T
__attribute__
((
ext_vector_type
(
V
))),
N
>
;
// Caution: DO NOT REMOVE
// Caution: DO NOT REMOVE
// intentionally have only declaration but no definition to cause compilation
failure when trying to
// intentionally have only declaration but no definition to cause compilation
// instantiate this template. The purpose is to catch
user's mistake when trying to make "vector of
//
failure when trying to
instantiate this template. The purpose is to catch
// vectors"
//
user's mistake when trying to make "vector of
vectors"
template
<
typename
T
,
index_t
V
,
index_t
N
>
template
<
typename
T
,
index_t
V
,
index_t
N
>
struct
vector_type
<
vector_type
<
T
,
V
>
,
N
>
;
struct
vector_type
<
vector_type
<
T
,
V
>
,
N
>
;
// vector_type_maker
// vector_type_maker
// This is the right way to handle "vector of vectors": making a bigger vector instead
// This is the right way to handle "vector of vectors": making a bigger vector
// instead
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
struct
vector_type_maker
struct
vector_type_maker
{
{
...
@@ -960,21 +977,233 @@ using f8x16_t = typename vector_type<f8_t, 16>::type;
...
@@ -960,21 +977,233 @@ using f8x16_t = typename vector_type<f8_t, 16>::type;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
template
<
typename
T
>
// Convert X to Y
struct
NumericLimits
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
{
__host__
__device__
static
constexpr
T
Min
()
{
return
std
::
numeric_limits
<
T
>::
min
();
}
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
();
}
return
static_cast
<
Y
>
(
x
);
}
// convert bfp16 to fp32
template
<
>
inline
__host__
__device__
constexpr
float
type_convert
<
float
,
bhalf_t
>
(
bhalf_t
x
)
{
union
{
uint32_t
int32
;
float
fp32
;
}
u
=
{
uint32_t
(
x
)
<<
16
};
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
return
u
.
fp32
;
}
__host__
__device__
static
constexpr
T
QuietNaN
()
// convert fp32 to bfp16
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
float
fp32
;
}
uint32_t
int32
;
}
u
=
{
x
};
return
uint16_t
(
u
.
int32
>>
16
);
}
// convert bfp16 to fp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
half_t
type_convert
<
half_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
return
static_cast
<
half_t
>
(
x_fp32
);
}
// convert fp16 to bfp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// convert bfp16 to int32 via fp32
template
<
>
inline
__host__
__device__
constexpr
int32_t
type_convert
<
int32_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
return
static_cast
<
int32_t
>
(
x_fp32
);
}
// convert int32 to bfp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
int32_t
>
(
int32_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
// convert bfp16 to int8 via fp32
template
<
>
inline
__host__
__device__
constexpr
int8_t
type_convert
<
int8_t
,
bhalf_t
>
(
bhalf_t
x
)
{
float
x_fp32
=
type_convert
<
float
>
(
x
);
return
static_cast
<
int8_t
>
(
x_fp32
);
}
// convert int8 to bfp16 via fp32
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
int8_t
>
(
int8_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
type_convert
<
bhalf_t
>
(
x_fp32
);
}
__host__
__device__
static
constexpr
T
Infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
();
}
// Declare a template function for bf16 conversion using RTN
template
<
typename
Y
,
typename
X
>
__host__
__device__
constexpr
Y
bf16_convert_rtn
(
X
x
);
// Convert fp32 to bf16 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
float
>
(
float
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
x
};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool
flag0
=
~
u
.
int32
&
0x7f800000
;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool
flag1
=
!
flag0
&&
(
u
.
int32
&
0xffff
);
u
.
int32
+=
flag0
?
0x7fff
+
((
u
.
int32
>>
16
)
&
1
)
:
0
;
// Round to nearest, round to even
u
.
int32
|=
flag1
?
0x10000
:
0x0
;
// Preserve signaling NaN
return
uint16_t
(
u
.
int32
>>
16
);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
bf16_convert_rtn
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
float
x_fp32
=
static_cast
<
float
>
(
x
);
return
bf16_convert_rtn
<
bhalf_t
>
(
x_fp32
);
}
template
<
typename
T
>
struct
NumericLimits
;
template
<
>
struct
NumericLimits
<
int32_t
>
{
__host__
__device__
static
constexpr
int32_t
Lowest
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
int32_t
Min
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
int32_t
Max
()
noexcept
{
return
2147483647
;
}
__host__
__device__
static
constexpr
int32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int16_t
>
{
__host__
__device__
static
constexpr
int16_t
Lowest
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Min
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Max
()
noexcept
{
return
32767
;
}
__host__
__device__
static
constexpr
int16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int8_t
>
{
__host__
__device__
static
constexpr
int8_t
Lowest
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Min
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Max
()
noexcept
{
return
127
;
}
__host__
__device__
static
constexpr
int8_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int8_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint32_t
>
{
__host__
__device__
static
constexpr
uint32_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Max
()
noexcept
{
return
4294967295U
;
}
__host__
__device__
static
constexpr
uint32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint16_t
>
{
__host__
__device__
static
constexpr
uint16_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Max
()
noexcept
{
return
65535U
;
}
__host__
__device__
static
constexpr
uint16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint8_t
>
{
__host__
__device__
static
constexpr
uint8_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint8_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint8_t
Max
()
noexcept
{
return
255U
;
}
__host__
__device__
static
constexpr
uint8_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint8_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
float
>
{
static
constexpr
unsigned
int
binary_min
=
0x00800000
;
static
constexpr
unsigned
int
binary_max
=
0x7F7FFFFF
;
static
constexpr
unsigned
int
binary_lowest
=
0xFF7FFFFF
;
static
constexpr
unsigned
int
binary_qnan
=
0xFFC00001
;
static
constexpr
unsigned
int
binary_inf
=
0x7F8000000
;
__host__
__device__
static
constexpr
float
Min
()
{
return
bit_cast
<
float
>
(
binary_min
);
}
__host__
__device__
static
constexpr
float
Max
()
{
return
bit_cast
<
float
>
(
binary_max
);
}
__host__
__device__
static
constexpr
float
Lowest
()
{
return
bit_cast
<
float
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
float
QuietNaN
()
{
return
bit_cast
<
float
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
float
Infinity
()
{
return
bit_cast
<
float
>
(
binary_inf
);
}
};
};
template
<
>
template
<
>
...
@@ -1024,3 +1253,5 @@ struct NumericLimits<f8_t>
...
@@ -1024,3 +1253,5 @@ struct NumericLimits<f8_t>
};
};
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/debug.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace
ck
{
namespace
ck
{
namespace
debug
{
namespace
debug
{
namespace
detail
{
namespace
detail
{
template
<
typename
T
,
typename
Enable
=
void
>
template
<
typename
T
,
typename
Enable
=
void
>
struct
PrintAsType
;
struct
PrintAsType
;
template
<
typename
T
>
template
<
typename
T
>
struct
PrintAsType
<
T
,
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
>::
type
>
struct
PrintAsType
<
{
T
,
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
>::
type
>
{
using
type
=
float
;
using
type
=
float
;
__host__
__device__
static
void
Print
(
const
T
&
p
)
{
printf
(
"%.3f "
,
static_cast
<
type
>
(
p
));
}
__host__
__device__
static
void
Print
(
const
T
&
p
)
{
printf
(
"%.3f "
,
static_cast
<
type
>
(
p
));
}
};
};
template
<
>
template
<
>
struct
PrintAsType
<
ck
::
half_t
,
void
>
{
struct
PrintAsType
<
ck
::
half_t
,
void
>
using
type
=
float
;
{
__host__
__device__
static
void
Print
(
const
ck
::
half_t
&
p
)
{
using
type
=
float
;
printf
(
"%.3f "
,
static_cast
<
type
>
(
p
));
__host__
__device__
static
void
Print
(
const
ck
::
half_t
&
p
)
}
{
printf
(
"%.3f "
,
static_cast
<
type
>
(
p
));
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
PrintAsType
<
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
>
struct
PrintAsType
<
T
,
{
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
>
{
using
type
=
int
;
using
type
=
int
;
__host__
__device__
static
void
Print
(
const
T
&
p
)
{
printf
(
"%d "
,
static_cast
<
type
>
(
p
));
}
__host__
__device__
static
void
Print
(
const
T
&
p
)
{
printf
(
"%d "
,
static_cast
<
type
>
(
p
));
}
};
};
}
// namespace detail
}
// namespace detail
// Print at runtime the data in shared memory in 128 bytes per row format given
shared mem pointer
// Print at runtime the data in shared memory in 128 bytes per row format given
// and the number of elements. Can optionally specify strides
between elements and how many bytes'
//
shared mem pointer
and the number of elements. Can optionally specify strides
// worth of data per row.
//
between elements and how many bytes'
worth of data per row.
//
//
// Usage example:
// Usage example:
//
//
// debug::print_shared(a_block_buf.p_data_, index_t(a_block_desc_k0_m_k1.GetElementSpaceSize()));
// debug::print_shared(a_block_buf.p_data_,
// index_t(a_block_desc_k0_m_k1.GetElementSpaceSize()));
//
//
template
<
typename
T
,
index_t
element_stride
=
1
,
index_t
row_bytes
=
128
>
template
<
typename
T
,
index_t
element_stride
=
1
,
index_t
row_bytes
=
128
>
__device__
void
print_shared
(
T
const
*
p_shared
,
index_t
num_elements
)
__device__
void
print_shared
(
T
const
*
p_shared
,
index_t
num_elements
)
{
{
constexpr
index_t
row_elements
=
row_bytes
/
sizeof
(
T
);
constexpr
index_t
row_elements
=
row_bytes
/
sizeof
(
T
);
static_assert
((
element_stride
>=
1
&&
element_stride
<=
row_elements
),
static_assert
((
element_stride
>=
1
&&
element_stride
<=
row_elements
),
"element_stride should between [1, row_elements]"
);
"element_stride should between [1, row_elements]"
);
index_t
wgid
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
+
gridDim
.
x
*
gridDim
.
y
*
blockIdx
.
z
;
index_t
wgid
=
index_t
tid
=
blockIdx
.
x
+
blockIdx
.
y
*
gridDim
.
x
+
gridDim
.
x
*
gridDim
.
y
*
blockIdx
.
z
;
(
threadIdx
.
z
*
(
blockDim
.
x
*
blockDim
.
y
))
+
(
threadIdx
.
y
*
blockDim
.
x
)
+
threadIdx
.
x
;
index_t
tid
=
(
threadIdx
.
z
*
(
blockDim
.
x
*
blockDim
.
y
))
+
(
threadIdx
.
y
*
blockDim
.
x
)
+
threadIdx
.
x
;
__syncthreads
();
__syncthreads
();
if
(
tid
==
0
)
if
(
tid
==
0
)
{
{
printf
(
"
\n
Workgroup id %d, bytes per row %d, element stride %d
\n\n
"
,
wgid
,
printf
(
"
\n
Workgroup id %d, bytes per row %d, element stride %d
\n\n
"
,
row_bytes
,
element_stride
);
wgid
,
for
(
index_t
i
=
0
;
i
<
num_elements
;
i
+=
row_elements
)
{
row_bytes
,
printf
(
"elem %5d: "
,
i
);
element_stride
);
for
(
index_t
j
=
0
;
j
<
row_elements
;
j
+=
element_stride
)
{
for
(
index_t
i
=
0
;
i
<
num_elements
;
i
+=
row_elements
)
detail
::
PrintAsType
<
T
>::
Print
(
p_shared
[
i
+
j
]);
{
}
printf
(
"elem %5d: "
,
i
);
for
(
index_t
j
=
0
;
j
<
row_elements
;
j
+=
element_stride
)
{
detail
::
PrintAsType
<
T
>::
Print
(
p_shared
[
i
+
j
]);
}
printf
(
"
\n
"
);
printf
(
"
\n
"
);
}
printf
(
"
\n
"
);
}
}
printf
(
"
\n
"
);
}
__syncthreads
();
__syncthreads
();
}
}
}
// namespace debug
}
// namespace debug
}
// namespace ck
}
// namespace ck
#endif // UTILITY_DEBUG_HPP
#endif // UTILITY_DEBUG_HPP
#pragma clang diagnostic pop
include/ck/utility/dynamic_buffer.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -405,3 +408,5 @@ make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element
...
@@ -405,3 +408,5 @@ make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element
}
}
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/enable_if.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifdef __HIPCC_RTC__
namespace
std
{
template
<
bool
B
,
typename
T
=
void
>
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
}
// namespace std
#endif
namespace
ck
{
namespace
ck
{
template
<
bool
B
,
typename
T
=
void
>
template
<
bool
B
,
typename
T
=
void
>
using
enable_if
=
std
::
enable_if
<
B
,
T
>
;
using
enable_if
=
std
::
enable_if
<
B
,
T
>
;
template
<
bool
B
,
typename
T
=
void
>
template
<
bool
B
,
typename
T
=
void
>
using
enable_if_t
=
typename
std
::
enable_if
<
B
,
T
>::
type
;
using
enable_if_t
=
typename
std
::
enable_if
<
B
,
T
>::
type
;
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/functional.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -129,3 +132,5 @@ constexpr auto conditional_expr(X&& x, Y&& y)
...
@@ -129,3 +132,5 @@ constexpr auto conditional_expr(X&& x, Y&& y)
}
}
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/functional2.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -47,3 +50,5 @@ struct static_for
...
@@ -47,3 +50,5 @@ struct static_for
};
};
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/functional3.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -142,3 +145,5 @@ struct ford
...
@@ -142,3 +145,5 @@ struct ford
};
};
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/functional4.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -63,3 +66,5 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
...
@@ -63,3 +66,5 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
}
// namespace ck
}
// namespace ck
#endif
#endif
#pragma clang diagnostic pop
include/ck/utility/generic_memory_space_atomic.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -121,3 +124,5 @@ __device__ float2_t atomic_max<float2_t>(float2_t* p_dst, const float2_t& x)
...
@@ -121,3 +124,5 @@ __device__ float2_t atomic_max<float2_t>(float2_t* p_dst, const float2_t& x)
}
}
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/get_id.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -26,3 +29,5 @@ __device__ index_t get_grid_size() { return gridDim.x; }
...
@@ -26,3 +29,5 @@ __device__ index_t get_grid_size() { return gridDim.x; }
__device__
index_t
get_block_size
()
{
return
blockDim
.
x
;
}
__device__
index_t
get_block_size
()
{
return
blockDim
.
x
;
}
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/ignore.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -20,3 +23,5 @@ struct ignore_t
...
@@ -20,3 +23,5 @@ struct ignore_t
inline
constexpr
detail
::
ignore_t
ignore
;
inline
constexpr
detail
::
ignore_t
ignore
;
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/inner_product.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -234,3 +237,5 @@ inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t
...
@@ -234,3 +237,5 @@ inner_product<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a, const int8x16_t
}
}
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/integral_constant.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -5,47 +8,50 @@
...
@@ -5,47 +8,50 @@
namespace
ck
{
namespace
ck
{
template
<
class
T
,
T
v
>
template
<
class
T
,
T
v
>
struct
integral_constant
{
struct
integral_constant
static
constexpr
T
value
=
v
;
{
typedef
T
value_type
;
static
constexpr
T
value
=
v
;
typedef
integral_constant
type
;
typedef
T
value_type
;
__host__
__device__
constexpr
operator
value_type
()
const
noexcept
{
typedef
integral_constant
type
;
return
value
;
__host__
__device__
constexpr
operator
value_type
()
const
noexcept
{
return
value
;
}
}
__host__
__device__
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
__host__
__device__
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
};
};
template
<
typename
TX
,
TX
X
,
typename
TY
,
TY
Y
>
template
<
typename
TX
,
TX
X
,
typename
TY
,
TY
Y
>
__host__
__device__
constexpr
auto
operator
+
(
integral_constant
<
TX
,
X
>
,
integral_constant
<
TY
,
Y
>
)
__host__
__device__
constexpr
auto
operator
+
(
integral_constant
<
TX
,
X
>
,
{
integral_constant
<
TY
,
Y
>
)
{
return
integral_constant
<
decltype
(
X
+
Y
),
X
+
Y
>
{};
return
integral_constant
<
decltype
(
X
+
Y
),
X
+
Y
>
{};
}
}
template
<
typename
TX
,
TX
X
,
typename
TY
,
TY
Y
>
template
<
typename
TX
,
TX
X
,
typename
TY
,
TY
Y
>
__host__
__device__
constexpr
auto
operator
-
(
integral_constant
<
TX
,
X
>
,
integral_constant
<
TY
,
Y
>
)
__host__
__device__
constexpr
auto
operator
-
(
integral_constant
<
TX
,
X
>
,
{
integral_constant
<
TY
,
Y
>
)
{
static_assert
(
Y
<=
X
,
"wrong!"
);
static_assert
(
Y
<=
X
,
"wrong!"
);
return
integral_constant
<
decltype
(
X
-
Y
),
X
-
Y
>
{};
return
integral_constant
<
decltype
(
X
-
Y
),
X
-
Y
>
{};
}
}
template
<
typename
TX
,
TX
X
,
typename
TY
,
TY
Y
>
template
<
typename
TX
,
TX
X
,
typename
TY
,
TY
Y
>
__host__
__device__
constexpr
auto
operator
*
(
integral_constant
<
TX
,
X
>
,
integral_constant
<
TY
,
Y
>
)
__host__
__device__
constexpr
auto
operator
*
(
integral_constant
<
TX
,
X
>
,
{
integral_constant
<
TY
,
Y
>
)
{
return
integral_constant
<
decltype
(
X
*
Y
),
X
*
Y
>
{};
return
integral_constant
<
decltype
(
X
*
Y
),
X
*
Y
>
{};
}
}
template
<
typename
TX
,
TX
X
,
typename
TY
,
TY
Y
>
template
<
typename
TX
,
TX
X
,
typename
TY
,
TY
Y
>
__host__
__device__
constexpr
auto
operator
/
(
integral_constant
<
TX
,
X
>
,
integral_constant
<
TY
,
Y
>
)
__host__
__device__
constexpr
auto
operator
/
(
integral_constant
<
TX
,
X
>
,
{
integral_constant
<
TY
,
Y
>
)
{
static_assert
(
Y
>
0
,
"wrong!"
);
static_assert
(
Y
>
0
,
"wrong!"
);
return
integral_constant
<
decltype
(
X
/
Y
),
X
/
Y
>
{};
return
integral_constant
<
decltype
(
X
/
Y
),
X
/
Y
>
{};
}
}
template
<
typename
TX
,
TX
X
,
typename
TY
,
TY
Y
>
template
<
typename
TX
,
TX
X
,
typename
TY
,
TY
Y
>
__host__
__device__
constexpr
auto
operator
%
(
integral_constant
<
TX
,
X
>
,
integral_constant
<
TY
,
Y
>
)
__host__
__device__
constexpr
auto
operator
%
(
integral_constant
<
TX
,
X
>
,
{
integral_constant
<
TY
,
Y
>
)
{
static_assert
(
Y
>
0
,
"wrong!"
);
static_assert
(
Y
>
0
,
"wrong!"
);
return
integral_constant
<
decltype
(
X
%
Y
),
X
%
Y
>
{};
return
integral_constant
<
decltype
(
X
%
Y
),
X
%
Y
>
{};
}
}
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/is_known_at_compile_time.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -54,3 +57,5 @@ struct is_known_at_compile_time<Tuple<Ts...>>
...
@@ -54,3 +57,5 @@ struct is_known_at_compile_time<Tuple<Ts...>>
};
};
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/magic_division.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -6,155 +9,144 @@
...
@@ -6,155 +9,144 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "number.hpp"
#include "type.hpp"
#include "tuple.hpp"
#include "tuple.hpp"
#include "type.hpp"
#define INT32_MAX 2147483647
namespace
ck
{
namespace
ck
{
// magic number division
// magic number division
// Caution:
// Caution:
// 1. For uint32_t as dividend: magic number division implementation being used would produce
// 1. For uint32_t as dividend: magic number division implementation being
// correct result if the dividend is uint32_t and its value is within 31-bit value range.
// used would produce correct result if the dividend is uint32_t and its value
// 2. For int32_t as dividendd: magic number division for int32_t dividened has not been
// is within 31-bit value range.
// implemented, the int32_t dividend would be bit-wise interpreted as uint32_t and magic number
// 2. For int32_t as dividendd: magic number division for int32_t dividened
// division implementation for uint32_t is then used. Therefore, dividend value need to be
// has not been implemented, the int32_t dividend would be bit-wise
// non-negative.
// interpreted as uint32_t and magic number division implementation for
// uint32_t is then used. Therefore, dividend value need to be non-negative.
// TODO:
// TODO:
// 1. Implement magic number divison for int32_t
// 1. Implement magic number divison for int32_t
// 2. Implement magic number divison for unit32_t with 32-bit value range
// 2. Implement magic number divison for unit32_t with 32-bit value range
struct
MagicDivision
struct
MagicDivision
{
{
// uint32_t
// uint32_t
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
CalculateMagicNumbers
(
uint32_t
divisor
)
CalculateMagicNumbers
(
uint32_t
divisor
)
{
{
// WARNING: magic division is only applicable for division inside this
// WARNING: magic division is only applicable for division inside this range.
// range. You should use the return value of CalculateMagicNumbers, if
// You should use the return value of CalculateMagicNumbers, if division is not inside this
// division is not inside this range. The "else" logic below is to quiet
// range. The "else" logic below is to quiet down run-time error.
// down run-time error.
if
(
divisor
>=
1
&&
divisor
<=
INT32_MAX
)
if
(
divisor
>=
1
&&
divisor
<=
INT32_MAX
)
{
{
uint32_t
shift
=
0
;
uint32_t
shift
=
0
;
for
(
shift
=
0
;
shift
<
32
;
++
shift
)
{
for
(
shift
=
0
;
shift
<
32
;
++
shift
)
if
((
1U
<<
shift
)
>=
divisor
)
{
{
break
;
if
((
1U
<<
shift
)
>=
divisor
)
{
break
;
}
}
uint64_t
one
=
1
;
uint64_t
multiplier
=
((
one
<<
32
)
*
((
one
<<
shift
)
-
divisor
))
/
divisor
+
1
;
// assert(multiplier <= 0xffffffffUL);
return
make_tuple
(
uint32_t
(
multiplier
),
shift
);
}
else
{
return
make_tuple
(
uint32_t
(
0
),
uint32_t
(
0
));
}
}
}
}
__host__
__device__
static
constexpr
uint32_t
CalculateMagicMultiplier
(
uint32_t
divisor
)
uint64_t
one
=
1
;
{
uint64_t
multiplier
=
auto
tmp
=
CalculateMagicNumbers
(
divisor
);
((
one
<<
32
)
*
((
one
<<
shift
)
-
divisor
))
/
divisor
+
1
;
// assert(multiplier <= 0xffffffffUL);
return
tmp
[
Number
<
0
>
{}];
}
return
make_tuple
(
uint32_t
(
multiplier
),
shift
);
}
else
{
__host__
__device__
static
constexpr
uint32_t
CalculateMagicShift
(
uint32_t
divisor
)
return
make_tuple
(
uint32_t
(
0
),
uint32_t
(
0
));
{
}
auto
tmp
=
CalculateMagicNumbers
(
divisor
);
}
return
tmp
[
Number
<
1
>
{}];
__host__
__device__
static
constexpr
uint32_t
}
CalculateMagicMultiplier
(
uint32_t
divisor
)
{
auto
tmp
=
CalculateMagicNumbers
(
divisor
);
// integral_constant<uint32_t, .>
template
<
uint32_t
Divisor
>
return
tmp
[
Number
<
0
>
{}];
__host__
__device__
static
constexpr
auto
}
CalculateMagicNumbers
(
integral_constant
<
uint32_t
,
Divisor
>
)
{
__host__
__device__
static
constexpr
uint32_t
constexpr
auto
tmp
=
CalculateMagicNumbers
(
uint32_t
{
Divisor
});
CalculateMagicShift
(
uint32_t
divisor
)
{
auto
tmp
=
CalculateMagicNumbers
(
divisor
);
constexpr
uint32_t
multiplier
=
tmp
[
Number
<
0
>
{}];
constexpr
uint32_t
shift
=
tmp
[
Number
<
1
>
{}];
return
tmp
[
Number
<
1
>
{}];
}
return
make_tuple
(
integral_constant
<
uint32_t
,
multiplier
>
{},
integral_constant
<
uint32_t
,
shift
>
{});
// integral_constant<uint32_t, .>
}
template
<
uint32_t
Divisor
>
__host__
__device__
static
constexpr
auto
template
<
uint32_t
Divisor
>
CalculateMagicNumbers
(
integral_constant
<
uint32_t
,
Divisor
>
)
{
__host__
__device__
static
constexpr
auto
constexpr
auto
tmp
=
CalculateMagicNumbers
(
uint32_t
{
Divisor
});
CalculateMagicMultiplier
(
integral_constant
<
uint32_t
,
Divisor
>
)
{
constexpr
uint32_t
multiplier
=
tmp
[
Number
<
0
>
{}];
constexpr
uint32_t
multiplier
=
CalculateMagicMultiplier
(
uint32_t
{
Divisor
});
constexpr
uint32_t
shift
=
tmp
[
Number
<
1
>
{}];
return
integral_constant
<
uint32_t
,
multiplier
>
{};
return
make_tuple
(
integral_constant
<
uint32_t
,
multiplier
>
{},
}
integral_constant
<
uint32_t
,
shift
>
{});
}
template
<
uint32_t
Divisor
>
__host__
__device__
static
constexpr
auto
template
<
uint32_t
Divisor
>
CalculateMagicShift
(
integral_constant
<
uint32_t
,
Divisor
>
)
__host__
__device__
static
constexpr
auto
{
CalculateMagicMultiplier
(
integral_constant
<
uint32_t
,
Divisor
>
)
{
constexpr
uint32_t
shift
=
CalculateMagicShift
(
uint32_t
{
Divisor
});
constexpr
uint32_t
multiplier
=
CalculateMagicMultiplier
(
uint32_t
{
Divisor
});
return
integral_constant
<
uint32_t
,
shift
>
{};
return
integral_constant
<
uint32_t
,
multiplier
>
{};
}
}
// integral_constant<int32_t, .>
template
<
uint32_t
Divisor
>
template
<
int32_t
Divisor
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
CalculateMagicShift
(
integral_constant
<
uint32_t
,
Divisor
>
)
{
CalculateMagicNumbers
(
integral_constant
<
int32_t
,
Divisor
>
)
constexpr
uint32_t
shift
=
CalculateMagicShift
(
uint32_t
{
Divisor
});
{
return
CalculateMagicNumbers
(
integral_constant
<
uint32_t
,
Divisor
>
{});
return
integral_constant
<
uint32_t
,
shift
>
{};
}
}
template
<
int32_t
Divisor
>
// integral_constant<int32_t, .>
__host__
__device__
static
constexpr
auto
template
<
int32_t
Divisor
>
CalculateMagicMultiplier
(
integral_constant
<
int32_t
,
Divisor
>
)
__host__
__device__
static
constexpr
auto
{
CalculateMagicNumbers
(
integral_constant
<
int32_t
,
Divisor
>
)
{
return
CalculateMagicMultiplier
(
integral_constant
<
uint32_t
,
Divisor
>
{});
return
CalculateMagicNumbers
(
integral_constant
<
uint32_t
,
Divisor
>
{});
}
}
template
<
int32_t
Divisor
>
template
<
int32_t
Divisor
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
CalculateMagicShift
(
integral_constant
<
int32_t
,
Divisor
>
)
CalculateMagicMultiplier
(
integral_constant
<
int32_t
,
Divisor
>
)
{
{
return
CalculateMagicMultiplier
(
integral_constant
<
uint32_t
,
Divisor
>
{});
return
CalculateMagicShift
(
integral_constant
<
uint32_t
,
Divisor
>
{});
}
}
template
<
int32_t
Divisor
>
// magic division for uint32_t
__host__
__device__
static
constexpr
auto
__device__
static
constexpr
uint32_t
CalculateMagicShift
(
integral_constant
<
int32_t
,
Divisor
>
)
{
DoMagicDivision
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
return
CalculateMagicShift
(
integral_constant
<
uint32_t
,
Divisor
>
{});
{
}
uint32_t
tmp
=
__umulhi
(
dividend
,
multiplier
);
return
(
tmp
+
dividend
)
>>
shift
;
// magic division for uint32_t
}
__device__
static
constexpr
uint32_t
DoMagicDivision
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
__host__
static
constexpr
uint32_t
uint32_t
tmp
=
__umulhi
(
dividend
,
multiplier
);
DoMagicDivision
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
return
(
tmp
+
dividend
)
>>
shift
;
{
}
uint32_t
tmp
=
static_cast
<
uint64_t
>
(
dividend
)
*
multiplier
>>
32
;
return
(
tmp
+
dividend
)
>>
shift
;
__host__
static
constexpr
uint32_t
}
DoMagicDivision
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
tmp
=
static_cast
<
uint64_t
>
(
dividend
)
*
multiplier
>>
32
;
// magic division for int32_t
return
(
tmp
+
dividend
)
>>
shift
;
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
}
// non-negative for result to be correct
// TODO: figure out how to do magic number divison for int32_t as dividended
// magic division for int32_t
__device__
static
constexpr
int32_t
// HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be
DoMagicDivision
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
// non-negative for result to be correct
{
// TODO: figure out how to do magic number divison for int32_t as dividended
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
__device__
static
constexpr
int32_t
uint32_t
tmp
=
__umulhi
(
dividend_u32
,
multiplier
);
DoMagicDivision
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
return
(
tmp
+
dividend_u32
)
>>
shift
;
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
}
uint32_t
tmp
=
__umulhi
(
dividend_u32
,
multiplier
);
return
(
tmp
+
dividend_u32
)
>>
shift
;
__host__
static
constexpr
int32_t
}
DoMagicDivision
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
__host__
static
constexpr
int32_t
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
DoMagicDivision
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
uint32_t
tmp
=
static_cast
<
uint64_t
>
(
dividend_u32
)
*
multiplier
>>
32
;
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
return
(
tmp
+
dividend_u32
)
>>
shift
;
uint32_t
tmp
=
static_cast
<
uint64_t
>
(
dividend_u32
)
*
multiplier
>>
32
;
}
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
};
};
struct
MDiv
struct
MDiv
...
@@ -230,3 +222,5 @@ struct MDiv2
...
@@ -230,3 +222,5 @@ struct MDiv2
};
};
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/math.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "enable_if.hpp"
#include "integral_constant.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "number.hpp"
#include "type.hpp"
#include "type.hpp"
#include "enable_if.hpp"
namespace
ck
{
namespace
ck
{
namespace
math
{
namespace
math
{
template
<
typename
T
,
T
s
>
template
<
typename
T
,
T
s
>
struct
scales
{
struct
scales
__host__
__device__
constexpr
T
operator
()(
T
a
)
const
{
return
s
*
a
;
}
{
__host__
__device__
constexpr
T
operator
()(
T
a
)
const
{
return
s
*
a
;
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
plus
{
struct
plus
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
minus
{
struct
minus
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
-
b
;
}
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
-
b
;
}
};
};
struct
multiplies
struct
multiplies
{
{
template
<
typename
A
,
typename
B
>
template
<
typename
A
,
typename
B
>
__host__
__device__
constexpr
auto
operator
()(
const
A
&
a
,
const
B
&
b
)
const
{
__host__
__device__
constexpr
auto
operator
()(
const
A
&
a
,
const
B
&
b
)
const
return
a
*
b
;
{
}
return
a
*
b
;
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
maximize
{
struct
maximize
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
{
return
a
>=
b
?
a
:
b
;
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
>=
b
?
a
:
b
;
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
minimize
{
struct
minimize
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
{
return
a
<=
b
?
a
:
b
;
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<=
b
?
a
:
b
;
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
integer_divide_ceiler
{
struct
integer_divide_ceiler
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
{
static_assert
(
is_same
<
T
,
index_t
>
{}
||
is_same
<
T
,
int
>
{},
"wrong type"
);
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
index_t
>
{}
||
is_same
<
T
,
int
>
{},
"wrong type"
);
return
(
a
+
b
-
Number
<
1
>
{})
/
b
;
return
(
a
+
b
-
Number
<
1
>
{})
/
b
;
}
}
};
};
template
<
typename
X
,
typename
Y
>
template
<
typename
X
,
typename
Y
>
__host__
__device__
constexpr
auto
integer_divide_floor
(
X
x
,
Y
y
)
__host__
__device__
constexpr
auto
integer_divide_floor
(
X
x
,
Y
y
)
{
{
return
x
/
y
;
return
x
/
y
;
}
}
template
<
typename
X
,
typename
Y
>
template
<
typename
X
,
typename
Y
>
__host__
__device__
constexpr
auto
integer_divide_ceil
(
X
x
,
Y
y
)
__host__
__device__
constexpr
auto
integer_divide_ceil
(
X
x
,
Y
y
)
{
{
return
(
x
+
y
-
Number
<
1
>
{})
/
y
;
return
(
x
+
y
-
Number
<
1
>
{})
/
y
;
}
}
template
<
typename
X
,
typename
Y
>
template
<
typename
X
,
typename
Y
>
__host__
__device__
constexpr
auto
integer_least_multiple
(
X
x
,
Y
y
)
__host__
__device__
constexpr
auto
integer_least_multiple
(
X
x
,
Y
y
)
{
{
return
y
*
integer_divide_ceil
(
x
,
y
);
return
y
*
integer_divide_ceil
(
x
,
y
);
}
}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
constexpr
T
max
(
T
x
)
{
return
x
;
}
__host__
__device__
constexpr
T
max
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
constexpr
T
max
(
T
x
,
T
y
)
{
__host__
__device__
constexpr
T
max
(
T
x
,
T
y
)
return
x
>
y
?
x
:
y
;
{
return
x
>
y
?
x
:
y
;
}
}
template
<
index_t
X
>
template
<
index_t
X
>
__host__
__device__
constexpr
index_t
max
(
Number
<
X
>
,
index_t
y
)
__host__
__device__
constexpr
index_t
max
(
Number
<
X
>
,
index_t
y
)
{
{
return
X
>
y
?
X
:
y
;
return
X
>
y
?
X
:
y
;
}
}
template
<
index_t
Y
>
template
<
index_t
Y
>
__host__
__device__
constexpr
index_t
max
(
index_t
x
,
Number
<
Y
>
)
__host__
__device__
constexpr
index_t
max
(
index_t
x
,
Number
<
Y
>
)
{
{
return
x
>
Y
?
x
:
Y
;
return
x
>
Y
?
x
:
Y
;
}
}
template
<
typename
X
,
typename
...
Ys
>
template
<
typename
X
,
typename
...
Ys
>
__host__
__device__
constexpr
auto
max
(
X
x
,
Ys
...
ys
)
__host__
__device__
constexpr
auto
max
(
X
x
,
Ys
...
ys
)
{
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
max
(
x
,
max
(
ys
...));
return
max
(
x
,
max
(
ys
...));
}
}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
constexpr
T
min
(
T
x
)
{
return
x
;
}
__host__
__device__
constexpr
T
min
(
T
x
)
{
return
x
;
}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
constexpr
T
min
(
T
x
,
T
y
)
{
__host__
__device__
constexpr
T
min
(
T
x
,
T
y
)
return
x
<
y
?
x
:
y
;
{
return
x
<
y
?
x
:
y
;
}
}
template
<
index_t
X
>
template
<
index_t
X
>
__host__
__device__
constexpr
index_t
min
(
Number
<
X
>
,
index_t
y
)
__host__
__device__
constexpr
index_t
min
(
Number
<
X
>
,
index_t
y
)
{
{
return
X
<
y
?
X
:
y
;
return
X
<
y
?
X
:
y
;
}
}
template
<
index_t
Y
>
template
<
index_t
Y
>
__host__
__device__
constexpr
index_t
min
(
index_t
x
,
Number
<
Y
>
)
__host__
__device__
constexpr
index_t
min
(
index_t
x
,
Number
<
Y
>
)
{
{
return
x
<
Y
?
x
:
Y
;
return
x
<
Y
?
x
:
Y
;
}
}
template
<
typename
X
,
typename
...
Ys
>
template
<
typename
X
,
typename
...
Ys
>
__host__
__device__
constexpr
auto
min
(
X
x
,
Ys
...
ys
)
__host__
__device__
constexpr
auto
min
(
X
x
,
Ys
...
ys
)
{
{
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
static_assert
(
sizeof
...(
Ys
)
>
0
,
"not enough argument"
);
return
min
(
x
,
min
(
ys
...));
return
min
(
x
,
min
(
ys
...));
}
}
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
constexpr
T
clamp
(
const
T
&
x
,
const
T
&
lowerbound
,
const
T
&
upperbound
)
__host__
__device__
constexpr
T
clamp
(
const
T
&
x
,
const
T
&
lowerbound
,
{
const
T
&
upperbound
)
{
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
return
min
(
max
(
x
,
lowerbound
),
upperbound
);
}
}
// disallow implicit type casting
// disallow implicit type casting
template
<
typename
T
>
template
<
typename
T
>
__device__
T
exp
(
T
x
);
__device__
T
exp
(
T
x
);
// TODO: add f16 support using v_exp_f16
// TODO: add f16 support using v_exp_f16
template
<
>
template
<
>
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
}
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
}
template
<
>
template
<
>
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
}
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
}
static
inline
__host__
float
exp
(
float
x
)
{
return
::
expf
(
x
);
}
//
static inline __host__ float exp(float x) { return ::expf(x); }
static
inline
__host__
double
exp
(
double
x
)
{
return
std
::
exp
(
x
);
}
//
static inline __host__ double exp(double x) { return std::exp(x); }
// greatest common divisor, aka highest common factor
// greatest common divisor, aka highest common factor
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
{
{
if
(
x
<
0
)
{
if
(
x
<
0
)
return
gcd
(
-
x
,
y
);
{
}
else
if
(
y
<
0
)
{
return
gcd
(
-
x
,
y
);
return
gcd
(
x
,
-
y
);
}
}
else
if
(
x
==
y
||
x
==
0
)
{
else
if
(
y
<
0
)
return
y
;
{
}
else
if
(
y
==
0
)
{
return
gcd
(
x
,
-
y
);
return
x
;
}
}
else
if
(
x
>
y
)
{
else
if
(
x
==
y
||
x
==
0
)
return
gcd
(
x
%
y
,
y
);
{
}
else
{
return
y
;
return
gcd
(
x
,
y
%
x
);
}
}
else
if
(
y
==
0
)
{
return
x
;
}
else
if
(
x
>
y
)
{
return
gcd
(
x
%
y
,
y
);
}
else
{
return
gcd
(
x
,
y
%
x
);
}
}
}
template
<
index_t
X
,
index_t
Y
>
template
<
index_t
X
,
index_t
Y
>
__host__
__device__
constexpr
auto
gcd
(
Number
<
X
>
,
Number
<
Y
>
)
__host__
__device__
constexpr
auto
gcd
(
Number
<
X
>
,
Number
<
Y
>
)
{
{
constexpr
auto
r
=
gcd
(
X
,
Y
);
constexpr
auto
r
=
gcd
(
X
,
Y
);
return
Number
<
r
>
{};
return
Number
<
r
>
{};
}
}
template
<
typename
X
,
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
template
<
typename
X
,
typename
...
Ys
,
__host__
__device__
constexpr
auto
gcd
(
X
x
,
Ys
...
ys
)
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
{
__host__
__device__
constexpr
auto
gcd
(
X
x
,
Ys
...
ys
)
{
return
gcd
(
x
,
gcd
(
ys
...));
return
gcd
(
x
,
gcd
(
ys
...));
}
}
// least common multiple
// least common multiple
template
<
typename
X
,
typename
Y
>
template
<
typename
X
,
typename
Y
>
__host__
__device__
constexpr
auto
lcm
(
X
x
,
Y
y
)
__host__
__device__
constexpr
auto
lcm
(
X
x
,
Y
y
)
{
{
return
(
x
*
y
)
/
gcd
(
x
,
y
);
return
(
x
*
y
)
/
gcd
(
x
,
y
);
}
}
template
<
typename
X
,
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
template
<
typename
X
,
typename
...
Ys
,
__host__
__device__
constexpr
auto
lcm
(
X
x
,
Ys
...
ys
)
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
{
__host__
__device__
constexpr
auto
lcm
(
X
x
,
Ys
...
ys
)
{
return
lcm
(
x
,
lcm
(
ys
...));
return
lcm
(
x
,
lcm
(
ys
...));
}
}
template
<
typename
T
>
template
<
typename
T
>
struct
equal
{
struct
equal
__host__
__device__
constexpr
bool
operator
()(
T
x
,
T
y
)
const
{
{
return
x
==
y
;
__host__
__device__
constexpr
bool
operator
()(
T
x
,
T
y
)
const
{
return
x
==
y
;
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
less
{
struct
less
__host__
__device__
constexpr
bool
operator
()(
T
x
,
T
y
)
const
{
{
return
x
<
y
;
__host__
__device__
constexpr
bool
operator
()(
T
x
,
T
y
)
const
{
return
x
<
y
;
}
}
};
};
template
<
index_t
X
>
template
<
index_t
X
>
...
@@ -258,3 +206,5 @@ __host__ __device__ constexpr auto next_power_of_two(Number<X> x)
...
@@ -258,3 +206,5 @@ __host__ __device__ constexpr auto next_power_of_two(Number<X> x)
}
// namespace math
}
// namespace math
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/math_v2.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -13,177 +16,169 @@
...
@@ -13,177 +16,169 @@
namespace
ck
{
namespace
ck
{
namespace
math
{
namespace
math
{
// math functions for the host, some are implemented by calling C++ std functions
// math functions for the host, some are implemented by calling C++ std
// functions
static
inline
__host__
float
abs
(
float
x
)
{
return
std
::
abs
(
x
)
;
};
static
inline
__host__
float
abs
(
float
x
)
{
return
x
<
0
?
x
*
-
1.0
:
x
;
};
static
inline
__host__
double
abs
(
double
x
)
{
return
std
::
abs
(
x
)
;
};
static
inline
__host__
double
abs
(
double
x
)
{
return
x
<
0
?
x
*
-
1.0
:
x
;
};
static
inline
__host__
int8_t
abs
(
int8_t
x
)
static
inline
__host__
int8_t
abs
(
int8_t
x
)
{
{
int8_t
sgn
=
x
>>
(
8
-
1
);
int8_t
sgn
=
x
>>
(
8
-
1
);
return
(
x
^
sgn
)
-
sgn
;
return
(
x
^
sgn
)
-
sgn
;
};
};
static
inline
__host__
int32_t
abs
(
int32_t
x
)
static
inline
__host__
int32_t
abs
(
int32_t
x
)
{
{
int32_t
sgn
=
x
>>
(
32
-
1
);
int32_t
sgn
=
x
>>
(
32
-
1
);
return
(
x
^
sgn
)
-
sgn
;
return
(
x
^
sgn
)
-
sgn
;
};
};
static
inline
__host__
half_t
abs
(
half_t
x
)
static
inline
__host__
half_t
abs
(
half_t
x
)
{
{
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
uint16_t
abs_xx
=
xx
&
0x7fff
;
half_t
abs_x
=
ck
::
bit_cast
<
half_t
>
(
abs_xx
);
half_t
abs_x
=
ck
::
bit_cast
<
half_t
>
(
abs_xx
);
return
abs_x
;
return
abs_x
;
};
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__host__
int4_t
abs
(
int4_t
x
)
static
inline
__host__
int4_t
abs
(
int4_t
x
)
{
{
int4_t
sgn
=
x
>>
(
4
-
1
);
int4_t
sgn
=
x
>>
(
4
-
1
);
return
(
x
^
sgn
)
-
sgn
;
return
(
x
^
sgn
)
-
sgn
;
}
}
#endif
#endif
static
inline
__host__
bool
isnan
(
float
x
)
{
return
std
::
isnan
(
x
);
};
// TODO: to bit arithmetic to figure it out
static
inline
__host__
bool
isnan
(
float
x
)
{
(
void
)
x
;
return
false
;
};
static
inline
__host__
bool
isnan
(
double
x
)
{
return
std
::
isnan
(
x
);
};
static
inline
__host__
bool
isnan
(
double
x
)
{
(
void
)
x
;
return
false
;
};
static
inline
__host__
bool
isnan
(
int8_t
x
)
static
inline
__host__
bool
isnan
(
int8_t
x
)
{
{
(
void
)
x
;
(
void
)
x
;
return
false
;
return
false
;
};
};
static
inline
__host__
bool
isnan
(
int32_t
x
)
static
inline
__host__
bool
isnan
(
int32_t
x
)
{
{
(
void
)
x
;
(
void
)
x
;
return
false
;
return
false
;
};
};
static
inline
__host__
bool
isnan
(
half_t
x
)
static
inline
__host__
bool
isnan
(
half_t
x
)
{
{
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__host__
bool
isnan
(
int4_t
x
)
static
inline
__host__
bool
isnan
(
int4_t
x
)
{
{
(
void
)
x
;
(
void
)
x
;
return
false
;
return
false
;
};
};
#endif
#endif
static
inline
__host__
half_t
sqrt
(
half_t
x
)
// MIGRAPHX doesn't care about host compilation, just return identity values for
{
// now
return
static_cast
<
half_t
>
(
std
::
sqrt
(
static_cast
<
float
>
(
x
)));
};
static
inline
__host__
floa
t
sqrt
(
floa
t
x
)
{
return
std
::
sqrt
(
x
)
;
};
static
inline
__host__
half_
t
sqrt
(
half_
t
x
)
{
return
x
;
};
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
)
;
};
static
inline
__host__
float
sqrt
(
float
x
)
{
return
x
;
};
static
inline
__host__
half_t
tanh
(
half_t
x
)
static
inline
__host__
double
sqrt
(
double
x
)
{
return
x
;
};
{
return
static_cast
<
half_t
>
(
std
::
tanh
(
static_cast
<
float
>
(
x
)));
};
static
inline
__host__
floa
t
tanh
(
floa
t
x
)
{
return
std
::
tanh
(
x
)
;
};
static
inline
__host__
half_
t
tanh
(
half_
t
x
)
{
return
x
;
};
static
inline
__host__
double
tanh
(
double
x
)
{
return
std
::
tanh
(
x
)
;
};
static
inline
__host__
float
tanh
(
float
x
)
{
return
x
;
};
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static
inline
__host__
double
tanh
(
double
x
)
{
return
x
;
};
// math functions for the HIP kernel, some are implemented by calling hip
// builtin functions
static
inline
__device__
float
abs
(
float
x
)
{
return
::
abs
(
x
);
};
static
inline
__device__
float
abs
(
float
x
)
{
return
::
abs
(
x
);
};
static
inline
__device__
double
abs
(
double
x
)
{
return
::
abs
(
x
);
};
static
inline
__device__
double
abs
(
double
x
)
{
return
::
abs
(
x
);
};
static
inline
__device__
int8_t
abs
(
int8_t
x
)
static
inline
__device__
int8_t
abs
(
int8_t
x
)
{
{
int8_t
sgn
=
x
>>
(
8
-
1
);
int8_t
sgn
=
x
>>
(
8
-
1
);
return
(
x
^
sgn
)
-
sgn
;
return
(
x
^
sgn
)
-
sgn
;
};
};
static
inline
__device__
int32_t
abs
(
int32_t
x
)
static
inline
__device__
int32_t
abs
(
int32_t
x
)
{
{
int32_t
sgn
=
x
>>
(
32
-
1
);
int32_t
sgn
=
x
>>
(
32
-
1
);
return
(
x
^
sgn
)
-
sgn
;
return
(
x
^
sgn
)
-
sgn
;
};
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__device__
int4_t
abs
(
int4_t
x
)
static
inline
__device__
int4_t
abs
(
int4_t
x
)
{
{
int4_t
sgn
=
x
>>
(
4
-
1
);
int4_t
sgn
=
x
>>
(
4
-
1
);
return
(
x
^
sgn
)
-
sgn
;
return
(
x
^
sgn
)
-
sgn
;
};
};
#endif
#endif
static
inline
__device__
half_t
abs
(
half_t
x
)
static
inline
__device__
half_t
abs
(
half_t
x
)
{
{
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
uint16_t
abs_xx
=
xx
&
0x7fff
;
half_t
abs_x
=
ck
::
bit_cast
<
half_t
>
(
abs_xx
);
half_t
abs_x
=
ck
::
bit_cast
<
half_t
>
(
abs_xx
);
return
abs_x
;
return
abs_x
;
};
};
static
inline
__device__
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
static
inline
__device__
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
static
inline
__device__
bool
isnan
(
double
x
)
{
return
::
isnan
(
x
);
};
static
inline
__device__
bool
isnan
(
double
x
)
{
return
::
isnan
(
x
);
};
static
inline
__device__
bool
isnan
(
int8_t
x
)
static
inline
__device__
bool
isnan
(
int8_t
x
)
{
{
(
void
)
x
;
(
void
)
x
;
return
false
;
return
false
;
};
};
static
inline
__device__
bool
isnan
(
int32_t
x
)
static
inline
__device__
bool
isnan
(
int32_t
x
)
{
{
(
void
)
x
;
(
void
)
x
;
return
false
;
return
false
;
};
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__device__
bool
isnan
(
int4_t
x
)
static
inline
__device__
bool
isnan
(
int4_t
x
)
{
{
(
void
)
x
;
(
void
)
x
;
return
false
;
return
false
;
};
};
#endif
#endif
static
inline
__device__
bool
isnan
(
half_t
x
)
static
inline
__device__
bool
isnan
(
half_t
x
)
{
{
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
};
static
inline
__device__
half_t
sqrt
(
half_t
x
)
static
inline
__device__
half_t
sqrt
(
half_t
x
)
{
{
return
static_cast
<
half_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
return
static_cast
<
half_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
};
static
inline
__device__
float
sqrt
(
float
x
)
{
return
__builtin_amdgcn_sqrtf
(
x
);
};
static
inline
__device__
float
sqrt
(
float
x
)
{
return
__builtin_amdgcn_sqrtf
(
x
);
};
static
inline
__device__
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
static
inline
__device__
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
static
inline
__device__
half_t
tanh
(
half_t
x
)
static
inline
__device__
half_t
tanh
(
half_t
x
)
{
{
return
static_cast
<
half_t
>
(
::
tanhf
(
static_cast
<
float
>
(
x
)));
return
static_cast
<
half_t
>
(
::
tanhf
(
static_cast
<
float
>
(
x
)));
};
};
static
inline
__device__
float
tanh
(
float
x
)
{
return
::
tanhf
(
x
);
};
static
inline
__device__
float
tanh
(
float
x
)
{
return
::
tanhf
(
x
);
};
...
@@ -192,3 +187,5 @@ static inline __device__ double tanh(double x) { return ::tanh(x); };
...
@@ -192,3 +187,5 @@ static inline __device__ double tanh(double x) { return ::tanh(x); };
}
// namespace math
}
// namespace math
}
// namespace ck
}
// namespace ck
#pragma clang diagnostic pop
include/ck/utility/multi_index.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -10,3 +13,5 @@
...
@@ -10,3 +13,5 @@
#else
#else
#include "statically_indexed_array_multi_index.hpp"
#include "statically_indexed_array_multi_index.hpp"
#endif
#endif
#pragma clang diagnostic pop
include/ck/utility/number.hpp
View file @
253f942b
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything"
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
@@ -16,3 +19,5 @@ using LongNumber = integral_constant<long_index_t, N>;
...
@@ -16,3 +19,5 @@ using LongNumber = integral_constant<long_index_t, N>;
}
// namespace ck
}
// namespace ck
#endif
#endif
#pragma clang diagnostic pop
Prev
1
…
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment