"...composable_kernel.git" did not exist on "f46a6ffad83ab0245b8087df602f5fcfceb37ad2"
Commit 041ac4c9 authored by Jing Zhang's avatar Jing Zhang
Browse files

add pk_i4_t as a struct

parent 23f99eb4
...@@ -166,10 +166,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -166,10 +166,7 @@ struct StaticTensorTupleOfVectorBuffer
// Get X // Get X
// Idx is for S, not X. Idx should be aligned with X // Idx is for S, not X. Idx should be aligned with X
template <typename X, template <typename X,
typename Idx, typename Idx>
typename enable_if<has_same_scalar_type<S, X>::value &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr X GetAsType(Idx) const __host__ __device__ constexpr X GetAsType(Idx) const
{ {
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
...@@ -200,10 +197,7 @@ struct StaticTensorTupleOfVectorBuffer ...@@ -200,10 +197,7 @@ struct StaticTensorTupleOfVectorBuffer
// Set X // Set X
// Idx is for S, not X. Idx should be aligned with X // Idx is for S, not X. Idx should be aligned with X
template <typename X, template <typename X,
typename Idx, typename Idx>
typename enable_if<has_same_scalar_type<S, X>::value &&
is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
bool>::type = false>
__host__ __device__ constexpr void SetAsType(Idx, X x) __host__ __device__ constexpr void SetAsType(Idx, X x)
{ {
constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
......
...@@ -39,7 +39,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q) ...@@ -39,7 +39,7 @@ __host__ __device__ inline half4_t pki4_to_half4(int q)
__host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q) __host__ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
{ {
#if 0 #if 1
uint8_t x_u8 = ck::bit_cast<uint8_t>(q); uint8_t x_u8 = ck::bit_cast<uint8_t>(q);
uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4); uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
...@@ -118,7 +118,7 @@ struct PassThroughPack8 ...@@ -118,7 +118,7 @@ struct PassThroughPack8
__host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const __host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const
{ {
#if 0 #if 1
vector_type<half_t, 8> result; vector_type<half_t, 8> result;
result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x)); result.template AsType<half4_t>()(Number<0>{}) = pki4_to_half4(bit_cast<int>(x));
...@@ -252,6 +252,12 @@ struct PassThrough final : public UnaryOpBase ...@@ -252,6 +252,12 @@ struct PassThrough final : public UnaryOpBase
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<pk_i4_t, pk_i4_t>(pk_i4_t& y, const pk_i4_t& x) const
{
y = x;
}
template <> template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const __host__ __device__ void operator()<float, double>(float& y, const double& x) const
{ {
......
...@@ -13,7 +13,14 @@ using half_t = _Float16; ...@@ -13,7 +13,14 @@ using half_t = _Float16;
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
using f8_t = _BitInt(8); using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8); using bf8_t = unsigned _BitInt(8);
using pk_i4_t = uint8_t; //using pk_i4_t = uint8_t;
struct pk_i4_t
{
using type = int8_t;
type data;
__host__ __device__ constexpr pk_i4_t() : data{type{}} {}
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
};
inline constexpr auto next_pow2(uint32_t x) inline constexpr auto next_pow2(uint32_t x)
{ {
...@@ -168,6 +175,13 @@ struct scalar_type<int4_t> ...@@ -168,6 +175,13 @@ struct scalar_type<int4_t>
}; };
#endif #endif
template <>
struct scalar_type<pk_i4_t>
{
using type = pk_i4_t;
static constexpr index_t vector_size = 1;
};
template <> template <>
struct scalar_type<f8_fnuz_t> struct scalar_type<f8_fnuz_t>
{ {
...@@ -1047,6 +1061,12 @@ struct nnvb_data_t_selector<bf8_ocp_t> ...@@ -1047,6 +1061,12 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using type = bf8_ocp_t::data_type; using type = bf8_ocp_t::data_type;
}; };
template <>
struct nnvb_data_t_selector<pk_i4_t>
{
using type = pk_i4_t::type;
};
template <typename T, index_t N> template <typename T, index_t N>
struct non_native_vector_base< struct non_native_vector_base<
T, T,
...@@ -1166,6 +1186,14 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>> ...@@ -1166,6 +1186,14 @@ struct scalar_type<non_native_vector_base<bf8_ocp_t, N>>
static constexpr index_t vector_size = N; static constexpr index_t vector_size = N;
}; };
template <index_t N>
struct scalar_type<non_native_vector_base<pk_i4_t, N>>
{
using type = typename non_native_vector_base<pk_i4_t, N>::data_t;
static constexpr index_t vector_size = N;
};
// non-native vector_type implementation // non-native vector_type implementation
template <typename T> template <typename T>
struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>> struct vector_type<T, 1, typename std::enable_if_t<!is_native_type<T>()>>
...@@ -1868,6 +1896,7 @@ using bf8x64_t = bf8x64_fnuz_t; ...@@ -1868,6 +1896,7 @@ using bf8x64_t = bf8x64_fnuz_t;
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type; using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type; using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;
using pk_i4x8_t = typename vector_type<pk_i4_t, 8>::type;
// u8 // u8
// using uint8x2_t = typename vector_type<uint8_t, 2>::type; // using uint8x2_t = typename vector_type<uint8_t, 2>::type;
......
...@@ -59,10 +59,7 @@ struct DynamicBuffer ...@@ -59,10 +59,7 @@ struct DynamicBuffer
__host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
template <typename X, template <typename X>
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
{ {
// X contains multiple T // X contains multiple T
...@@ -204,10 +201,7 @@ struct DynamicBuffer ...@@ -204,10 +201,7 @@ struct DynamicBuffer
element_space_size_ / PackedSize); element_space_size_ / PackedSize);
} }
template <typename X, template <typename X>
typename enable_if<is_same<typename scalar_type<remove_cvref_t<X>>::type,
typename scalar_type<remove_cvref_t<T>>::type>::value,
bool>::type = false>
__host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
{ {
// X contains multiple T // X contains multiple T
......
...@@ -115,8 +115,7 @@ struct StaticBufferTupleOfVector ...@@ -115,8 +115,7 @@ struct StaticBufferTupleOfVector
// Get X // Get X
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
template <typename X, template <typename X,
index_t I, index_t I>
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
__host__ __device__ constexpr auto GetAsType(Number<I> i) const __host__ __device__ constexpr auto GetAsType(Number<I> i) const
{ {
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
...@@ -133,8 +132,7 @@ struct StaticBufferTupleOfVector ...@@ -133,8 +132,7 @@ struct StaticBufferTupleOfVector
// Set X // Set X
// i is offset of S, not X. i should be aligned to X // i is offset of S, not X. i should be aligned to X
template <typename X, template <typename X,
index_t I, index_t I>
typename enable_if<has_same_scalar_type<S, X>::value, bool>::type = false>
__host__ __device__ constexpr void SetAsType(Number<I> i, X x) __host__ __device__ constexpr void SetAsType(Number<I> i, X x)
{ {
constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{}; constexpr auto s_per_x = Number<scalar_type<remove_cvref_t<X>>::vector_size>{};
......
...@@ -76,7 +76,7 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -76,7 +76,7 @@ struct ReferenceGemm : public device::BaseOperator
} }
else if constexpr(is_same_v<ADataType, pk_i4_t>) else if constexpr(is_same_v<ADataType, pk_i4_t>)
{ {
pk_i4_t i4x2 = arg.a_m_k_(m, k); uint8_t i4x2 = arg.a_m_k_(m, k).data;
int8_t i4 = 0; int8_t i4 = 0;
if(k % 2 == 1) if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf; i4 = (i4x2 >> 0) & 0xf;
...@@ -97,7 +97,7 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -97,7 +97,7 @@ struct ReferenceGemm : public device::BaseOperator
} }
else if constexpr(is_same_v<BDataType, pk_i4_t>) else if constexpr(is_same_v<BDataType, pk_i4_t>)
{ {
pk_i4_t i4x2 = arg.b_k_n_(k, n); uint8_t i4x2 = arg.b_k_n_(k, n).data;
int8_t i4 = 0; int8_t i4 = 0;
if(k % 2 == 1) if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf; i4 = (i4x2 >> 0) & 0xf;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment