Commit 8ea1c974 authored by mtgu0705's avatar mtgu0705
Browse files

change the custom int4 to uint8 for verification

parent d201acc0
......@@ -208,7 +208,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i).data;
int i4x2 = b_k_n_permute(j + k * 2, i);
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
......@@ -303,9 +303,9 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck::pk_i4_t i4x2 = b_k_n(k, n);
int8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2.data >> 0) & 0xf;
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2.data >> 4) & 0xf;
i4 = (i4x2 >> 4) & 0xf;
i4 = i4 - 8;
v_b = ck::type_convert<float>(i4);
......
......@@ -324,11 +324,11 @@ struct PassThrough final : public UnaryOpBase
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
template <>
__host__ __device__ void operator()<pk_i4_t, pk_i4_t>(pk_i4_t& y, const pk_i4_t& x) const
{
y = x;
}
// template <>
// __host__ __device__ void operator()<pk_i4_t, pk_i4_t>(pk_i4_t& y, const pk_i4_t& x) const
// {
// y = x;
// }
template <>
__host__ __device__ void operator()<float, double>(float& y, const double& x) const
......
......@@ -11,15 +11,16 @@ namespace ck {
using bhalf_t = ushort;
using half_t = _Float16;
using int4_t = _BitInt(4);
using pk_i4_t = uint8_t;
// custom data type - pack int4 data
struct pk_i4_t
{
using type = uint8_t;
type data;
__host__ __device__ constexpr pk_i4_t() : data{type{}} {}
__host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
};
// struct pk_i4_t
// {
// using type = uint8_t;
// type data;
// __host__ __device__ constexpr pk_i4_t() : data{type{}} {}
// __host__ __device__ constexpr pk_i4_t(type init) : data{init} {}
// };
inline constexpr auto next_pow2(uint32_t x)
{
......@@ -174,12 +175,12 @@ struct scalar_type<int4_t>
};
#endif
template <>
struct scalar_type<pk_i4_t>
{
using type = pk_i4_t;
static constexpr index_t vector_size = 1;
};
// template <>
// struct scalar_type<pk_i4_t>
// {
// using type = pk_i4_t;
// static constexpr index_t vector_size = 1;
// };
template <>
struct scalar_type<f8_fnuz_t>
......@@ -1060,11 +1061,11 @@ struct nnvb_data_t_selector<bf8_ocp_t>
using type = bf8_ocp_t::data_type;
};
template <>
struct nnvb_data_t_selector<pk_i4_t>
{
using type = pk_i4_t::type;
};
// template <>
// struct nnvb_data_t_selector<pk_i4_t>
// {
// using type = pk_i4_t::type;
// };
template <typename T, index_t N>
struct non_native_vector_base<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment