Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
546a764e
Commit
546a764e
authored
Oct 24, 2023
by
Artur Wojcik
Browse files
Merge branch 'migraphx' into uif2-migraphx
parents
8da3dfff
57cdd70b
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
581 additions
and
48 deletions
+581
-48
include/ck/utility/container_helper.hpp
include/ck/utility/container_helper.hpp
+2
-2
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+100
-9
include/ck/utility/debug.hpp
include/ck/utility/debug.hpp
+1
-1
include/ck/utility/enable_if.hpp
include/ck/utility/enable_if.hpp
+16
-1
include/ck/utility/functional.hpp
include/ck/utility/functional.hpp
+2
-2
include/ck/utility/functional4.hpp
include/ck/utility/functional4.hpp
+5
-5
include/ck/utility/integral_constant.hpp
include/ck/utility/integral_constant.hpp
+5
-0
include/ck/utility/magic_division.hpp
include/ck/utility/magic_division.hpp
+2
-0
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+2
-1
include/ck/utility/random_gen.hpp
include/ck/utility/random_gen.hpp
+7
-6
include/ck/utility/tuple.hpp
include/ck/utility/tuple.hpp
+7
-7
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+1
-1
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+133
-7
include/ck/utility/type_convert.hpp
include/ck/utility/type_convert.hpp
+3
-3
library/CMakeLists.txt
library/CMakeLists.txt
+6
-3
library/src/jit_library/CMakeLists.txt
library/src/jit_library/CMakeLists.txt
+48
-0
library/src/jit_library/include/ck/host/common.hpp
library/src/jit_library/include/ck/host/common.hpp
+36
-0
library/src/jit_library/include/ck/host/device_batched_gemm_softmax_gemm.hpp
...rary/include/ck/host/device_batched_gemm_softmax_gemm.hpp
+110
-0
library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp
...rc/jit_library/include/ck/host/device_gemm_multiple_d.hpp
+59
-0
library/src/jit_library/src/common.cpp
library/src/jit_library/src/common.cpp
+36
-0
No files found.
include/ck/utility/container_helper.hpp
View file @
546a764e
...
...
@@ -326,14 +326,14 @@ template <typename T, index_t NX, index_t NY>
__host__
__device__
constexpr
auto
container_concat
(
const
Array
<
T
,
NX
>&
ax
,
const
Array
<
T
,
NY
>&
ay
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
[
&
](
auto
&&
...
zs
)
{
return
make_array
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
ax
,
ay
);
}
template
<
typename
...
X
,
typename
...
Y
>
__host__
__device__
constexpr
auto
container_concat
(
const
Tuple
<
X
...
>&
tx
,
const
Tuple
<
Y
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
[
&
](
auto
&&
...
zs
)
{
return
make_tuple
(
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...);
},
tx
,
ty
);
}
template
<
typename
Container
>
...
...
include/ck/utility/data_type.hpp
View file @
546a764e
...
...
@@ -5,7 +5,22 @@
#include "ck/utility/statically_indexed_array.hpp"
#ifdef __HIPCC_RTC__
/// Definitions from <cstdint>, <cmath> conflict with
/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h.
using
int8_t
=
signed
char
;
using
uint8_t
=
unsigned
char
;
using
int16_t
=
signed
short
;
using
uint16_t
=
unsigned
short
;
using
float_t
=
float
;
#endif // __HIPCC_RTC__
namespace
ck
{
#ifdef __HIPCC_RTC__
using
byte
=
unsigned
char
;
#else
using
std
::
byte
;
#endif
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
...
...
@@ -974,20 +989,96 @@ using bf8x32_t = typename vector_type<bf8_t, 32>::type;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
;
template
<
>
struct
NumericLimits
<
int32_t
>
{
__host__
__device__
static
constexpr
T
Min
()
{
return
std
::
numeric_limits
<
T
>::
min
()
;
}
__host__
__device__
static
constexpr
int32_t
Lowest
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
T
Max
()
{
return
std
::
numeric_limits
<
T
>::
max
()
;
}
__host__
__device__
static
constexpr
int32_t
Min
()
noexcept
{
return
-
2147483647
-
1
;
}
__host__
__device__
static
constexpr
T
Lowest
()
{
return
std
::
numeric_limits
<
T
>::
lowest
()
;
}
__host__
__device__
static
constexpr
int32_t
Max
()
noexcept
{
return
2147483647
;
}
__host__
__device__
static
constexpr
T
QuietNaN
()
{
return
std
::
numeric_limits
<
T
>::
quiet_NaN
();
}
__host__
__device__
static
constexpr
int32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int16_t
>
{
__host__
__device__
static
constexpr
int16_t
Lowest
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Min
()
noexcept
{
return
-
32768
;
}
__host__
__device__
static
constexpr
int16_t
Max
()
noexcept
{
return
32767
;
}
__host__
__device__
static
constexpr
int16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
int8_t
>
{
__host__
__device__
static
constexpr
int8_t
Lowest
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Min
()
noexcept
{
return
-
128
;
}
__host__
__device__
static
constexpr
int8_t
Max
()
noexcept
{
return
127
;
}
__host__
__device__
static
constexpr
int8_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
int8_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint32_t
>
{
__host__
__device__
static
constexpr
uint32_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
Max
()
noexcept
{
return
4294967295U
;
}
__host__
__device__
static
constexpr
uint32_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint32_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
uint16_t
>
{
__host__
__device__
static
constexpr
uint16_t
Lowest
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Min
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
Max
()
noexcept
{
return
65535U
;
}
__host__
__device__
static
constexpr
uint16_t
Infinity
()
noexcept
{
return
0
;
}
__host__
__device__
static
constexpr
uint16_t
QuietNaN
()
{
return
0
;
}
};
template
<
>
struct
NumericLimits
<
float
>
{
static
constexpr
unsigned
int
binary_min
=
0x00800000
;
static
constexpr
unsigned
int
binary_max
=
0x7F7FFFFF
;
static
constexpr
unsigned
int
binary_lowest
=
0xFF7FFFFF
;
static
constexpr
unsigned
int
binary_qnan
=
0xFFC00001
;
static
constexpr
unsigned
int
binary_inf
=
0x7F8000000
;
__host__
__device__
static
constexpr
float
Min
()
{
return
bit_cast
<
float
>
(
binary_min
);
}
__host__
__device__
static
constexpr
float
Max
()
{
return
bit_cast
<
float
>
(
binary_max
);
}
__host__
__device__
static
constexpr
float
Lowest
()
{
return
bit_cast
<
float
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
float
QuietNaN
()
{
return
bit_cast
<
float
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
T
Infinity
()
{
return
std
::
numeric_limits
<
T
>::
infinity
(
);
}
__host__
__device__
static
constexpr
float
Infinity
()
{
return
bit_cast
<
float
>
(
binary_inf
);
}
};
template
<
>
...
...
include/ck/utility/debug.hpp
View file @
546a764e
...
...
@@ -3,7 +3,7 @@
#ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP
#include "type.hpp"
namespace
ck
{
namespace
debug
{
...
...
include/ck/utility/enable_if.hpp
View file @
546a764e
...
...
@@ -4,11 +4,26 @@
#pragma once
namespace
ck
{
#ifdef __HIPCC_RTC__
template
<
bool
B
,
class
T
=
void
>
struct
enable_if
{
};
template
<
class
T
>
struct
enable_if
<
true
,
T
>
{
using
type
=
T
;
};
template
<
bool
B
,
class
T
=
void
>
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
#else
template
<
bool
B
,
typename
T
=
void
>
using
enable_if
=
std
::
enable_if
<
B
,
T
>
;
template
<
bool
B
,
typename
T
=
void
>
using
enable_if_t
=
typename
std
::
enable_if
<
B
,
T
>::
type
;
#endif
}
// namespace ck
include/ck/utility/functional.hpp
View file @
546a764e
...
...
@@ -120,11 +120,11 @@ constexpr auto conditional_expr(X&& x, Y&& y)
{
if
constexpr
(
predicate
)
{
return
std
::
forward
<
X
>
(
x
);
return
ck
::
forward
<
X
>
(
x
);
}
else
{
return
std
::
forward
<
Y
>
(
y
);
return
ck
::
forward
<
Y
>
(
y
);
}
}
...
...
include/ck/utility/functional4.hpp
View file @
546a764e
...
...
@@ -21,7 +21,7 @@ struct unpack_impl<Sequence<Is...>>
template
<
typename
F
,
typename
X
>
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
)
const
{
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...);
return
ck
::
forward
<
F
>
(
f
)(
ck
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...);
}
};
...
...
@@ -35,8 +35,8 @@ struct unpack2_impl<Sequence<Is...>, Sequence<Js...>>
template
<
typename
F
,
typename
X
,
typename
Y
>
__host__
__device__
constexpr
auto
operator
()(
F
&&
f
,
X
&&
x
,
Y
&&
y
)
const
{
return
std
::
forward
<
F
>
(
f
)(
std
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...,
std
::
forward
<
Y
>
(
y
).
At
(
Number
<
Js
>
{})...);
return
ck
::
forward
<
F
>
(
f
)(
ck
::
forward
<
X
>
(
x
).
At
(
Number
<
Is
>
{})...,
ck
::
forward
<
Y
>
(
y
).
At
(
Number
<
Js
>
{})...);
}
};
...
...
@@ -47,7 +47,7 @@ __host__ __device__ constexpr auto unpack(F&& f, X&& x)
{
using
X_
=
remove_reference_t
<
X
>
;
return
detail
::
unpack_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
));
ck
::
forward
<
F
>
(
f
),
ck
::
forward
<
X
>
(
x
));
}
// TODO: properly implement unpack that takes any number of containers
...
...
@@ -58,7 +58,7 @@ __host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y)
using
Y_
=
remove_reference_t
<
Y
>
;
return
detail
::
unpack2_impl
<
typename
arithmetic_sequence_gen
<
0
,
X_
::
Size
(),
1
>::
type
,
typename
arithmetic_sequence_gen
<
0
,
Y_
::
Size
(),
1
>::
type
>
{}(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
X
>
(
x
),
std
::
forward
<
Y
>
(
y
));
ck
::
forward
<
F
>
(
f
),
ck
::
forward
<
X
>
(
x
),
ck
::
forward
<
Y
>
(
y
));
}
}
// namespace ck
...
...
include/ck/utility/integral_constant.hpp
View file @
546a764e
...
...
@@ -48,4 +48,9 @@ __host__ __device__ constexpr auto operator%(integral_constant<TX, X>, integral_
return
integral_constant
<
decltype
(
X
%
Y
),
X
%
Y
>
{};
}
template
<
bool
B
>
using
bool_constant
=
integral_constant
<
bool
,
B
>
;
using
true_type
=
bool_constant
<
true
>
;
using
false_type
=
bool_constant
<
false
>
;
}
// namespace ck
include/ck/utility/magic_division.hpp
View file @
546a764e
...
...
@@ -9,6 +9,8 @@
#include "type.hpp"
#include "tuple.hpp"
#define INT32_MAX 2147483647
namespace
ck
{
// magic number division
...
...
include/ck/utility/math_v2.hpp
View file @
546a764e
...
...
@@ -14,6 +14,7 @@
namespace
ck
{
namespace
math
{
#ifndef __HIPCC_RTC__
// math functions for the host, some are implemented by calling C++ std functions
static
inline
__host__
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
...
...
@@ -183,7 +184,7 @@ inline __host__ double expm1<double>(double x)
{
return
std
::
expm1
(
x
);
}
#endif
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
static
inline
__device__
float
abs
(
float
x
)
{
return
::
abs
(
x
);
};
...
...
include/ck/utility/random_gen.hpp
View file @
546a764e
...
...
@@ -2,12 +2,13 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck/utility/ignore.hpp>
namespace
ck
{
// Pseudo random number generator
// version for fp32
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
std
::
is_same
<
float
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
uint32_t
x
=
*
(
reinterpret_cast
<
uint32_t
*>
(
&
val
));
...
...
@@ -23,7 +24,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
}
// version for fp16
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<
std
::
is_same
<
half_t
,
T
>{},
bool
>
=
false
>
template
<
typename
T
,
uint32_t
seed_t
,
ck
::
enable_if_t
<
std
::
is_same
<
half_t
,
T
>{},
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
index_t
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
uint16_t
x
=
*
(
reinterpret_cast
<
uint16_t
*>
(
&
val
));
...
...
@@ -40,12 +41,12 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed =
// return 0 if data is not fp16 or fp32
template
<
typename
T
,
uint32_t
seed_t
,
std
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
half_t
,
T
>
{}),
bool
>
=
false
>
ck
::
enable_if_t
<!
(
std
::
is_same
<
float
,
T
>{}
||
std
::
is_same
<
half_t
,
T
>
{}),
bool
>
=
false
>
__host__
__device__
uint32_t
prand_generator
(
int
id
,
T
val
,
uint32_t
seed
=
seed_t
)
{
std
::
ignore
=
id
;
std
::
ignore
=
val
;
std
::
ignore
=
seed
;
ck
::
ignore
=
id
;
ck
::
ignore
=
val
;
ck
::
ignore
=
seed
;
return
0
;
}
...
...
include/ck/utility/tuple.hpp
View file @
546a764e
...
...
@@ -32,7 +32,7 @@ struct TupleElementKeyData
template
<
typename
T
,
typename
enable_if
<!
is_same
<
remove_cvref_t
<
T
>,
TupleElementKeyData
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
std
::
forward
<
T
>
(
v
))
__host__
__device__
constexpr
TupleElementKeyData
(
T
&&
v
)
:
mData
(
ck
::
forward
<
T
>
(
v
))
{
}
...
...
@@ -67,7 +67,7 @@ get_tuple_element_data_reference(TupleElementKeyData<Key, Data>&& x)
template
<
typename
Key
,
typename
Data
>
__host__
__device__
constexpr
Data
get_tuple_element_data
(
const
TupleElementKeyData
<
Key
,
Data
>&
x
)
{
return
std
::
forward
(
x
.
mData
);
return
ck
::
forward
(
x
.
mData
);
}
template
<
typename
Indices
,
typename
...
Xs
>
...
...
@@ -83,13 +83,13 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElementKeyData<TupleElementKey<I
!
is_same
<
remove_cvref_t
<
Y
>,
TupleImpl
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Y
&&
y
)
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Y
>
(
y
))...
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
ck
::
forward
<
Y
>
(
y
))...
{
}
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
TupleImpl
(
Ys
&&
...
ys
)
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
std
::
forward
<
Ys
>
(
ys
))...
:
TupleElementKeyData
<
TupleElementKey
<
Is
>
,
Xs
>
(
ck
::
forward
<
Ys
>
(
ys
))...
{
static_assert
(
sizeof
...(
Is
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Is
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
...
...
@@ -123,14 +123,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
template
<
typename
Y
,
typename
enable_if
<
sizeof
...(
Xs
)
==
1
&&
!
is_same
<
remove_cvref_t
<
Y
>,
Tuple
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
std
::
forward
<
Y
>
(
y
))
__host__
__device__
constexpr
Tuple
(
Y
&&
y
)
:
base
(
ck
::
forward
<
Y
>
(
y
))
{
}
template
<
typename
...
Ys
,
typename
enable_if
<
sizeof
...(
Ys
)
==
sizeof
...(
Xs
)
&&
sizeof
...(
Ys
)
>
=
2
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
std
::
forward
<
Ys
>
(
ys
)...)
__host__
__device__
constexpr
Tuple
(
Ys
&&
...
ys
)
:
base
(
ck
::
forward
<
Ys
>
(
ys
)...)
{
}
...
...
@@ -210,7 +210,7 @@ using tuple_element_t = typename tuple_element<I, TTuple>::type;
template
<
typename
...
Xs
>
__host__
__device__
constexpr
auto
make_tuple
(
Xs
&&
...
xs
)
{
return
Tuple
<
remove_cvref_t
<
Xs
>
...
>
(
std
::
forward
<
Xs
>
(
xs
)...);
return
Tuple
<
remove_cvref_t
<
Xs
>
...
>
(
ck
::
forward
<
Xs
>
(
xs
)...);
}
// https://en.cppreference.com/w/cpp/utility/tuple/tie
...
...
include/ck/utility/tuple_helper.hpp
View file @
546a764e
...
...
@@ -28,7 +28,7 @@ __host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>&
const
Tuple
<
Y
&
...
>&
ty
)
{
return
unpack2
(
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
std
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
[
&
](
auto
&&
...
zs
)
{
return
Tuple
<
decltype
(
zs
)...
>
{
ck
::
forward
<
decltype
(
zs
)
>
(
zs
)...};
},
tx
,
ty
);
}
...
...
include/ck/utility/type.hpp
View file @
546a764e
...
...
@@ -4,10 +4,122 @@
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/integral_constant.hpp"
#include "ck/utility/enable_if.hpp"
#include "ck/utility/integral_constant.hpp"
namespace
ck
{
#ifdef __HIPCC_RTC__
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
struct name : bool_constant<__##name(T)> \
{ \
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAIT2(name) \
template <class T, class U> \
struct name : bool_constant<__##name(T, U)> \
{ \
}
// NOLINTNEXTLINE
#define CK_BUILTIN_TYPE_TRAITN(name) \
template <class... Ts> \
struct name : bool_constant<__##name(Ts...)> \
{ \
}
CK_BUILTIN_TYPE_TRAIT1
(
is_class
);
CK_BUILTIN_TYPE_TRAIT1
(
is_pointer
);
CK_BUILTIN_TYPE_TRAIT1
(
is_reference
);
CK_BUILTIN_TYPE_TRAIT1
(
is_trivially_copyable
);
CK_BUILTIN_TYPE_TRAIT1
(
is_unsigned
);
CK_BUILTIN_TYPE_TRAIT2
(
is_base_of
);
template
<
class
T
>
struct
remove_cv
{
using
type
=
T
;
};
template
<
class
T
>
struct
remove_cv
<
const
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_cv
<
volatile
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_reference
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_reference
<
T
&>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_reference
<
T
&&>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
const
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
volatile
>
{
typedef
T
type
;
};
template
<
class
T
>
struct
remove_pointer
<
T
*
const
volatile
>
{
typedef
T
type
;
};
template
<
typename
T
>
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&
t_
)
noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
template
<
typename
T
>
constexpr
T
&&
forward
(
typename
remove_reference
<
T
>::
type
&&
t_
)
noexcept
{
return
static_cast
<
T
&&>
(
t_
);
}
#else
#include <utility>
#include <type_traits>
using
std
::
forward
;
using
std
::
is_base_of
;
using
std
::
is_class
;
using
std
::
is_pointer
;
using
std
::
is_reference
;
using
std
::
is_trivially_copyable
;
using
std
::
is_unsigned
;
using
std
::
remove_cv
;
using
std
::
remove_pointer
;
using
std
::
remove_reference
;
#endif
template
<
typename
X
,
typename
Y
>
struct
is_same
:
public
integral_constant
<
bool
,
false
>
...
...
@@ -19,25 +131,39 @@ struct is_same<X, X> : public integral_constant<bool, true>
{
};
template
<
typename
T
>
inline
constexpr
bool
is_reference_v
=
is_reference
<
T
>::
value
;
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_same_v
=
is_same
<
X
,
Y
>::
value
;
template
<
typename
X
,
typename
Y
>
inline
constexpr
bool
is_base_of_v
=
is_base_of
<
X
,
Y
>::
value
;
template
<
typename
T
>
inline
constexpr
bool
is_unsigned_v
=
is_unsigned
<
T
>::
value
;
template
<
typename
T
>
using
remove_reference_t
=
typename
remove_reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_reference_t
=
typename
std
::
remove_reference
<
T
>::
type
;
using
remove_reference_t
=
typename
remove_reference
<
T
>::
type
;
template
<
typename
T
>
using
remove_cv_t
=
typename
std
::
remove_cv
<
T
>::
type
;
using
remove_cv_t
=
typename
remove_cv
<
T
>::
type
;
template
<
typename
T
>
using
remove_cvref_t
=
remove_cv_t
<
std
::
remove_reference_t
<
T
>>
;
using
remove_cvref_t
=
remove_cv_t
<
remove_reference_t
<
T
>>
;
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
using
remove_pointer_t
=
typename
remove_pointer
<
T
>::
type
;
template
<
typename
T
>
inline
constexpr
bool
is_pointer_v
=
std
::
is_pointer
<
T
>::
value
;
inline
constexpr
bool
is_pointer_v
=
is_pointer
<
T
>::
value
;
template
<
typename
Y
,
typename
X
,
typename
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
template
<
typename
Y
,
typename
X
,
typename
ck
::
enable_if
<
sizeof
(
X
)
==
sizeof
(
Y
),
bool
>
::
type
=
false
>
__host__
__device__
constexpr
Y
bit_cast
(
const
X
&
x
)
{
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
...
...
include/ck/utility/type_convert.hpp
View file @
546a764e
...
...
@@ -15,7 +15,7 @@ template <typename Y,
std
::
enable_if_t
<!
(
std
::
is_const_v
<
Y
>
||
std
::
is_const_v
<
X
>
),
bool
>
=
false
>
__host__
__device__
constexpr
Y
type_convert
(
X
x
)
{
static_assert
(
!
std
::
is_reference_v
<
Y
>
&&
!
std
::
is_reference_v
<
X
>
);
static_assert
(
!
ck
::
is_reference_v
<
Y
>
&&
!
ck
::
is_reference_v
<
X
>
);
return
static_cast
<
Y
>
(
x
);
}
...
...
@@ -356,7 +356,7 @@ template <>
inline
__host__
__device__
bf8_t
f8_convert_sr
<
bf8_t
,
float
>
(
float
x
)
{
constexpr
int
seed
=
42
;
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
uintptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
float
,
seed
>
(
reinterpret_cast
<
size
_t
>
(
&
x
),
x
);
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
union
{
...
...
@@ -392,7 +392,7 @@ inline __host__ __device__ bf8_t f8_convert_sr<bf8_t, half_t>(half_t x)
constexpr
f8_rounding_mode
rm
=
f8_rounding_mode
::
stochastic
;
constexpr
int
seed
=
42
;
// as thread id is not available on host, use 0 for prn generation
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
uintptr
_t
>
(
&
x
),
x
);
uint32_t
rng
=
prand_generator
<
half_t
,
seed
>
(
reinterpret_cast
<
size
_t
>
(
&
x
),
x
);
return
utils
::
cast_to_f8
<
half_t
,
bf8_t
,
negative_zero_nan
,
clip
,
(
rm
==
f8_rounding_mode
::
stochastic
)
>
(
x
,
rng
);
...
...
library/CMakeLists.txt
View file @
546a764e
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
add_subdirectory
(
src/tensor_operation_instance/gpu
)
add_subdirectory
(
src/utility
)
if
(
CK_BUILD_JIT_LIB
)
add_subdirectory
(
src/jit_library
)
else
()
add_subdirectory
(
src/tensor_operation_instance/gpu
)
add_subdirectory
(
src/utility
)
endif
()
library/src/jit_library/CMakeLists.txt
0 → 100644
View file @
546a764e
include
(
Embed
)
file
(
GLOB_RECURSE KERNEL_FILES CONFIGURE_DEPENDS
${
PROJECT_SOURCE_DIR
}
/include/ck/*.hpp
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
message
(
STATUS
"RELATIVE:
${
PROJECT_SOURCE_DIR
}
/include"
)
add_embed_library
(
ck_headers
${
KERNEL_FILES
}
RELATIVE
${
PROJECT_SOURCE_DIR
}
/include
)
execute_process
(
COMMAND python3
${
CMAKE_CURRENT_SOURCE_DIR
}
/util/make_instance_strings.py
${
PROJECT_SOURCE_DIR
}
/library/src/tensor_operation_instance/gpu
${
CMAKE_CURRENT_BINARY_DIR
}
/solution_instances
WORKING_DIRECTORY
${
CMAKE_CURRENT_SOURCE_DIR
}
/../tensor_operation_instance/gpu/
)
add_library
(
jit_library STATIC
src/device_batched_gemm_softmax_gemm.cpp
src/device_gemm_multiple_d.cpp
src/common.cpp
)
add_library
(
composable_kernel::jit_library ALIAS jit_library
)
set_target_properties
(
jit_library PROPERTIES LINKER_LANGUAGE CXX
)
target_include_directories
(
jit_library SYSTEM PRIVATE
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_SOURCE_DIR
}
/include>
$<BUILD_INTERFACE:
${
PROJECT_SOURCE_DIR
}
/library/src/jit_library/solution_instances>
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_BINARY_DIR
}
/solution_instances>
$<BUILD_INTERFACE:
${
CMAKE_CURRENT_BINARY_DIR
}
/embed/ck_headers/include>
)
target_link_libraries
(
jit_library PRIVATE $<BUILD_INTERFACE:ck_headers>
)
rocm_install
(
TARGETS jit_library
EXPORT jit_libraryTargets
)
rocm_install
(
DIRECTORY include/ck DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
)
rocm_install
(
DIRECTORY
${
PROJECT_SOURCE_DIR
}
/include/ck DESTINATION
${
CMAKE_INSTALL_INCLUDEDIR
}
)
rocm_install
(
EXPORT jit_libraryTargets
FILE composable_kerneljit_libraryTargets.cmake
NAMESPACE composable_kernel::
DESTINATION
${
CMAKE_INSTALL_LIBDIR
}
/cmake/composable_kernel
)
library/src/jit_library/include/ck/host/common.hpp
0 → 100644
View file @
546a764e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <string_view>
#include <utility>
#include <unordered_map>
namespace
ck
{
namespace
host
{
struct
Solution
{
std
::
string
template_str
;
std
::
size_t
block_size
;
std
::
size_t
grid_size
;
};
enum
class
DataType
{
Half
,
Float
,
Int8
,
Int32
};
std
::
string
ToString
(
DataType
dt
);
std
::
unordered_map
<
std
::
string_view
,
std
::
string_view
>
GetHeaders
();
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
);
}
// namespace host
}
// namespace ck
library/src/jit_library/include/ck/host/device_batched_gemm_softmax_gemm.hpp
0 → 100644
View file @
546a764e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
namespace
ck
{
namespace
host
{
namespace
device_batched_gemm_softmax_gemm
{
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
std
::
size_t
O
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransB1
=
false
;
bool
TransC
=
false
;
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
B1DataType
=
DataType
::
Half
;
DataType
CDataType
=
DataType
::
Half
;
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
B1ElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
AccElementOp
=
"ck::tensor_operation::element_wise::Scale"
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
private:
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
static
const
std
::
size_t
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle_idx
=
0
;
static
const
std
::
size_t
ALayout_idx
=
1
;
static
const
std
::
size_t
B0Layout_idx
=
2
;
static
const
std
::
size_t
B1Layout_idx
=
3
;
static
const
std
::
size_t
CLayout_idx
=
4
;
static
const
std
::
size_t
ADataType_idx
=
5
;
static
const
std
::
size_t
B0DataType_idx
=
6
;
static
const
std
::
size_t
B1DataType_idx
=
7
;
static
const
std
::
size_t
CDataType_idx
=
8
;
static
const
std
::
size_t
AccDataType_idx
=
9
;
static
const
std
::
size_t
CShuffleDataType_idx
=
10
;
static
const
std
::
size_t
AElementwiseOperation_idx
=
11
;
static
const
std
::
size_t
B0ElementwiseOperation_idx
=
12
;
static
const
std
::
size_t
Acc0ElementwiseOperation_idx
=
13
;
static
const
std
::
size_t
B1ElementwiseOperation_idx
=
14
;
static
const
std
::
size_t
CElementwiseOperation_idx
=
15
;
static
const
std
::
size_t
GEMMSpecialization_idx
=
16
;
static
const
std
::
size_t
NumGemmKPrefetchStage_idx
=
17
;
static
const
std
::
size_t
BlockSize_idx
=
18
;
static
const
std
::
size_t
Gemm01MPerBlock_idx
=
19
;
static
const
std
::
size_t
Gemm0NPerBlock_idx
=
20
;
static
const
std
::
size_t
Gemm0KPerBlock_idx
=
21
;
static
const
std
::
size_t
Gemm1NPerBlock_idx
=
22
;
static
const
std
::
size_t
Gemm1KPerBlock_idx
=
23
;
static
const
std
::
size_t
AK1_idx
=
24
;
static
const
std
::
size_t
BK1_idx
=
25
;
static
const
std
::
size_t
B1K1_idx
=
26
;
static
const
std
::
size_t
MPerXDL_idx
=
27
;
static
const
std
::
size_t
NPerXDL_idx
=
28
;
static
const
std
::
size_t
Gemm0MXdlPerWave_idx
=
29
;
static
const
std
::
size_t
Gemm0NXdlPerWave_idx
=
30
;
static
const
std
::
size_t
Gemm1NXdlPerWave_idx
=
31
;
static
const
std
::
size_t
ABlockTransferThreadClusterLengths_K0_M_K1_idx
=
32
;
static
const
std
::
size_t
ABlockTransferThreadClusterArrangeOrder_idx
=
33
;
static
const
std
::
size_t
ABlockTransferSrcAccessOrder_idx
=
34
;
static
const
std
::
size_t
ABlockTransferSrcVectorDim_idx
=
35
;
static
const
std
::
size_t
ABlockTransferSrcScalarPerVector_idx
=
36
;
static
const
std
::
size_t
ABlockTransferDstScalarPerVector_K1_idx
=
37
;
static
const
std
::
size_t
ABlockLdsAddExtraM_idx
=
38
;
static
const
std
::
size_t
B0BlockTransferThreadClusterLengths_K0_N_K1_idx
=
39
;
static
const
std
::
size_t
B0BlockTransferThreadClusterArrangeOrder_idx
=
40
;
static
const
std
::
size_t
B0BlockTransferSrcAccessOrder_idx
=
41
;
static
const
std
::
size_t
B0BlockTransferSrcVectorDim_idx
=
42
;
static
const
std
::
size_t
B0BlockTransferSrcScalarPerVector_idx
=
43
;
static
const
std
::
size_t
B0BlockTransferDstScalarPerVector_K1_idx
=
44
;
static
const
std
::
size_t
B0BlockLdsAddExtraN_idx
=
45
;
static
const
std
::
size_t
B1BlockTransferThreadClusterLengths_K0_N_K1_idx
=
46
;
static
const
std
::
size_t
B1BlockTransferThreadClusterArrangeOrder_idx
=
47
;
static
const
std
::
size_t
B1BlockTransferSrcAccessOrder_idx
=
48
;
static
const
std
::
size_t
B1BlockTransferSrcVectorDim_idx
=
49
;
static
const
std
::
size_t
B1BlockTransferSrcScalarPerVector_idx
=
50
;
static
const
std
::
size_t
B1BlockTransferDstScalarPerVector_K1_idx
=
51
;
static
const
std
::
size_t
B1BlockLdsAddExtraN_idx
=
52
;
static
const
std
::
size_t
CShuffleMXdlPerWavePerShuffle_idx
=
53
;
static
const
std
::
size_t
CShuffleNXdlPerWavePerShuffle_idx
=
54
;
static
const
std
::
size_t
CBlockTransferClusterLengths_MBlock_MWaveMPerXdl_NBlock_NWaveNPerXdl_idx
=
55
;
static
const
std
::
size_t
CBlockTransferScalarPerVector_NWaveNPerXdl_idx
=
56
;
static
const
std
::
size_t
MaskOutUpperTriangle_idx
=
57
;
};
}
// namespace device_batched_gemm_softmax_gemm
}
// namespace host
}
// namespace ck
library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp
0 → 100644
View file @
546a764e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include <sstream>
#include <iterator>
#include <numeric>
#include "ck/host/common.hpp"
namespace
ck
{
namespace
host
{
namespace
device_gemm_multiple_d
{
struct
Problem
{
std
::
size_t
M
=
0
;
std
::
size_t
N
=
0
;
std
::
size_t
K
=
0
;
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransE
=
false
;
std
::
vector
<
bool
>
DsTrans
=
{};
DataType
ADataType
=
DataType
::
Half
;
DataType
BDataType
=
DataType
::
Half
;
DataType
EDataType
=
DataType
::
Half
;
std
::
vector
<
DataType
>
DsDataType
=
{};
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CDEElementOp
=
"ck::Tuple<>"
;
static
const
std
::
size_t
ds_layout_idx
=
3
;
static
const
std
::
size_t
ds_data_type_idx
=
9
;
static
const
std
::
size_t
e_data_type_idx
=
10
;
static
const
std
::
size_t
a_elementwise_op_idx
=
11
;
static
const
std
::
size_t
b_elementwise_op_idx
=
12
;
static
const
std
::
size_t
ds_elementwise_op_idx
=
13
;
static
const
std
::
size_t
gemm_spec_idx
=
14
;
static
const
std
::
size_t
block_size_idx
=
16
;
static
const
std
::
size_t
m_per_block_idx
=
17
;
static
const
std
::
size_t
n_per_block_idx
=
18
;
static
const
std
::
size_t
k_per_block_idx
=
19
;
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
private:
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
};
}
// namespace device_gemm_multiple_d
}
// namespace host
}
// namespace ck
library/src/jit_library/src/common.cpp
0 → 100644
View file @
546a764e
#include "ck/host/common.hpp"
#include "ck_headers.hpp"
#include <stdexcept>
#include <algorithm>
namespace
ck
{
namespace
host
{
std
::
string
ToString
(
DataType
dt
)
{
switch
(
dt
)
{
case
DataType
::
Float
:
return
"float"
;
case
DataType
::
Half
:
return
"ck::half_t"
;
case
DataType
::
Int8
:
return
"int8_t"
;
case
DataType
::
Int32
:
return
"int32_t"
;
}
throw
std
::
runtime_error
(
"Incorrect data type"
);
}
std
::
unordered_map
<
std
::
string_view
,
std
::
string_view
>
GetHeaders
()
{
auto
headers
=
ck_headers
();
headers
.
insert
(
{
"ck/config.h"
,
""
});
return
headers
;
}
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
)
{
return
(
x
+
y
-
std
::
size_t
{
1
})
/
y
;
}
}
// namespace host
}
// namespace ck
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment