// SPDX-License-Identifier: MIT #pragma once #include "aiter_hip_common.h" namespace ck_tile { template using vec_t = thread_buffer; // using vec_t = ext_vector_t; using int8x2_v = vec_t; using fp8x2_v = vec_t; using fp16x2_v = vec_t; using bf16x2_v = vec_t; using fp32x2_v = vec_t; struct fp4x2_t { using type = uint8_t; type data; __host__ __device__ constexpr fp4x2_t() : data{type{}} {} __host__ __device__ constexpr fp4x2_t(type init) : data{init} {} }; using fp4x2x2_v = vec_t; using fp4x2x4_v = vec_t; using fp4x2x8_v = vec_t; template <> struct vector_traits { using scalar_type = uint8_t; static constexpr index_t vector_size = 1; }; template <> struct numeric { // maximum finite value CK_TILE_HOST_DEVICE static constexpr fp32_t max() { return 6.0f; } }; CK_TILE_DEVICE fp32x2_v amd_assembly_pk_mul_f32(fp32x2_v a, fp32x2_t b) { fp32x2_v c; #if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx936__) asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); #else asm volatile("v_mul_f32 %[v_result0], %[v_a0], %[v_b0]\n\t" "v_mul_f32 %[v_result1], %[v_a1], %[v_b1]\n\t" : [v_result0] "=v"(c[0]), [v_result1] "=v"(c[1]) : [v_a0] "v"(a[0]), [v_a1] "v"(a[1]), [v_b0] "v"(b[0]), [v_b1] "v"(b[1])); #endif return c; } #ifdef ENABLE_FP8 #if defined(__gfx938__) CK_TILE_DEVICE fp8x2_v amd_assembly_cvt_pk_fp8_f32(fp32_t a, fp32_t b) { int16x2_t c; asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2, %3 op_sel:[0,0,0,0]" : "=v"(c) : "v"(a), "v"(b), "v"(c)); return bit_cast(c[0]); } CK_TILE_DEVICE fp8x2_v amd_assembly_cvt_pk_bf8_f32(fp32_t a, fp32_t b) { int16x2_t c; asm volatile("v_cvt_pk_bf8_f32 %0, %1, %2, %3 op_sel:[0,0,0,0]" : "=v"(c) : "v"(a), "v"(b), "v"(c)); return bit_cast(c[0]); } #endif #endif #if defined(__gfx946__) CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_f32(fp32_t a, fp32_t b, fp32_t scale) { int16x2_t c; // permute high bits and low bits to match the order of the original vector asm volatile("v_cvt_scalef32_pk_fp4_f32 %0, %1, %2, %3" : "=v"(c) : "v"(b), "v"(a), "v"(scale)); return bit_cast(bit_cast(c[0])[0]); } CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_f16(fp16x2_v a, fp32_t scale) { int16x2_t c; // permute high bits and low bits to match the order of the original vector asm volatile("v_cvt_scalef32_pk_fp4_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(scale)); return bit_cast(bit_cast(c[0])[0]); } CK_TILE_DEVICE fp4x2_t amd_assembly_cvt_scalef32_pk_fp4_bf16(bf16x2_v a, fp32_t scale) { int16x2_t c; // permute high bits and low bits to match the order of the original vector asm volatile("v_cvt_scalef32_pk_fp4_bf16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(scale)); return bit_cast(bit_cast(c[0])[0]); } #endif // convert any to fp32x?_t one by one template ), bool> = false> CK_TILE_HOST_DEVICE constexpr vec_t vec_convert(vec_t x) { using fp32xX_t = vec_t; fp32xX_t tmp; for(size_t i = 0; i < N; i++) { tmp[i] = type_convert(x[i]); } return tmp; } template = false, std::enable_if_t<(!(std::is_same_v)), bool> = false> CK_TILE_HOST_DEVICE constexpr vec_t vec_convert(vec_t x, fp32_t inverted_scale) { if constexpr(!std::is_same_v) { using fp32xX_t = vec_t; fp32xX_t tmp = vec_convert(x); return vec_convert(tmp, inverted_scale); } else { // fp32->?? return vec_convert(x, inverted_scale); } } // fp32x2 -> fp8x2 #if defined(__gfx938__) || defined(__gfx946__) CK_TILE_HOST_DEVICE constexpr fp8x2_v fp32x2_t_to_fp8x2_t(fp32x2_v x, fp32_t inverted_scale) { using vec_ti = vector_traits; constexpr int vec_size = vec_ti::vector_size; constexpr auto interpret = numeric_traits::f8_interpret; fp32x2_v tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); return (interpret == fp8_interpretation::E4M3_FNUZ) || (interpret == fp8_interpretation::E4M3_OCP) ? amd_assembly_cvt_pk_fp8_f32(tmp[0], tmp[1]) : amd_assembly_cvt_pk_bf8_f32(tmp[0], tmp[1]); } #endif // fp32x2 -> int8x2 CK_TILE_HOST_DEVICE constexpr int8x2_v fp32x2_t_to_int8x2_t(fp32x2_v x, fp32_t inverted_scale) { fp32x2_v tmp = amd_assembly_pk_mul_f32(x, fp32x2_t{inverted_scale, inverted_scale}); int8x2_v out; out[0] = static_cast(tmp[0]); out[1] = static_cast(tmp[1]); return out; } #if defined(__gfx946__) // fp32x2 -> fp4x2 CK_TILE_HOST_DEVICE constexpr fp4x2_t fp32x2_t_to_fp4x2_t(fp32x2_v x, fp32_t inverted_scale) { return amd_assembly_cvt_scalef32_pk_fp4_f32(x[0], x[1], inverted_scale); } // fp16x2 -> fp4x2 CK_TILE_HOST_DEVICE constexpr fp4x2_t fp16x2_t_to_fp4x2_t(fp16x2_v x, fp32_t inverted_scale) { return amd_assembly_cvt_scalef32_pk_fp4_f16(x, inverted_scale); } // bf16x2 -> fp4x2 CK_TILE_HOST_DEVICE constexpr fp4x2_t bf16x2_t_to_fp4x2_t(bf16x2_v x, fp32_t inverted_scale) { return amd_assembly_cvt_scalef32_pk_fp4_bf16(x, inverted_scale); } #endif #define CK_TILE_TYPE_CONVERT(dtype_, stype_, vec_size_) \ template <> \ CK_TILE_HOST_DEVICE constexpr vec_t \ vec_convert(vec_t x, \ fp32_t inverted_scale) \ { \ constexpr int iter_num = vec_size_ / 2; \ vec_t out; \ using vec_i2 = vec_t; \ using vec_o2 = vec_t; \ _Pragma("unroll") for(size_t i = 0; i < iter_num; i++) \ { \ vec_o2 tmp = stype_##x2##_t_to_##dtype_##x2##_t(x.template get_as()(i), \ inverted_scale); \ out.template get_as()(i) = tmp; \ } \ return out; \ } #if defined(__gfx938__) || defined(__gfx946__) CK_TILE_TYPE_CONVERT(fp8, fp32, 2) CK_TILE_TYPE_CONVERT(fp8, fp32, 4) CK_TILE_TYPE_CONVERT(fp8, fp32, 8) CK_TILE_TYPE_CONVERT(fp8, fp32, 16) CK_TILE_TYPE_CONVERT(fp8, fp32, 32) #endif CK_TILE_TYPE_CONVERT(int8, fp32, 2) CK_TILE_TYPE_CONVERT(int8, fp32, 4) CK_TILE_TYPE_CONVERT(int8, fp32, 8) CK_TILE_TYPE_CONVERT(int8, fp32, 16) CK_TILE_TYPE_CONVERT(int8, fp32, 32) #undef CK_TILE_TYPE_CONVERT // 4 bit vec convert // convert any to fp32x?_t one by one #if defined(__gfx946__) template = false, std::enable_if_t<((std::is_same_v)), bool> = false> CK_TILE_HOST_DEVICE constexpr vec_t vec_convert(vec_t x, fp32_t inverted_scale); #define CK_TILE_TYPE_CONVERT(dtype_, stype_, vec_size_) \ template <> \ CK_TILE_HOST_DEVICE constexpr vec_t \ vec_convert(vec_t x, \ fp32_t inverted_scale) \ { \ constexpr int iter_num = vec_size_ / 2; \ vec_t out; \ using vec_i2 = vec_t; \ using vec_o2 = dtype_##_t; \ _Pragma("unroll") for(size_t i = 0; i < iter_num; i++) \ { \ vec_o2 tmp = \ stype_##x2##_t_to_##dtype_##_t(x.template get_as()(i), inverted_scale); \ out.template get_as()(i) = tmp; \ } \ return out; \ } CK_TILE_TYPE_CONVERT(fp4x2, fp32, 4) CK_TILE_TYPE_CONVERT(fp4x2, fp32, 8) CK_TILE_TYPE_CONVERT(fp4x2, fp32, 16) CK_TILE_TYPE_CONVERT(fp4x2, fp32, 32) CK_TILE_TYPE_CONVERT(fp4x2, fp16, 4) CK_TILE_TYPE_CONVERT(fp4x2, fp16, 8) CK_TILE_TYPE_CONVERT(fp4x2, fp16, 16) CK_TILE_TYPE_CONVERT(fp4x2, fp16, 32) CK_TILE_TYPE_CONVERT(fp4x2, bf16, 4) CK_TILE_TYPE_CONVERT(fp4x2, bf16, 8) CK_TILE_TYPE_CONVERT(fp4x2, bf16, 16) CK_TILE_TYPE_CONVERT(fp4x2, bf16, 32) #endif #undef CK_TILE_TYPE_CONVERT } // namespace ck_tile