Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7a3b49e5
Commit
7a3b49e5
authored
Jun 25, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into contraction
parents
e07b3d8e
d3051d75
Changes
592
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
378 additions
and
241 deletions
+378
-241
include/ck/utility/is_known_at_compile_time.hpp
include/ck/utility/is_known_at_compile_time.hpp
+5
-4
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+5
-5
include/ck/utility/math.hpp
include/ck/utility/math.hpp
+21
-5
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+66
-15
include/ck/utility/multi_index.hpp
include/ck/utility/multi_index.hpp
+4
-4
include/ck/utility/number.hpp
include/ck/utility/number.hpp
+3
-0
include/ck/utility/print.hpp
include/ck/utility/print.hpp
+3
-0
include/ck/utility/reduction_common.hpp
include/ck/utility/reduction_common.hpp
+6
-31
include/ck/utility/reduction_enums.hpp
include/ck/utility/reduction_enums.hpp
+5
-30
include/ck/utility/reduction_functions_accumulate.hpp
include/ck/utility/reduction_functions_accumulate.hpp
+40
-54
include/ck/utility/reduction_operator.hpp
include/ck/utility/reduction_operator.hpp
+126
-62
include/ck/utility/sequence.hpp
include/ck/utility/sequence.hpp
+17
-4
include/ck/utility/sequence_helper.hpp
include/ck/utility/sequence_helper.hpp
+3
-0
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+3
-0
include/ck/utility/statically_indexed_array.hpp
include/ck/utility/statically_indexed_array.hpp
+3
-0
include/ck/utility/statically_indexed_array_multi_index.hpp
include/ck/utility/statically_indexed_array_multi_index.hpp
+3
-0
include/ck/utility/synchronization.hpp
include/ck/utility/synchronization.hpp
+5
-4
include/ck/utility/thread_group.hpp
include/ck/utility/thread_group.hpp
+3
-0
include/ck/utility/transpose_vectors.hpp
include/ck/utility/transpose_vectors.hpp
+5
-4
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+52
-19
No files found.
include/ck/utility/is_known_at_compile_time.hpp
View file @
7a3b49e5
#ifndef IS_KNOWN_AT_COMPILE_TIME_HPP
#define IS_KNOWN_AT_COMPILE_TIME_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp"
#pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp"
#include "sequence.hpp"
#include "tuple.hpp"
...
...
@@ -52,4 +54,3 @@ struct is_known_at_compile_time<Tuple<Ts...>>
};
}
// namespace ck
#endif
include/ck/utility/magic_division.hpp
View file @
7a3b49e5
#ifndef CK_MAGIC_DIVISION_HPP
#define CK_MAGIC_DIVISION_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp"
#pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
...
...
@@ -156,5 +158,3 @@ struct MagicDivision
};
}
// namespace ck
#endif
include/ck/utility/math.hpp
View file @
7a3b49e5
#ifndef CK_MATH_HPP
#define CK_MATH_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp"
#pragma once
#include "ck/ck.hpp"
#include "integral_constant.hpp"
#include "number.hpp"
#include "type.hpp"
...
...
@@ -142,6 +144,22 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
return
min
(
x
,
min
(
ys
...));
}
// disallow implicit type casting
template
<
typename
T
>
__device__
T
exp
(
T
x
);
template
<
>
__device__
float
exp
<
float
>
(
float
x
)
{
return
__expf
(
x
);
}
template
<
>
__device__
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
}
// greatest common divisor, aka highest common factor
__host__
__device__
constexpr
index_t
gcd
(
index_t
x
,
index_t
y
)
{
...
...
@@ -212,5 +230,3 @@ struct less
}
// namespace math
}
// namespace ck
#endif
include/ck/utility/math_v2.hpp
View file @
7a3b49e5
#ifndef CK_MATH_V2_HPP
#define CK_MATH_V2_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cmath>
#include "data_type.hpp"
#include "half.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
namespace
ck
{
namespace
math
{
// math functions for the host, some are implemented by calling C++ std functions
static
inline
__host__
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
static
inline
__host__
double
abs
(
double
x
)
{
return
std
::
abs
(
x
);
};
...
...
@@ -28,26 +33,26 @@ static inline __host__ int32_t abs(int32_t x)
static
inline
__host__
half_t
abs
(
half_t
x
)
{
half_float
::
half
xx
=
*
reinterpret_cast
<
half_float
::
half
*
>
(
&
x
);
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
half_float
::
half
abs_xx
=
half_float
::
abs
(
xx
)
;
uint16_t
abs_xx
=
xx
&
0x7fff
;
half_t
abs_x
=
*
reinterpre
t_cast
<
half_t
*
>
(
&
abs_xx
);
half_t
abs_x
=
ck
::
bi
t_cast
<
half_t
>
(
abs_xx
);
return
abs_x
;
};
static
inline
__host__
float
isnan
(
float
x
)
{
return
std
::
isnan
(
x
);
};
static
inline
__host__
bool
isnan
(
float
x
)
{
return
std
::
isnan
(
x
);
};
static
inline
__host__
double
isnan
(
double
x
)
{
return
std
::
isnan
(
x
);
};
static
inline
__host__
bool
isnan
(
double
x
)
{
return
std
::
isnan
(
x
);
};
static
inline
__host__
int8_t
isnan
(
int8_t
x
)
static
inline
__host__
bool
isnan
(
int8_t
x
)
{
(
void
)
x
;
return
false
;
};
static
inline
__host__
int32_t
isnan
(
int32_t
x
)
static
inline
__host__
bool
isnan
(
int32_t
x
)
{
(
void
)
x
;
return
false
;
...
...
@@ -55,12 +60,58 @@ static inline __host__ int32_t isnan(int32_t x)
static
inline
__host__
bool
isnan
(
half_t
x
)
{
half_float
::
half
xx
=
*
reinterpret_cast
<
half_float
::
half
*
>
(
&
x
);
uint16_t
xx
=
ck
::
bit_cast
<
uint16_t
>
(
x
);
return
half_float
::
isnan
(
xx
)
;
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
static
inline
__host__
float
sqrt
(
float
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static
inline
__device__
float
abs
(
float
x
)
{
return
::
abs
(
x
);
};
static
inline
__device__
double
abs
(
double
x
)
{
return
::
abs
(
x
);
};
static
inline
__device__
int8_t
abs
(
int8_t
x
)
{
int8_t
sgn
=
x
>>
(
8
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
static
inline
__device__
int32_t
abs
(
int32_t
x
)
{
int32_t
sgn
=
x
>>
(
32
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
static
inline
__device__
half_t
abs
(
half_t
x
)
{
return
::
__habs
(
x
);
};
static
inline
__device__
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
static
inline
__device__
bool
isnan
(
double
x
)
{
return
::
isnan
(
x
);
};
static
inline
__device__
bool
isnan
(
int8_t
x
)
{
(
void
)
x
;
return
false
;
};
static
inline
__device__
bool
isnan
(
int32_t
x
)
{
(
void
)
x
;
return
false
;
};
static
inline
__device__
bool
isnan
(
half_t
x
)
{
return
::
__hisnan
(
x
);
};
static
inline
__device__
float
sqrt
(
float
x
)
{
return
::
sqrtf
(
x
);
};
static
inline
__device__
double
sqrt
(
double
x
)
{
return
::
sqrt
(
x
);
};
}
// namespace math
}
// namespace ck
#endif
include/ck/utility/multi_index.hpp
View file @
7a3b49e5
#ifndef CK_MULTI_INDEX_HPP
#define CK_MULTI_INDEX_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "common_header.hpp"
...
...
@@ -8,5 +10,3 @@
#else
#include "statically_indexed_array_multi_index.hpp"
#endif
#endif
include/ck/utility/number.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_NUMBER_HPP
#define CK_NUMBER_HPP
...
...
include/ck/utility/print.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_PRINT_HPP
#define CK_PRINT_HPP
...
...
include/ck/utility/reduction_common.hpp
View file @
7a3b49e5
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_COMMON_HPP
#define CK_REDUCTION_COMMON_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "reduction_enums.hpp"
#pragma once
#include "ck/utility/reduction_enums.hpp"
namespace
ck
{
...
...
@@ -60,6 +37,4 @@ constexpr __device__ index_t get_shift<1>()
return
(
0
);
}
};
// end of namespace ck
#endif
}
// namespace ck
include/ck/utility/reduction_enums.hpp
View file @
7a3b49e5
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_ENUMS_HPP
#define CK_REDUCTION_ENUMS_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
...
...
@@ -61,6 +38,4 @@ enum struct IndicesType
INDICES_8BIT
=
3
,
};
};
// end of namespace ck
#endif
}
// namespace ck
include/ck/utility/reduction_functions_accumulate.hpp
View file @
7a3b49e5
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_FUNCTIONS_BINOP_HPP
#define CK_REDUCTION_FUNCTIONS_BINOP_HPP
#include "data_type.hpp"
#include "reduction_common.hpp"
#include "reduction_operator.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/reduction_common.hpp"
#include "ck/utility/reduction_operator.hpp"
namespace
ck
{
namespace
detail
{
template
<
typename
T
>
static
inline
__device__
bool
is_nan
(
T
x
)
{
return
(
isnan
(
x
));
};
template
<
>
inline
__device__
bool
is_nan
<
half_t
>
(
half_t
x
)
// Check for NaN; guarantee NaNs are NOT propagated to result (i.e., ignore NaNs)
template
<
typename
ReduceOperation
,
typename
AccDataType
>
struct
AccumulateWithNanIgnore
{
return
(
__hisnan
(
x
));
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
{
if
(
!
isnan
(
currVal
))
{
ReduceOperation
{}(
accuVal
,
currVal
);
}
};
};
template
<
bool
PropagateNan
,
typename
ReduceOperation
,
typename
AccDataType
>
struct
AccumulateWithNanCheck
;
// Does not check for NaN; does not guarantee NaNs be propagated to result
// e.g., given that max(a, b) = a > b ? a : b
// then max(NaN, 1) returns 1
// max(1, NaN) returns NaN
// since any comparison involving NaNs returns false
template
<
typename
ReduceOperation
,
typename
AccDataType
>
struct
AccumulateWithNanCheck
<
false
,
ReduceOperation
,
AccDataType
>
{
// cppcheck-suppress constParameter
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
__host__
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
{
ReduceOperation
{}(
accuVal
,
currVal
);
};
};
// Check for NaN; guarantees NaNs be propagated to result
template
<
typename
ReduceOperation
,
typename
AccDataType
>
struct
AccumulateWithNanCheck
<
true
,
ReduceOperation
,
AccDataType
>
{
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
__host__
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
)
{
if
(
is_nan
(
currVal
))
using
ck
::
math
::
isnan
;
if
(
isnan
(
currVal
))
{
accuVal
=
currVal
;
}
...
...
@@ -81,7 +67,7 @@ struct AccumulateWithIndexAndNanCheck;
template
<
typename
ReduceOperation
,
typename
AccDataType
,
typename
IndexDataType
>
struct
AccumulateWithIndexAndNanCheck
<
false
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
{
__device__
static
inline
void
__host__
__device__
static
inline
void
// cppcheck-suppress constParameter
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
,
...
...
@@ -101,12 +87,14 @@ template <typename ReduceOperation, typename AccDataType, typename IndexDataType
struct
AccumulateWithIndexAndNanCheck
<
true
,
ReduceOperation
,
AccDataType
,
IndexDataType
>
{
// The method is called when the ReduceOperation is indexable and the user asked for indices
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
,
IndexDataType
&
accuIndex
,
IndexDataType
currIndex
)
__host__
__device__
static
inline
void
Calculate
(
AccDataType
&
accuVal
,
AccDataType
currVal
,
IndexDataType
&
accuIndex
,
IndexDataType
currIndex
)
{
if
(
is_nan
(
currVal
))
using
ck
::
math
::
isnan
;
if
(
isnan
(
currVal
))
{
accuVal
=
currVal
;
accuIndex
=
currIndex
;
...
...
@@ -123,7 +111,5 @@ struct AccumulateWithIndexAndNanCheck<true, ReduceOperation, AccDataType, IndexD
};
};
};
// namespace detail
};
// end of namespace ck
#endif
}
// namespace detail
}
// namespace ck
include/ck/utility/reduction_operator.hpp
View file @
7a3b49e5
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2020 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#ifndef CK_REDUCTION_OPERATOR_HPP
#define CK_REDUCTION_OPERATOR_HPP
#include "config.hpp"
#include "data_type.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type.hpp"
namespace
ck
{
...
...
@@ -36,7 +14,7 @@ namespace reduce {
// Every binary operator used in reduction is represented by a templated functor class. Each functor
// class must provide at least
// three members:
// 1) Get
ReductionZero
Val() -- the interface to return the "identity element" for the binary
// 1) Get
Identity
Val
ue
() -- the interface to return the "identity element" for the binary
// operator, "identity element" is the unique
// element in the algebraic space that doesn't affect the value of other elements
// when operated against them, and the concept is similar to zero vector in
...
...
@@ -54,64 +32,92 @@ namespace reduce {
// accumulated index also need be
// changed.
template
<
class
T
>
struct
Add
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetReductionZeroVal
()
{
return
static_cast
<
T
>
(
0.0
f
);
};
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
return
operation
==
InMemoryDataOperationEnum
::
AtomicAdd
||
operation
==
InMemoryDataOperationEnum
::
Set
;
};
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
a
=
a
+
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"The data type is not supported by the Add accumulator!"
);
a
=
a
+
b
;
}
};
template
<
class
T
>
struct
Mul
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetReductionZeroVal
()
{
return
static_cast
<
T
>
(
1.0
f
);
};
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
1.0
f
);
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
a
=
a
*
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
int32_t
>::
value
,
"The data type is not supported by the Mul accumulator!"
);
a
=
a
*
b
;
}
};
template
<
class
T
>
struct
Max
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetReductionZeroVal
()
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Lowest
();
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_max to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Max accumulator!"
);
if
(
a
<
b
)
{
a
=
b
;
...
...
@@ -120,31 +126,41 @@ struct Max
}
};
template
<
class
T
>
struct
Min
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetReductionZeroVal
()
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
NumericLimits
<
T
>::
Max
();
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_min to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Min accumulator!"
);
if
(
a
>
b
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the Min accumulator!"
);
if
(
a
>
b
)
{
a
=
b
;
...
...
@@ -153,28 +169,41 @@ struct Min
}
};
template
<
class
T
>
struct
AMax
{
using
dataType
=
T
;
__host__
__device__
static
constexpr
T
GetReductionZeroVal
()
{
return
static_cast
<
T
>
(
0.0
f
);
};
template
<
typename
T
>
__host__
__device__
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
IsCompatibleInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
// ToChange: atomic_max to be added
return
operation
==
InMemoryDataOperationEnum
::
Set
;
};
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the AMax accumulator!"
);
if
(
a
<
b
)
a
=
b
;
}
template
<
typename
T
>
__host__
__device__
inline
constexpr
void
operator
()(
T
&
a
,
T
b
,
bool
&
changed
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"The data type is not supported by the AMax accumulator!"
);
if
(
a
<
b
)
{
a
=
b
;
...
...
@@ -184,7 +213,7 @@ struct AMax
};
template
<
typename
T
>
T
GetReductionZero
ValueForInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
constexpr
T
GetIdentity
ValueForInMemoryDataOperation
(
InMemoryDataOperationEnum
operation
)
{
T
result
=
ck
::
type_convert
<
T
>
(
0.0
f
);
...
...
@@ -194,8 +223,43 @@ T GetReductionZeroValueForInMemoryDataOperation(InMemoryDataOperationEnum operat
return
(
result
);
};
};
// end of namespace reduce
template
<
InMemoryDataOperationEnum
Operation
,
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
{
static
constexpr
bool
value
=
false
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicAdd
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
AtomicMax
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
;
};
}
// end of namespace ck
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
Set
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
bhalf_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
};
template
<
typename
DataType
>
struct
InMemoryDataOperatonSupportedOnDataType
<
InMemoryDataOperationEnum
::
Add
,
DataType
>
{
static
constexpr
bool
value
=
is_same
<
DataType
,
float
>::
value
||
is_same
<
DataType
,
double
>::
value
||
is_same
<
DataType
,
half_t
>::
value
||
is_same
<
DataType
,
int8_t
>::
value
||
is_same
<
DataType
,
int32_t
>::
value
;
};
#endif
}
// namespace reduce
}
// namespace ck
include/ck/utility/sequence.hpp
View file @
7a3b49e5
#ifndef CK_SEQUENCE_HPP
#define CK_SEQUENCE_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "integral_constant.hpp"
#include "type.hpp"
...
...
@@ -241,7 +243,13 @@ struct arithmetic_sequence_gen
}
};
using
type
=
typename
sequence_gen
<
(
IEnd
-
IBegin
)
/
Increment
,
F
>::
type
;
using
type0
=
typename
sequence_gen
<
(
IEnd
-
IBegin
)
/
Increment
,
F
>::
type
;
using
type1
=
Sequence
<>
;
static
constexpr
bool
kHasContent
=
(
Increment
>
0
&&
IBegin
<
IEnd
)
||
(
Increment
<
0
&&
IBegin
>
IEnd
);
using
type
=
typename
conditional
<
kHasContent
,
type0
,
type1
>::
type
;
};
// uniform sequence
...
...
@@ -882,5 +890,10 @@ __host__ __device__ constexpr bool sequence_all_of(Seq, F f)
return
flag
;
}
template
<
typename
Sx
,
typename
Sy
>
using
sequence_merge_t
=
typename
sequence_merge
<
Sx
,
Sy
>::
type
;
template
<
index_t
NSize
,
index_t
I
>
using
uniform_sequence_gen_t
=
typename
uniform_sequence_gen
<
NSize
,
I
>::
type
;
}
// namespace ck
#endif
include/ck/utility/sequence_helper.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "tuple.hpp"
...
...
include/ck/utility/static_buffer.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATIC_BUFFER_HPP
#define CK_STATIC_BUFFER_HPP
...
...
include/ck/utility/statically_indexed_array.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATICALLY_INDEXED_ARRAY_HPP
#define CK_STATICALLY_INDEXED_ARRAY_HPP
...
...
include/ck/utility/statically_indexed_array_multi_index.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
...
...
include/ck/utility/synchronization.hpp
View file @
7a3b49e5
#ifndef CK_SYNCHRONIZATION_AMD_HPP
#define CK_SYNCHRONIZATION_AMD_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp"
#pragma once
#include "ck/ck.hpp"
namespace
ck
{
...
...
@@ -18,4 +20,3 @@ __device__ void block_sync_lds()
}
}
// namespace ck
#endif
include/ck/utility/thread_group.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "get_id.hpp"
...
...
include/ck/utility/transpose_vectors.hpp
View file @
7a3b49e5
#ifndef CK_TRANSPOSE_VECTORS_AMD_HPP
#define CK_TRANSPOSE_VECTORS_AMD_HPP
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "config.hpp"
#pragma once
#include "ck/ck.hpp"
#include "statically_indexed_array.hpp"
#include "data_type.hpp"
...
...
@@ -165,4 +167,3 @@ struct transpose_vectors<int8_t, NX, NY>
};
}
// namespace ck
#endif
include/ck/utility/tuple.hpp
View file @
7a3b49e5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "integral_constant.hpp"
...
...
@@ -16,14 +19,18 @@ struct TupleElementKey
};
template
<
typename
Key
,
typename
Data
>
struct
TupleElement
struct
TupleElement
KeyData
{
__host__
__device__
constexpr
TupleElement
()
=
default
;
template
<
typename
T
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElement
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElement
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
#if 0 // workaround compiler complaint about implicitly-deleted default constructor
__host__ __device__ constexpr TupleElementKeyData() = default;
#else
__host__
__device__
constexpr
TupleElementKeyData
()
:
mData
{}
{}
#endif
template
<
typename
T
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElementKeyData
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
{
}
...
...
@@ -31,20 +38,21 @@ struct TupleElement
};
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
const
Data
&
get_tuple_element
(
const
TupleElement
<
Key
,
Data
>&
x
)
__host__
__device__
constexpr
const
Data
&
get_tuple_element_data
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
{
return
static_cast
<
const
Data
&>
(
x
.
mData
);
}
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
&
get_tuple_element
(
TupleElement
<
Key
,
Data
>&
x
)
__host__
__device__
constexpr
Data
&
get_tuple_element
_data
(
TupleElement
KeyData
<
Key
,
Data
>&
x
)
{
return
x
.
mData
;
}
// TODO: not sure the use of reference is correct
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
&&
get_tuple_element
(
TupleElement
<
Key
,
Data
>&&
x
)
__host__
__device__
constexpr
Data
&&
get_tuple_element
_data
(
TupleElement
KeyData
<
Key
,
Data
>&&
x
)
{
return
static_cast
<
Data
&&>
(
x
.
mData
);
}
...
...
@@ -53,7 +61,7 @@ template <typename Indices, typename... Xs>
struct
TupleImpl
;
template
<
index_t
...
Is
,
typename
...
Xs
>
struct
TupleImpl
<
Sequence
<
Is
...
>
,
Xs
...
>
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
...
struct
TupleImpl
<
Sequence
<
Is
...
>
,
Xs
...
>
:
TupleElement
KeyData
<
TupleElementKey
<
Is
>
,
Xs
>
...
{
__host__
__device__
constexpr
TupleImpl
()
=
default
;
...
...
@@ -62,13 +70,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
!
is_same
<
remove_cvref_t
<
Y
>,
TupleImpl
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
:
TupleElement
KeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
{
}
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElement
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
:
TupleElement
KeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Is
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
...
...
@@ -77,15 +85,15 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
__host__
__device__
static
constexpr
index_t
Size
()
{
return
sizeof
...(
Xs
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
const
auto
&
GetElementByKey
(
TupleElementKey
<
I
>
)
const
__host__
__device__
constexpr
const
auto
&
GetElement
Data
ByKey
(
TupleElementKey
<
I
>
)
const
{
return
get_tuple_element
<
TupleElementKey
<
I
>>
(
*
this
);
return
get_tuple_element
_data
<
TupleElementKey
<
I
>>
(
*
this
);
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
&
GetElementByKey
(
TupleElementKey
<
I
>
)
__host__
__device__
constexpr
auto
&
GetElement
Data
ByKey
(
TupleElementKey
<
I
>
)
{
return
get_tuple_element
<
TupleElementKey
<
I
>>
(
*
this
);
return
get_tuple_element
_data
<
TupleElementKey
<
I
>>
(
*
this
);
}
};
...
...
@@ -120,7 +128,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
constexpr
const
auto
&
At
(
Number
<
I
>
)
const
{
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
return
base
::
GetElement
Data
ByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
// write access
...
...
@@ -128,7 +136,7 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
constexpr
auto
&
At
(
Number
<
I
>
)
{
static_assert
(
I
<
base
::
Size
(),
"wrong! out of range"
);
return
base
::
GetElementByKey
(
detail
::
TupleElementKey
<
I
>
{});
return
base
::
GetElement
Data
ByKey
(
detail
::
TupleElementKey
<
I
>
{});
}
// read access
...
...
@@ -158,6 +166,31 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
};
template
<
>
struct
Tuple
<>
{
__host__
__device__
constexpr
Tuple
()
=
default
;
__host__
__device__
static
constexpr
index_t
Size
()
{
return
0
;
}
template
<
typename
T
>
__host__
__device__
constexpr
auto
operator
=
(
const
T
&
)
{
return
*
this
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
};
template
<
index_t
I
,
typename
TTuple
>
struct
tuple_element
{
using
type
=
decltype
(
TTuple
{}.
At
(
Number
<
I
>
{}));
};
template
<
index_t
I
,
typename
TTuple
>
using
tuple_element_t
=
typename
tuple_element
<
I
,
TTuple
>::
type
;
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
...
...
Prev
1
…
8
9
10
11
12
13
14
15
16
…
30
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment