// SPDX-License-Identifier: MIT #pragma once #include "hip_reduce.h" #include "opus/opus.hpp" // todo: remove this to use aiterTensor dtype #include #include #include namespace aiter { using namespace opus; #define RT 0 #define GROUP_NT 3 using index_t = int; ///////////////////////////////////////////////////////////////////////////////////////////////////////// // scaled type conversion: v_pk_mul_f32 + v_med3_f32 + v_cvt_pk_{fp8,bf8}_f32 // Identical ISA to ck_tile::vec_convert for performance parity OPUS_D fp32x2_t pk_mul_f32(fp32x2_t a, fp32x2_t b) { fp32x2_t c; asm volatile("v_pk_mul_f32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); return c; } // fp32x2 -> fp8x2 with scale + saturation clamp (E4M3) // ISA: v_pk_mul_f32 + v_med3_f32 x2 + v_cvt_pk_fp8_f32 template , bool> = true> OPUS_D decltype(auto) fp32_to_fp8_scaled_x2(const S& s, float inverted_scale) { fp32x2_t tmp = pk_mul_f32(s, fp32x2_t{inverted_scale, inverted_scale}); #if defined(__gfx942__) constexpr float hi = 240.0f, lo = -240.0f; #else constexpr float hi = 448.0f, lo = -448.0f; #endif float a = tmp[0], b = tmp[1]; int w; asm volatile("v_med3_f32 %1, %1, %3, %4\n" "v_med3_f32 %2, %2, %3, %4\n" "v_cvt_pk_fp8_f32 %0, %1, %2" : "=v"(w), "+v"(a), "+v"(b) : "v"(lo), "v"(hi)); return __builtin_bit_cast(fp8x2_t, static_cast(w)); } template , bool> = true> OPUS_D decltype(auto) fp32_to_fp8_scaled_x4(const S& s, float inverted_scale) { auto lo = fp32_to_fp8_scaled_x2(fp32x2_t{s[0], s[1]}, inverted_scale); auto hi = fp32_to_fp8_scaled_x2(fp32x2_t{s[2], s[3]}, inverted_scale); return fp8x4_t{lo[0], lo[1], hi[0], hi[1]}; } // fp32x2 -> bf8x2 with scale + saturation clamp (E5M2) // ISA: v_pk_mul_f32 + v_med3_f32 x2 + v_cvt_pk_bf8_f32 template , bool> = true> OPUS_D decltype(auto) fp32_to_bf8_scaled_x2(const S& s, float inverted_scale) { fp32x2_t tmp = pk_mul_f32(s, fp32x2_t{inverted_scale, inverted_scale}); constexpr float hi = 57344.0f, lo = -57344.0f; float a = tmp[0], b = tmp[1]; int w; asm volatile("v_med3_f32 %1, %1, %3, %4\n" "v_med3_f32 %2, %2, %3, %4\n" "v_cvt_pk_bf8_f32 %0, %1, %2" : "=v"(w), "+v"(a), "+v"(b) : "v"(lo), "v"(hi)); return __builtin_bit_cast(bf8x2_t, static_cast(w)); } template , bool> = true> OPUS_D decltype(auto) fp32_to_bf8_scaled_x4(const S& s, float inverted_scale) { auto lo = fp32_to_bf8_scaled_x2(fp32x2_t{s[0], s[1]}, inverted_scale); auto hi = fp32_to_bf8_scaled_x2(fp32x2_t{s[2], s[3]}, inverted_scale); return bf8x4_t{lo[0], lo[1], hi[0], hi[1]}; } // fp32x2 -> i8x2 with scale // ISA: v_pk_mul_f32 + v_cvt_i32_f32 x2 template , bool> = true> OPUS_D decltype(auto) fp32_to_i8_scaled_x2(const S& s, float inverted_scale) { fp32x2_t tmp = pk_mul_f32(s, fp32x2_t{inverted_scale, inverted_scale}); return i8x2_t{static_cast(tmp[0]), static_cast(tmp[1])}; } template , bool> = true> OPUS_D decltype(auto) fp32_to_i8_scaled_x4(const S& s, float inverted_scale) { fp32x2_t tmp0 = pk_mul_f32(fp32x2_t{s[0], s[1]}, fp32x2_t{inverted_scale, inverted_scale}); fp32x2_t tmp1 = pk_mul_f32(fp32x2_t{s[2], s[3]}, fp32x2_t{inverted_scale, inverted_scale}); return i8x4_t{static_cast(tmp0[0]), static_cast(tmp0[1]), static_cast(tmp1[0]), static_cast(tmp1[1])}; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // fp16x2 -> fp4 with scale (v_cvt_scalef32_pk_fp4_f16, gfx950 only) // opus.hpp has fp32->fp4 and bf16->fp4 but NOT fp16->fp4 #if defined(__gfx950__) template , bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x2(const S& s, float scale, number = {}) { u32_t w; w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, s, scale, sel); return __builtin_bit_cast(array, static_cast(w)); } template , bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x4(const S& s, float scale) { u32_t w; w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[0], s[1]}, scale, 0); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[2], s[3]}, scale, 1); return __builtin_bit_cast(array, static_cast(w)); } template , bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x8(const S& s, float scale) { u32_t w; w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[0], s[1]}, scale, 0); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[2], s[3]}, scale, 1); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[4], s[5]}, scale, 2); w = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(w, fp16x2_t{s[6], s[7]}, scale, 3); return __builtin_bit_cast(array, w); } #else template , bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x2(const S&, float) { return array{}; } template , bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x4(const S&, float) { return array{}; } template , bool> = true> OPUS_D constexpr decltype(auto) fp16_to_fp4_scaled_x8(const S&, float) { return array{}; } #endif // bf16 -> fp4 larger vectors (bf16x4/x8) using opus bf16_to_fp4_packed_x2 template , bool> = true> OPUS_D constexpr decltype(auto) bf16_to_fp4_scaled_x4(const S& s, float scale) { auto lo = bf16_to_fp4_packed_x2(bf16x2_t{s[0], s[1]}, scale); auto hi = bf16_to_fp4_packed_x2(bf16x2_t{s[2], s[3]}, scale); return array{lo, hi}; } template , bool> = true> OPUS_D constexpr decltype(auto) bf16_to_fp4_scaled_x8(const S& s, float scale) { auto a = bf16_to_fp4_packed_x2(bf16x2_t{s[0], s[1]}, scale); auto b = bf16_to_fp4_packed_x2(bf16x2_t{s[2], s[3]}, scale); auto c = bf16_to_fp4_packed_x2(bf16x2_t{s[4], s[5]}, scale); auto d = bf16_to_fp4_packed_x2(bf16x2_t{s[6], s[7]}, scale); return array{a, b, c, d}; } // fp4 -> fp32/bf16/fp16 dequant helpers. Input fp4_t stores two packed fp4 values. template >, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_scaled_x2(const S& s, float scale) { return fp4_to_fp32_packed_x2(s, scale); } template >, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_scaled_x4(const S& s, float scale) { return fp4_to_fp32_packed_x4(s, scale); } template >, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_scaled_x8(const S& s, float scale) { return fp4_to_fp32_packed_x8(s, scale); } template >, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_bf16_scaled_x2(const S& s, float scale) { #if defined(__gfx950__) u32_t packed; if constexpr(std::is_same_v) { packed = static_cast(__builtin_bit_cast(u8_t, s)); } else { packed = static_cast(__builtin_bit_cast(u8_t, s[0])); } return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(packed, scale, 0); #else auto x = fp4_to_fp32_scaled_x2(s, scale); return bf16x2_t{static_cast(x[0]), static_cast(x[1])}; #endif } template >, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_bf16_scaled_x4(const S& s, float scale) { auto lo = fp4_to_bf16_scaled_x2(s[0], scale); auto hi = fp4_to_bf16_scaled_x2(s[1], scale); return bf16x4_t{lo[0], lo[1], hi[0], hi[1]}; } template >, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_bf16_scaled_x8(const S& s, float scale) { auto a = fp4_to_bf16_scaled_x2(s[0], scale); auto b = fp4_to_bf16_scaled_x2(s[1], scale); auto c = fp4_to_bf16_scaled_x2(s[2], scale); auto d = fp4_to_bf16_scaled_x2(s[3], scale); return bf16x8_t{a[0], a[1], b[0], b[1], c[0], c[1], d[0], d[1]}; } template >, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp16_scaled_x2(const S& s, float scale) { auto x = fp4_to_fp32_scaled_x2(s, scale); return fp16x2_t{static_cast(x[0]), static_cast(x[1])}; } template >, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp16_scaled_x4(const S& s, float scale) { auto x = fp4_to_fp32_scaled_x4(s, scale); return fp16x4_t{static_cast(x[0]), static_cast(x[1]), static_cast(x[2]), static_cast(x[3])}; } template >, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp16_scaled_x8(const S& s, float scale) { auto x = fp4_to_fp32_scaled_x8(s, scale); return fp16x8_t{static_cast(x[0]), static_cast(x[1]), static_cast(x[2]), static_cast(x[3]), static_cast(x[4]), static_cast(x[5]), static_cast(x[6]), static_cast(x[7])}; } template && std::is_same_v, fp4_t> && !is_any_of_v, array, array>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp32_scaled(const S& s, float scale) { constexpr index_t N = size(); vector_t out; static_for([&](auto i) { auto x = fp4_to_fp32_scaled_x2(s[i.value], scale); out[i.value * 2] = x[0]; out[i.value * 2 + 1] = x[1]; }); return out; } template && std::is_same_v, fp4_t> && !is_any_of_v, array, array>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_bf16_scaled(const S& s, float scale) { constexpr index_t N = size(); vector_t out; static_for([&](auto i) { auto x = fp4_to_bf16_scaled_x2(s[i.value], scale); out[i.value * 2] = x[0]; out[i.value * 2 + 1] = x[1]; }); return out; } template && std::is_same_v, fp4_t> && !is_any_of_v, array, array>, bool> = true> OPUS_D constexpr decltype(auto) fp4_to_fp16_scaled(const S& s, float scale) { constexpr index_t N = size(); vector_t out; static_for([&](auto i) { auto x = fp4_to_fp16_scaled_x2(s[i.value], scale); out[i.value * 2] = x[0]; out[i.value * 2 + 1] = x[1]; }); return out; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // scaled_cast: type conversion with scale multiplication (ck_tile::vec_convert equivalent) // Usage: aiter::scaled_cast(fp32_vec, inverted_scale) // --- 8-bit targets (fp8, bf8, i8): fp32 source x2/x4 --- template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return fp32_to_fp8_scaled_x2(s, inverted_scale); } template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return fp32_to_bf8_scaled_x2(s, inverted_scale); } template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return fp32_to_i8_scaled_x2(s, inverted_scale); } template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return fp32_to_fp8_scaled_x4(s, inverted_scale); } template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return fp32_to_bf8_scaled_x4(s, inverted_scale); } template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return fp32_to_i8_scaled_x4(s, inverted_scale); } // --- fp4 target: fp32 source (delegates to opus cast) --- template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return fp32_to_fp4_packed_x2(s, inverted_scale); } template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return fp32_to_fp4_packed_x4(s, inverted_scale); } template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return fp32_to_fp4_packed_x8(s, inverted_scale); } // --- fp4 target: bf16 source --- template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return bf16_to_fp4_packed_x2(s, inverted_scale); } template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return bf16_to_fp4_scaled_x4(s, inverted_scale); } template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return bf16_to_fp4_scaled_x8(s, inverted_scale); } // --- fp4 target: fp16 source --- template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return fp16_to_fp4_scaled_x2(s, inverted_scale); } template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return fp16_to_fp4_scaled_x4(s, inverted_scale); } template && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { return fp16_to_fp4_scaled_x8(s, inverted_scale); } // --- fp4 source: dequant to fp32 --- template > && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float scale) { return fp4_to_fp32_scaled_x2(s, scale); } template > && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float scale) { return fp4_to_fp32_scaled_x4(s, scale); } template > && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float scale) { return fp4_to_fp32_scaled_x8(s, scale); } template && std::is_same_v, fp4_t> && !is_any_of_v, array, array> && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float scale) { return fp4_to_fp32_scaled(s, scale); } // --- fp4 source: dequant to bf16 --- template > && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float scale) { return fp4_to_bf16_scaled_x2(s, scale); } template > && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float scale) { return fp4_to_bf16_scaled_x4(s, scale); } template > && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float scale) { return fp4_to_bf16_scaled_x8(s, scale); } template && std::is_same_v, fp4_t> && !is_any_of_v, array, array> && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float scale) { return fp4_to_bf16_scaled(s, scale); } // --- fp4 source: dequant to fp16 --- template > && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float scale) { return fp4_to_fp16_scaled_x2(s, scale); } template > && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float scale) { return fp4_to_fp16_scaled_x4(s, scale); } template > && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float scale) { return fp4_to_fp16_scaled_x8(s, scale); } template && std::is_same_v, fp4_t> && !is_any_of_v, array, array> && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float scale) { return fp4_to_fp16_scaled(s, scale); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// // auto-fold: build flat output vector using x2 primitives in a loop // 8-bit targets (fp8, bf8, i8): any fp32 vector size via x2 loop template && std::is_same_v, fp32_t> && !is_any_of_v && (std::is_same_v || std::is_same_v || std::is_same_v), bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { constexpr index_t N = size(); static_assert(N % 2 == 0); vector_t out; static_for([&](auto i) { auto pair = scaled_cast(fp32x2_t{s[i.value * 2], s[i.value * 2 + 1]}, inverted_scale); out[i.value * 2] = pair[0]; out[i.value * 2 + 1] = pair[1]; }); return out; } // two-hop: non-fp32 source -> convert to fp32 via static_cast -> scaled_cast to 8-bit target // Uses static_cast instead of opus::cast to handle _Float16/__fp16 mismatch template && !std::is_same_v, fp32_t> && (std::is_same_v || std::is_same_v || std::is_same_v), bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { constexpr index_t N = size(); vector_t fp32_vec; static_for([&](auto i) { fp32_vec[i.value] = static_cast(s[i.value]); }); return scaled_cast(fp32_vec, inverted_scale); } // fp4 target: any fp32 vector size via x2 loop template < typename D, typename S, std::enable_if_t && std::is_same_v, fp32_t> && !is_any_of_v && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { constexpr index_t N = size(); static_assert(N % 2 == 0); array out; static_for([&](auto i) { auto packed = scaled_cast(fp32x2_t{s[i.value * 2], s[i.value * 2 + 1]}, inverted_scale); out[i.value] = packed[0]; }); return out; } // fp4 target: non-fp32 source -> convert to fp32 via static_cast -> scaled_cast to fp4 template && !std::is_same_v, fp32_t> && !is_any_of_v && std::is_same_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { constexpr index_t N = size(); vector_t fp32_vec; static_for([&](auto i) { fp32_vec[i.value] = static_cast(s[i.value]); }); return scaled_cast(fp32_vec, inverted_scale); } // general fallback: fp32 source -> any non-quantized target with scale template && std::is_same_v, fp32_t> && !is_any_of_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { constexpr index_t N = size(); S tmp; static_for([&](auto i) { tmp[i.value] = s[i.value] * inverted_scale; }); if constexpr(std::is_same_v) { return tmp; } else { return cast(tmp); } } // general fallback: non-fp32 source -> any non-quantized target with scale (two-hop via fp32) template && !std::is_same_v, fp32_t> && !is_any_of_v, bool> = true> OPUS_D decltype(auto) scaled_cast(const S& s, float inverted_scale) { constexpr index_t N = size(); vector_t fp32_vec; static_for([&](auto i) { fp32_vec[i.value] = static_cast(s[i.value]); }); return scaled_cast(fp32_vec, inverted_scale); } // Load a large vector (vec_size elements of type T) from gmem buffer in chunks. // Each chunk issues one buffer_load instruction of chunk_bytes bytes (4/8/16 -> // dword/dwordx2/dwordx4). Total loads = vec_size * sizeof(T) / chunk_bytes. // // interleave=false: chunks are contiguous in GMEM. // GMEM layout (per thread): // base + row_offset // |<-- chunk_bytes -->|<-- chunk_bytes -->|<-- chunk_bytes -->|<-- chunk_bytes -->| // [ chunk 0 ][ chunk 1 ][ chunk 2 ][ chunk 3 ] // // interleave=true: chunks are strided by interleave_thread_size * chunk_bytes in GMEM. // GMEM layout (thread 0 loads marked with *, other threads fill the gaps): // base + row_offset // |<- chunk_bytes ->|<- (interleave_thread_size-1)*chunk_bytes gap ->|<- chunk_bytes ->|... // [ *chunk 0 (t0)* ][ chunk 0 (t1) ]...[ chunk 0 (tN-1) ] [ *chunk 1 (t0)* ]... // // Each thread's chunks are interleaved with other threads' data, // stride = interleave_thread_size * chunk_bytes bytes between chunks. // // Example: T=bf16(2B), vec_size=32, chunk_bytes=16, interleave_thread_size=256 // total = 64B -> 4x buffer_load_dwordx4, each loading 8 bf16 elements. // interleave stride = 256 * 16 = 4096 bytes between chunks. template __device__ opus::vector_t load_vector_nbytes(opus::gmem& buffer, int row_offset) { static_assert(vec_size * sizeof(T) % chunk_bytes == 0, "vec_size * sizeof(T) must be a multiple of chunk_bytes"); static constexpr index_t num_chunks = vec_size * sizeof(T) / chunk_bytes; constexpr index_t chunk_size_elements = chunk_bytes / sizeof(T); constexpr index_t interleave_bytes = interleave_thread_size * chunk_bytes; opus::vector_t result; T* result_ptr = reinterpret_cast(&result); opus::static_for([&](auto i) { constexpr index_t chunk_offset_bytes = interleave ? i.value * interleave_bytes : i.value * chunk_bytes; constexpr index_t chunk_offset_elements = chunk_offset_bytes / sizeof(T); opus::vector_t* chunk_ptr = reinterpret_cast*>( result_ptr + i.value * chunk_size_elements); *chunk_ptr = buffer.template load(row_offset, chunk_offset_elements); }); return result; } // Store a vector (vec_size elements of DTYPE_I) to gmem buffer in chunks, with optional type // conversion. Mirror of load_vector_nbytes but for writing. Each chunk issues one buffer_store of // chunk_bytes bytes. // // Template params: // T : buffer element type (storage type in GMEM) // DTYPE_I : input element type in registers (e.g. float) // vec_size : number of input elements // chunk_bytes: bytes per buffer_store instruction (4/8/16 -> dword/dwordx2/dwordx4) // T_R : target conversion type before storing (default = T) // if T_R != DTYPE_I, data is converted per-chunk before store. // interleave : same strided layout as load_vector_nbytes // (stride = interleave_thread_size * chunk_bytes) // // interleave=false: chunks are contiguous in GMEM. // GMEM layout (per thread): // base + row_offset // |<-- chunk_bytes -->|<-- chunk_bytes -->|<-- chunk_bytes -->|<-- chunk_bytes -->| // [ chunk 0 ][ chunk 1 ][ chunk 2 ][ chunk 3 ] // // interleave=true: chunks are strided by interleave_thread_size * chunk_bytes in GMEM. // GMEM layout (thread 0 stores marked with *, other threads fill the gaps): // base + row_offset // |<- chunk_bytes ->|<- (interleave_thread_size-1)*chunk_bytes gap ->|<- chunk_bytes ->|... // [ *chunk 0 (t0)* ][ chunk 0 (t1) ]...[ chunk 0 (tN-1) ] [ *chunk 1 (t0)* ]... // // Each thread's chunks are interleaved with other threads' data, // stride = interleave_thread_size * chunk_bytes bytes between chunks. // // Conversion paths (when T_R != DTYPE_I): // - T_R is bf16/fp16: per-element type_convert (scalar loop) // - otherwise: vec_convert with inverted_scale (e.g. float -> fp8/fp4) // When T_R == DTYPE_I: direct store, no conversion. template __device__ void store_vector_nbytes(opus::gmem& buffer, const opus::vector_t& vec, int row_offset, float inverted_scale = 1.0f) { static constexpr int32_t store_vec_size = std::is_same_v ? vec_size / 2 : vec_size; static_assert(store_vec_size * sizeof(T) % chunk_bytes == 0, "store_vec_size * sizeof(T) must be a multiple of chunk_bytes"); static constexpr index_t num_chunks = store_vec_size * sizeof(T) / chunk_bytes; static constexpr index_t chunk_size_elements = vec_size / num_chunks; static constexpr index_t store_chunk_size_elements = store_vec_size / num_chunks; static constexpr index_t interleave_bytes = interleave_thread_size * chunk_bytes; const DTYPE_I* vec_ptr = reinterpret_cast(&vec); using chunk_type = opus::vector_t; using store_type = opus::vector_t; opus::static_for([&](auto i) { constexpr index_t chunk_offset_bytes = interleave ? i.value * interleave_bytes : i.value * chunk_bytes; constexpr index_t chunk_offset_elements = chunk_offset_bytes / sizeof(T); const chunk_type* chunk_ptr = reinterpret_cast(vec_ptr + i.value * chunk_size_elements); if constexpr(!std::is_same_v) { if constexpr(std::is_same_v || std::is_same_v) { opus::vector_t chunk_convert; for(int j = 0; j < chunk_size_elements; j++) { chunk_convert[j] = opus::cast((*chunk_ptr)[j]); } store_type& chunk_store = reinterpret_cast(chunk_convert); buffer.template store( chunk_store, row_offset, chunk_offset_elements); } else if constexpr(std::is_same_v) { auto chunk_convert = scaled_cast(*chunk_ptr, inverted_scale); store_type& chunk_store = reinterpret_cast(chunk_convert); buffer.template store( chunk_store, row_offset, chunk_offset_elements); } else { opus::vector_t chunk_convert; chunk_convert = scaled_cast(*chunk_ptr, inverted_scale); store_type& chunk_store = reinterpret_cast(chunk_convert); buffer.template store( chunk_store, row_offset, chunk_offset_elements); } // Workaround: compiler may not insert s_nop after the last buffer_store, causing a // WAR hazard where vdata VGPRs are overwritten before buffer_store finishes reading // them. asm volatile("s_nop 0"); } else { const store_type* chunk_store_ptr = reinterpret_cast(chunk_ptr); buffer.template store( *chunk_store_ptr, row_offset, chunk_offset_elements); } }); } // High-level store API: automatically selects the best chunk_bytes (16/8/4) for // store_vector_nbytes. Picks the largest chunk size that evenly divides the total store bytes. // // When interleave=true, num_repeat controls how many interleaved repeats per thread, // which affects the effective store size used to choose chunk_bytes. template __device__ void store_vector(opus::gmem& buffer, const opus::vector_t& vec, int row_offset, float inverted_scale = 1.0f) { static constexpr int32_t num_store_repeat = interleave ? num_repeat : 1; static constexpr int32_t store_vec_size = std::is_same_v ? vec_size / 2 : vec_size; if constexpr((store_vec_size * sizeof(T) / num_store_repeat) % 16 == 0) { store_vector_nbytes( buffer, vec, row_offset, inverted_scale); } else if constexpr((store_vec_size * sizeof(T) / num_store_repeat) % 8 == 0) { store_vector_nbytes( buffer, vec, row_offset, inverted_scale); } else if constexpr((store_vec_size * sizeof(T) / num_store_repeat) % 4 == 0) { store_vector_nbytes( buffer, vec, row_offset, inverted_scale); } else { static_assert(false, "vec_size * sizeof(T) must be a multiple of 16, 8, or 4"); } } // todo: edit this to use aiterTensor dtype template struct t2opus; template <> struct t2opus { using type = float; }; template <> struct t2opus { using type = opus::fp16_t; }; template <> struct t2opus { using type = opus::bf16_t; }; template <> struct t2opus { using type = int32_t; }; template <> struct t2opus { using type = opus::i8_t; }; // HIP native type -> opus type mapping template struct hip2opus; template <> struct hip2opus { using type = opus::fp32_t; }; template <> struct hip2opus<__half> { using type = opus::fp16_t; }; template <> struct hip2opus { using type = opus::bf16_t; }; template <> struct hip2opus { using type = opus::fp8_t; }; template <> struct hip2opus { using type = opus::i8_t; }; template <> struct hip2opus { using type = int32_t; }; } // namespace aiter