"examples/community/composable_stable_diffusion.py" did not exist on "f7ebe56921f69c05a9273dae4755490a9c51ce12"
Commit 6f26696f authored by Adam Osewski's avatar Adam Osewski
Browse files

Introduce int4 data type.

parent bac7df8f
...@@ -21,6 +21,14 @@ rocm_setup_version(VERSION 0.2.0) ...@@ -21,6 +21,14 @@ rocm_setup_version(VERSION 0.2.0)
include(TargetFlags) include(TargetFlags)
list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip) list(APPEND CMAKE_PREFIX_PATH ${CMAKE_INSTALL_PREFIX} ${CMAKE_INSTALL_PREFIX}/llvm ${CMAKE_INSTALL_PREFIX}/hip /opt/rocm /opt/rocm/llvm /opt/rocm/hip)
option(USE_BITINT_EXTENSION_INT4, "Whether to enable clang's BitInt extension to provide int4 data type." OFF)
if(USE_BITINT_EXTENSION_INT4)
add_compile_definitions(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
add_compile_options(-Wno-bit-int-extension)
message("CK compiled with USE_BITINT_EXTENSION_INT4 set to ${USE_BITINT_EXTENSION_INT4}")
endif()
## C++ ## C++
enable_language(CXX) enable_language(CXX)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
......
...@@ -62,6 +62,14 @@ struct PassThrough ...@@ -62,6 +62,14 @@ struct PassThrough
{ {
y = type_convert<int8_t>(x); y = type_convert<int8_t>(x);
} }
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
__host__ __device__ void operator()<int4_t, int4_t>(int4_t& y, const int4_t& x) const
{
y = x;
}
#endif
}; };
struct UnaryConvert struct UnaryConvert
...@@ -111,9 +119,13 @@ struct UnarySquare ...@@ -111,9 +119,13 @@ struct UnarySquare
template <typename T> template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const __host__ __device__ void operator()(T& y, const T& x) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value, static_assert(is_same_v<T, float> || is_same_v<T, double> || is_same_v<T, int32_t> ||
is_same_v<T, int8_t>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| is_same_v<T, int4_t>
#endif
,
"Data type is not supported by this operation!"); "Data type is not supported by this operation!");
y = x * x; y = x * x;
}; };
}; };
......
...@@ -9,6 +9,9 @@ namespace ck { ...@@ -9,6 +9,9 @@ namespace ck {
using bhalf_t = ushort; using bhalf_t = ushort;
using half_t = _Float16; using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4);
#endif
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
...@@ -130,6 +133,15 @@ struct scalar_type<int8_t> ...@@ -130,6 +133,15 @@ struct scalar_type<int8_t>
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct scalar_type<int4_t>
{
using type = int4_t;
static constexpr index_t vector_size = 1;
};
#endif
// //
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
...@@ -1030,4 +1042,16 @@ struct NumericLimits<half_t> ...@@ -1030,4 +1042,16 @@ struct NumericLimits<half_t>
__host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); } __host__ __device__ static constexpr half_t QuietNaN() { return bit_cast<half_t>(binary_qnan); }
}; };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<int4_t>
{
__host__ __device__ static constexpr int4_t Min() { return int4_t(-7); }
__host__ __device__ static constexpr int4_t Max() { return int4_t(7); }
__host__ __device__ static constexpr int4_t Lowest() { return int4_t(-7); }
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
} // namespace ck } // namespace ck
...@@ -42,6 +42,14 @@ static inline __host__ half_t abs(half_t x) ...@@ -42,6 +42,14 @@ static inline __host__ half_t abs(half_t x)
return abs_x; return abs_x;
}; };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ int4_t abs(int4_t x)
{
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
}
#endif
static inline __host__ bool isnan(float x) { return std::isnan(x); }; static inline __host__ bool isnan(float x) { return std::isnan(x); };
static inline __host__ bool isnan(double x) { return std::isnan(x); }; static inline __host__ bool isnan(double x) { return std::isnan(x); };
...@@ -65,6 +73,14 @@ static inline __host__ bool isnan(half_t x) ...@@ -65,6 +73,14 @@ static inline __host__ bool isnan(half_t x)
return (xx & 0x7FFF) > 0x7C00; return (xx & 0x7FFF) > 0x7C00;
}; };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __host__ bool isnan(int4_t x)
{
(void)x;
return false;
};
#endif
static inline __host__ float sqrt(float x) { return std::sqrt(x); }; static inline __host__ float sqrt(float x) { return std::sqrt(x); };
static inline __host__ double sqrt(double x) { return std::sqrt(x); }; static inline __host__ double sqrt(double x) { return std::sqrt(x); };
...@@ -89,6 +105,15 @@ static inline __device__ int32_t abs(int32_t x) ...@@ -89,6 +105,15 @@ static inline __device__ int32_t abs(int32_t x)
return (x ^ sgn) - sgn; return (x ^ sgn) - sgn;
}; };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __device__ int4_t abs(int4_t x)
{
int4_t sgn = x >> (4 - 1);
return (x ^ sgn) - sgn;
};
#endif
static inline __device__ half_t abs(half_t x) { return ::__habs(x); }; static inline __device__ half_t abs(half_t x) { return ::__habs(x); };
static inline __device__ bool isnan(float x) { return ::isnan(x); }; static inline __device__ bool isnan(float x) { return ::isnan(x); };
...@@ -107,6 +132,14 @@ static inline __device__ bool isnan(int32_t x) ...@@ -107,6 +132,14 @@ static inline __device__ bool isnan(int32_t x)
return false; return false;
}; };
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
static inline __device__ bool isnan(int4_t x)
{
(void)x;
return false;
};
#endif
static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); }; static inline __device__ bool isnan(half_t x) { return ::__hisnan(x); };
static inline __device__ float sqrt(float x) { return ::sqrtf(x); }; static inline __device__ float sqrt(float x) { return ::sqrtf(x); };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment