Commit 8ea1c974 authored by mtgu0705's avatar mtgu0705
Browse files

change the custom int4 to uint8 for verification

parent d201acc0
...@@ -208,7 +208,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -208,7 +208,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
for(int k = 0; k < 4; k++) for(int k = 0; k < 4; k++)
{ {
int i4x2 = b_k_n_permute(j + k * 2, i).data; int i4x2 = b_k_n_permute(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf; input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf; input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
} }
...@@ -303,9 +303,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -303,9 +303,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::pk_i4_t i4x2 = b_k_n(k, n); ck::pk_i4_t i4x2 = b_k_n(k, n);
int8_t i4 = 0; int8_t i4 = 0;
if(k % 2 == 1) if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf; i4 = (i4x2 >> 0) & 0xf;
else else
i4 = (i4x2.data >> 4) & 0xf; i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8; i4 = i4 - 8;
v_b = ck::type_convert<float>(i4); v_b = ck::type_convert<float>(i4);
......
...@@ -324,11 +324,11 @@ struct PassThrough final : public UnaryOpBase ...@@ -324,11 +324,11 @@ struct PassThrough final : public UnaryOpBase
template <typename Y, typename X> template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const; __host__ __device__ void operator()(Y& y, const X& x) const;
template <> // template <>
__host__ __device__ void operator()<pk_i4_t, pk_i4_t>(pk_i4_t& y, const pk_i4_t& x) const // __host__ __device__ void operator()<pk_i4_t, pk_i4_t>(pk_i4_t& y, const pk_i4_t& x) const
{ // {
y = x; // y = x;
} // }
template <> template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const __host__ __device__ void operator()<float, double>(float& y, const double& x) const
......
...@@ -11,15 +11,16 @@ namespace ck { ...@@ -11,15 +11,16 @@ namespace ck {
using bhalf_t = ushort; using bhalf_t = ushort;
using half_t = _Float16; using half_t = _Float16;
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
using pk_i4_t = uint8_t;
// custom data type - pack int4 data // custom data type - pack int4 data
struct pk_i4_t // struct pk_i4_t
{ // {
using type = uint8_t; // using type = uint8_t;
type data; // type data;
__host__ __device__ constexpr pk_i4_t() : data{type{}} {} // __host__ __device__ constexpr pk_i4_t() : data{type{}} {}
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {} // __host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
}; // };
inline constexpr auto next_pow2(uint32_t x) inline constexpr auto next_pow2(uint32_t x)
{ {
...@@ -174,12 +175,12 @@ struct scalar_type<int4_t> ...@@ -174,12 +175,12 @@ struct scalar_type<int4_t>
}; };
#endif #endif
template <> // template <>
struct scalar_type<pk_i4_t> // struct scalar_type<pk_i4_t>
{ // {
using type = pk_i4_t; // using type = pk_i4_t;
static constexpr index_t vector_size = 1; // static constexpr index_t vector_size = 1;
}; // };
template <> template <>
struct scalar_type<f8_fnuz_t> struct scalar_type<f8_fnuz_t>
...@@ -1060,11 +1061,11 @@ struct nnvb_data_t_selector<bf8_ocp_t> ...@@ -1060,11 +1061,11 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using type = bf8_ocp_t::data_type; using type = bf8_ocp_t::data_type;
}; };
template <> // template <>
struct nnvb_data_t_selector<pk_i4_t> // struct nnvb_data_t_selector<pk_i4_t>
{ // {
using type = pk_i4_t::type; // using type = pk_i4_t::type;
}; // };
template <typename T, index_t N> template <typename T, index_t N>
struct non_native_vector_base< struct non_native_vector_base<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment