Commit 041ac4c9 authored by Jing Zhang's avatar Jing Zhang
Browse files

add pk_i4_t as a struct

parent 23f99eb4
......@@ -166,10 +166,7 @@ struct StaticTensorTupleOfVectorBuffer
// Get X
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
typename Idx>
__host__ __device__ constexpr X GetAsType(Idx) const
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
......@@ -200,10 +197,7 @@ struct StaticTensorTupleOfVectorBuffer
// Set X
// Idx is for S, not X. Idx should be aligned with X
template <typename X,
typename Idx,
typename enable_if<has_same_scalar_type<S, X>::value &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
typename Idx>
__host__ __device__ constexpr void SetAsType(Idx, X x)
{
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
......
......@@ -39,7 +39,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
{
#if 0
#if 1
uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
......@@ -118,7 +118,7 @@ struct PassThroughPack8
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
{
#if 0
#if 1
vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
......@@ -252,6 +252,12 @@ struct PassThrough final : public UnaryOpBase
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<pk_i4_t, pk_i4_t>(pk_i4_t& y, const pk_i4_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
{
......
......@@ -13,7 +13,14 @@ using half_t = _Float16;
using int4_t = _BitInt(4);
using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8);
using pk_i4_t = uint8_t;
//using pk_i4_t = uint8_t;
struct pk_i4_t
{
using type = int8_t;
type data;
__host__ __device__ constexpr pk_i4_t() : data{type{}} {}
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
};
inline constexpr auto next_pow2(uint32_t x)
{
......@@ -168,6 +175,13 @@ struct scalar_type<int4_t>
};
#endif
template <>
struct scalar_type<pk_i4_t>
{
using type = pk_i4_t;
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<f8_fnuz_t>
{
......@@ -1047,6 +1061,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using type = bf8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<pk_i4_t>
{
using type = pk_i4_t::type;
};
template <typename T, index_t N>
struct non_native_vector_base<
T,
......@@ -1166,6 +1186,14 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
static constexpr index_t vector_size = N;
};
template <index_t N>
struct scalar_type<non_native_vector_base<pk_i4_t, N>>
{
using type = typename non_native_vector_base<pk_i4_t, N>::data_t;
static constexpr index_t vector_size = N;
};
// non-native vector_type implementation
template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
......@@ -1868,6 +1896,7 @@ using bf8x64_t = bf8x64_fnuz_t;
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
using pk_i4x8_t = typename vector_type<pk_i4_t, 8>::type;
// u8
// using uint8x2_t = typename vector_type<uint8_t, 2>::type;
......
......@@ -59,10 +59,7 @@ struct DynamicBuffer
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
template <typename X>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{
// X contains multiple T
......@@ -204,10 +201,7 @@ struct DynamicBuffer
element_space_size_ / PackedSize);
}
template <typename X,
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
template <typename X>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{
// X contains multiple T
......
......@@ -115,8 +115,7 @@ struct StaticBufferTupleOfVector
// Get X
// i is offset of S, not X. i should be aligned to X
template <typename X,
index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
index_t I>
__host__ __device__ constexpr auto GetAsType(Number<I> i) const
{
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
......@@ -133,8 +132,7 @@ struct StaticBufferTupleOfVector
// Set X
// i is offset of S, not X. i should be aligned to X
template <typename X,
index_t I,
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
index_t I>
__host__ __device__ constexpr void SetAsType(Number<I> i, X x)
{
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
......
......@@ -76,7 +76,7 @@ struct ReferenceGemm : public device::BaseOperator
}
else if constexpr(is_same_v<ADataType, pk_i4_t>)
{
pk_i4_t i4x2 = arg.a_m_k_(m, k);
uint8_t i4x2 = arg.a_m_k_(m, k).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
......@@ -97,7 +97,7 @@ struct ReferenceGemm : public device::BaseOperator
}
else if constexpr(is_same_v<BDataType, pk_i4_t>)
{
pk_i4_t i4x2 = arg.b_k_n_(k, n);
uint8_t i4x2 = arg.b_k_n_(k, n).data;
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment