#pragma once #include #include struct fp8x8 { __nv_fp8x4_e4m3 lo; __nv_fp8x4_e4m3 hi; }; struct fp8x16 { fp8x8 lo; fp8x8 hi; }; struct bf16x8 { __nv_bfloat162 a, b, c, d; }; __device__ __forceinline__ bf16x8 cvt_fp8x8_bf16x8(const fp8x8 &inputs, const float &scale) { __nv_bfloat162 scale_bf162 = __float2bfloat162_rn(scale); #define DEQUANT_FP8x4(OUTPUT_BF16_LO, OUTPUT_BF16_HI, FP8x4) \ { \ float4 fp32x4 = (float4)(FP8x4); \ OUTPUT_BF16_LO = __float22bfloat162_rn({fp32x4.x, fp32x4.y})*scale_bf162; \ OUTPUT_BF16_HI = __float22bfloat162_rn({fp32x4.z, fp32x4.w})*scale_bf162; \ } bf16x8 result; DEQUANT_FP8x4(result.a, result.b, inputs.lo); DEQUANT_FP8x4(result.c, result.d, inputs.hi); return result; } enum class L1CacheHint { NO_ALLOCATE, EVICT_FIRST, EVICT_NORMAL, EVICT_LAST }; enum class L2PrefetchHint { B64, B128, B256 }; template< typename T, L1CacheHint l1_cache_hint, L2PrefetchHint l2_prefetch_hint > __device__ __forceinline__ T load_128b_from_gmem(const void* addr) { static_assert(sizeof(T) == 128/8); int4 ret; #define EXEC(L1_HINT_STR, L2_HINT_STR) { \ asm volatile("ld.global.nc.L1::" L1_HINT_STR ".L2::" L2_HINT_STR ".v4.s32 {%0, %1, %2, %3}, [%4];" \ : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) \ : "l"(addr)); \ } #define DISPATCH_L2(L1_HINT_STR) { \ if constexpr(l2_prefetch_hint == L2PrefetchHint::B64) \ EXEC(L1_HINT_STR, "64B") \ else if constexpr(l2_prefetch_hint == L2PrefetchHint::B128) \ EXEC(L1_HINT_STR, "128B") \ else if constexpr(l2_prefetch_hint == L2PrefetchHint::B256) \ EXEC(L1_HINT_STR, "256B") \ } if constexpr(l1_cache_hint == L1CacheHint::NO_ALLOCATE) DISPATCH_L2("no_allocate") else if constexpr(l1_cache_hint == L1CacheHint::EVICT_FIRST) DISPATCH_L2("evict_first") else if constexpr(l1_cache_hint == L1CacheHint::EVICT_NORMAL) DISPATCH_L2("evict_normal") else if constexpr(l1_cache_hint == L1CacheHint::EVICT_LAST) DISPATCH_L2("evict_last") #undef EXEC #undef DISPATCH_L2 return *reinterpret_cast(&ret); }