





#pragma once
#include <ck/ck.hpp>

namespace ck {

template<typename T, typename S, int N>
struct FastInterleavedAndBiasedNumericArrayConverter {
};


template<>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, 4> {
    using result_type = Array<half_t, 4>;
    using source_type = Array<uint8_t, 4>;


    CUTLASS_DEVICE
    static result_type convert(source_type const& source)
    {
        result_type result;
        uint32_t*      h   = reinterpret_cast<uint32_t*>(&result);
        uint32_t const i8s = reinterpret_cast<uint32_t const&>(source);

        static constexpr uint32_t mask_for_elt_01     = 0x05020500;
        static constexpr uint32_t mask_for_elt_23     = 0x05030501;
        static constexpr uint32_t start_byte_for_fp16 = 0x64646464;

        asm volatile("v_perm_b32 %0,%1,%2,%3;\n" : "=v"(h[0]) : "v"(start_byte_for_fp16), "v"(i8s), "v"(mask_for_elt_01));
        asm volatile("v_perm_b32 %0,%1,%2,%3;\n" : "=v"(h[1]) : "v"(start_byte_for_fp16), "v"(i8s), "v"(mask_for_elt_23));

        static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x6480;
        asm volatile("v_sub_f16x2 %0, %1, %2;\n" : "=v"(h[0]) : "v"(h[0]), "v"(I8s_TO_F16s_MAGIC_NUM));
        asm volatile("v_sub_f16x2 %0, %1, %2;\n" : "=v"(h[1]) : "v"(h[1]), "v"(I8s_TO_F16s_MAGIC_NUM));

        return result;
    }

    CUTLASS_DEVICE
    result_type operator()(source_type const& s)
    {
        return convert(s);
    }
};

template<int N>
struct FastInterleavedAndBiasedNumericArrayConverter<half_t, uint8_t, N> {
    static constexpr int VEC_WIDTH = 4;
    static_assert(!(N % VEC_WIDTH), "N must be multiple of 4.");

    using result_type = Array<half_t, N>;
    using source_type = Array<uint8_t, N>;

    CUTLASS_DEVICE
    static result_type convert(source_type const& source)
    {
        using scalar_result_type = typename result_type::Element;
        using scalar_source_type = typename source_type::Element;
        FastInterleavedAndBiasedNumericArrayConverter<scalar_result_type, scalar_source_type, VEC_WIDTH>
            convert_vector_;

        result_type result;
        using vec_result = Array<scalar_result_type, VEC_WIDTH>;
        using vec_source = Array<scalar_source_type, VEC_WIDTH>;

        vec_result*       result_ptr = reinterpret_cast<vec_result*>(&result);
        vec_source const* source_ptr = reinterpret_cast<vec_source const*>(&source);

        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < N / VEC_WIDTH; ++i) {
            result_ptr[i] = convert_vector_(source_ptr[i]);
        }

        return result;
    }

    CUTLASS_DEVICE
    result_type operator()(source_type const& s)
    {
        return convert(s);
    }
};

} // namespace ck































