Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6f26696f
Commit
6f26696f
authored
Aug 18, 2022
by
Adam Osewski
Browse files
Introduce int4 data type.
parent
bac7df8f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
79 additions
and
2 deletions
+79
-2
CMakeLists.txt
CMakeLists.txt
+8
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+14
-2
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+24
-0
include/ck/utility/math_v2.hpp
include/ck/utility/math_v2.hpp
+33
-0
No files found.
CMakeLists.txt
View file @
6f26696f
...
@@ -21,6 +21,14 @@ rocm_setup_version(VERSION 0.2.0)
...
@@ -21,6 +21,14 @@ rocm_setup_version(VERSION 0.2.0)
include
(
TargetFlags
)
include
(
TargetFlags
)
list
(
APPEND CMAKE_PREFIX_PATH
${
CMAKE_INSTALL_PREFIX
}
${
CMAKE_INSTALL_PREFIX
}
/llvm
${
CMAKE_INSTALL_PREFIX
}
/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip
)
list
(
APPEND CMAKE_PREFIX_PATH
${
CMAKE_INSTALL_PREFIX
}
${
CMAKE_INSTALL_PREFIX
}
/llvm
${
CMAKE_INSTALL_PREFIX
}
/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip
)
option
(
USE_BITINT_EXTENSION_INT4,
"Whether to enable clang's BitInt extension to provide int4 data type."
OFF
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_compile_definitions
(
CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
)
add_compile_options
(
-Wno-bit-int-extension
)
message
(
"CK compiled with USE_BITINT_EXTENSION_INT4 set to
${
USE_BITINT_EXTENSION_INT4
}
"
)
endif
()
## C++
## C++
enable_language
(
CXX
)
enable_language
(
CXX
)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD 17
)
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
6f26696f
...
@@ -62,6 +62,14 @@ struct PassThrough
...
@@ -62,6 +62,14 @@ struct PassThrough
{
{
y
=
type_convert
<
int8_t
>
(
x
);
y
=
type_convert
<
int8_t
>
(
x
);
}
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
__host__
__device__
void
operator
()
<
int4_t
,
int4_t
>
(
int4_t
&
y
,
const
int4_t
&
x
)
const
{
y
=
x
;
}
#endif
};
};
struct
UnaryConvert
struct
UnaryConvert
...
@@ -111,9 +119,13 @@ struct UnarySquare
...
@@ -111,9 +119,13 @@ struct UnarySquare
template
<
typename
T
>
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
,
static_assert
(
is_same_v
<
T
,
float
>
||
is_same_v
<
T
,
double
>
||
is_same_v
<
T
,
int32_t
>
||
is_same_v
<
T
,
int8_t
>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
is_same_v
<
T
,
int4_t
>
#endif
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
x
*
x
;
y
=
x
*
x
;
};
};
};
};
...
...
include/ck/utility/data_type.hpp
View file @
6f26696f
...
@@ -9,6 +9,9 @@ namespace ck {
...
@@ -9,6 +9,9 @@ namespace ck {
using
bhalf_t
=
ushort
;
using
bhalf_t
=
ushort
;
using
half_t
=
_Float16
;
using
half_t
=
_Float16
;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
#endif
// vector_type
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
@@ -130,6 +133,15 @@ struct scalar_type<int8_t>
...
@@ -130,6 +133,15 @@ struct scalar_type<int8_t>
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
scalar_type
<
int4_t
>
{
using
type
=
int4_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
//
//
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
struct
vector_type
<
T
,
1
>
...
@@ -1030,4 +1042,16 @@ struct NumericLimits<half_t>
...
@@ -1030,4 +1042,16 @@ struct NumericLimits<half_t>
__host__
__device__
static
constexpr
half_t
QuietNaN
()
{
return
bit_cast
<
half_t
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
half_t
QuietNaN
()
{
return
bit_cast
<
half_t
>
(
binary_qnan
);
}
};
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
struct
NumericLimits
<
int4_t
>
{
__host__
__device__
static
constexpr
int4_t
Min
()
{
return
int4_t
(
-
7
);
}
__host__
__device__
static
constexpr
int4_t
Max
()
{
return
int4_t
(
7
);
}
__host__
__device__
static
constexpr
int4_t
Lowest
()
{
return
int4_t
(
-
7
);
}
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
}
// namespace ck
}
// namespace ck
include/ck/utility/math_v2.hpp
View file @
6f26696f
...
@@ -42,6 +42,14 @@ static inline __host__ half_t abs(half_t x)
...
@@ -42,6 +42,14 @@ static inline __host__ half_t abs(half_t x)
return
abs_x
;
return
abs_x
;
};
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__host__
int4_t
abs
(
int4_t
x
)
{
int4_t
sgn
=
x
>>
(
4
-
1
);
return
(
x
^
sgn
)
-
sgn
;
}
#endif
static
inline
__host__
bool
isnan
(
float
x
)
{
return
std
::
isnan
(
x
);
};
static
inline
__host__
bool
isnan
(
float
x
)
{
return
std
::
isnan
(
x
);
};
static
inline
__host__
bool
isnan
(
double
x
)
{
return
std
::
isnan
(
x
);
};
static
inline
__host__
bool
isnan
(
double
x
)
{
return
std
::
isnan
(
x
);
};
...
@@ -65,6 +73,14 @@ static inline __host__ bool isnan(half_t x)
...
@@ -65,6 +73,14 @@ static inline __host__ bool isnan(half_t x)
return
(
xx
&
0x7FFF
)
>
0x7C00
;
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__host__
bool
isnan
(
int4_t
x
)
{
(
void
)
x
;
return
false
;
};
#endif
static
inline
__host__
float
sqrt
(
float
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
float
sqrt
(
float
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
static
inline
__host__
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
...
@@ -89,6 +105,15 @@ static inline __device__ int32_t abs(int32_t x)
...
@@ -89,6 +105,15 @@ static inline __device__ int32_t abs(int32_t x)
return
(
x
^
sgn
)
-
sgn
;
return
(
x
^
sgn
)
-
sgn
;
};
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__device__
int4_t
abs
(
int4_t
x
)
{
int4_t
sgn
=
x
>>
(
4
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
#endif
static
inline
__device__
half_t
abs
(
half_t
x
)
{
return
::
__habs
(
x
);
};
static
inline
__device__
half_t
abs
(
half_t
x
)
{
return
::
__habs
(
x
);
};
static
inline
__device__
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
static
inline
__device__
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
...
@@ -107,6 +132,14 @@ static inline __device__ bool isnan(int32_t x)
...
@@ -107,6 +132,14 @@ static inline __device__ bool isnan(int32_t x)
return
false
;
return
false
;
};
};
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static
inline
__device__
bool
isnan
(
int4_t
x
)
{
(
void
)
x
;
return
false
;
};
#endif
static
inline
__device__
bool
isnan
(
half_t
x
)
{
return
::
__hisnan
(
x
);
};
static
inline
__device__
bool
isnan
(
half_t
x
)
{
return
::
__hisnan
(
x
);
};
static
inline
__device__
float
sqrt
(
float
x
)
{
return
::
sqrtf
(
x
);
};
static
inline
__device__
float
sqrt
(
float
x
)
{
return
::
sqrtf
(
x
);
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment