// SPDX-License-Identifier: MIT // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "data_type.hpp" namespace ck { template union BufferResource { __device__ constexpr BufferResource() : content{} {} // 128 bit SGPRs to supply buffer resource in buffer instructions // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions int32x4_t content; StaticallyIndexedArray address; StaticallyIndexedArray range; StaticallyIndexedArray config; }; template __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size) { BufferResource wave_buffer_resource; // wavewise base address (64 bit) wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); // wavewise range (32 bit) wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T); // wavewise setting (32 bit) wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; return wave_buffer_resource.content; } template __device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave) { BufferResource wave_buffer_resource; // wavewise base address (64 bit) wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); // wavewise range (32 bit) wave_buffer_resource.range(Number<2>{}) = 0xffffffff; // max possible range // wavewise setting (32 bit) wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; return wave_buffer_resource.content; } // buffer load i8 __device__ int8_t llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); __device__ int8x2_t llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8"); __device__ int8x4_t llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); // buffer load i16 __device__ bhalf_t llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); __device__ bhalf2_t llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16"); __device__ bhalf4_t llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16"); // buffer load i32 __device__ int32_t llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); __device__ int32x2_t llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); __device__ int32x4_t llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); // buffer load fp16 __device__ half_t llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); __device__ half2_t llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); __device__ half4_t llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); // buffer load fp32 __device__ float llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); __device__ float2_t llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32"); __device__ float4_t llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); // buffer store i8 __device__ void llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); __device__ void llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8"); __device__ void llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); // buffer store i16 __device__ void llvm_amdgcn_raw_buffer_store_i16(bhalf_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); __device__ void llvm_amdgcn_raw_buffer_store_i16x2(bhalf2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); __device__ void llvm_amdgcn_raw_buffer_store_i16x4(bhalf4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); // buffer store i32 __device__ void llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); __device__ void llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); __device__ void llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); // buffer store fp16 __device__ void llvm_amdgcn_raw_buffer_store_fp16(half_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); __device__ void llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16"); __device__ void llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16"); // buffer store fp32 __device__ void llvm_amdgcn_raw_buffer_store_fp32(float vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); __device__ void llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); __device__ void llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); // buffer atomic-add fp16 __device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( half2_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"); // buffer atomic-add i32 __device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( int32_t vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); // buffer atomic-add fp32 __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( float vdata, int32x4_t rsrc, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); // buffer atomic-add fp32 __device__ double llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, int32x4_t rsrc, // dst_wave_buffer_resource int voffset, // dst_thread_addr_offset int soffset, // dst_wave_addr_offset int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); // memory coherency bit for buffer store/load instruction // check ISA manual for each GFX target // e.g. for // https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf, // page 67~68 enum struct AmdBufferCoherenceEnum { DefaultCoherence = 0, // default value GLC = 1, SLC = 2, GLC_SLC = 3, // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse WAVE_NT0 = 0, WAVE_NT1 = 2, GROUP_NT0 = 1, GROUP_NT1 = 3, DEVICE_NT0 = 8, DEVICE_NT1 = 10, SYSTEM_NT0 = 9, SYSTEM_NT1 = 11, }; template __device__ typename vector_type::type amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); if constexpr(N == 1) { return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); } else if constexpr(N == 2) { int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); return bit_cast(tmp); } else if constexpr(N == 4) { int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); return bit_cast(tmp); } else if constexpr(N == 8) { int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); return bit_cast(tmp); } else if constexpr(N == 16) { int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); return bit_cast(tmp); } else if constexpr(N == 32) { int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); int32x4_t tmp1 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int32_t), static_cast(coherence)); vector_type tmp; tmp.AsType()(Number<0>{}) = tmp0; tmp.AsType()(Number<1>{}) = tmp1; return bit_cast(tmp); } else if constexpr(N == 64) { int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, static_cast(coherence)); int32x4_t tmp1 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 4 * sizeof(int32_t), static_cast(coherence)); int32x4_t tmp2 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 8 * sizeof(int32_t), static_cast(coherence)); int32x4_t tmp3 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset + 12 * sizeof(int32_t), static_cast(coherence)); vector_type tmp; tmp.AsType()(Number<0>{}) = tmp0; tmp.AsType()(Number<1>{}) = tmp1; tmp.AsType()(Number<2>{}) = tmp2; tmp.AsType()(Number<3>{}) = tmp3; return bit_cast(tmp); } } template __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, index_t src_thread_addr_offset, index_t src_wave_addr_offset) { static_assert( (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using r_t = typename vector_type::type; auto raw_data = amd_buffer_load_impl_raw( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); return bit_cast(raw_data); } template __device__ void amd_buffer_store_impl_raw(const typename vector_type::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) { static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, "wrong! not implemented"); if constexpr(N == 1) { llvm_amdgcn_raw_buffer_store_i8(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); } else if constexpr(N == 2) { llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); } else if constexpr(N == 4) { llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); } else if constexpr(N == 8) { llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); } else if constexpr(N == 16) { llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); } else if constexpr(N == 32) { vector_type tmp{bit_cast(src_thread_data)}; llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(int32_t) * 4, static_cast(coherence)); } else if constexpr(N == 64) { vector_type tmp{bit_cast(src_thread_data)}; llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, static_cast(coherence)); llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(int32_t) * 4, static_cast(coherence)); llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<2>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(int32_t) * 8, static_cast(coherence)); llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<3>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(int32_t) * 12, static_cast(coherence)); } } template __device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) { static_assert( (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); using r_t = typename vector_type::type; amd_buffer_store_impl_raw(bit_cast(src_thread_data), dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset); } template __device__ void amd_global_atomic_add_impl(const typename vector_type::type src_thread_data, T* addr) { static_assert((is_same::value && (N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 2 || N == 4 || N == 8)), "wrong! not implemented"); if constexpr(is_same::value) { vector_type tmp{src_thread_data}; static_for<0, N / 2, 1>{}([&](auto i) { __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast(addr) + i, tmp.template AsType()[i]); }); } #if defined(__gfx942__) else if constexpr(is_same::value) { vector_type tmp{src_thread_data}; static_for<0, N / 2, 1>{}([&](auto i) { __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast(addr) + i, tmp.template AsType()[i]); }); } #endif } template __device__ void amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) { static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || (is_same::value && (N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4)), "wrong! not implemented"); if constexpr(is_same::value) { if constexpr(N == 1) { llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); } else if constexpr(N == 2) { vector_type tmp{src_thread_data}; llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(float), 0); } else if constexpr(N == 4) { vector_type tmp{src_thread_data}; llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(float), 0); llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<2>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 2 * sizeof(float), 0); llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<3>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 3 * sizeof(float), 0); } } else if constexpr(is_same::value) { if constexpr(N == 2) { llvm_amdgcn_raw_buffer_atomic_add_fp16x2(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); } else if constexpr(N == 4) { vector_type tmp{src_thread_data}; static_for<0, 2, 1>{}([&](auto i) { llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + i * sizeof(half2_t), 0); }); } else if constexpr(N == 8) { vector_type tmp{src_thread_data}; static_for<0, 4, 1>{}([&](auto i) { llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + i * sizeof(half2_t), 0); }); } } else if constexpr(is_same::value) { if constexpr(N == 1) { llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); } else if constexpr(N == 2) { vector_type tmp{src_thread_data}; llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(int32_t), 0); } else if constexpr(N == 4) { vector_type tmp{src_thread_data}; llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(int32_t), 0); llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<2>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 2 * sizeof(int32_t), 0); llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<3>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 3 * sizeof(int32_t), 0); } } } template __device__ void amd_buffer_atomic_max_impl(const typename vector_type::type src_thread_data, int32x4_t dst_wave_buffer_resource, index_t dst_thread_addr_offset, index_t dst_wave_addr_offset) { static_assert((is_same::value && (N == 1 || N == 2 || N == 4)), "wrong! not implemented"); if constexpr(is_same::value) { if constexpr(N == 1) { llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); } else if constexpr(N == 2) { vector_type tmp{src_thread_data}; llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(double), 0); } else if constexpr(N == 4) { vector_type tmp{src_thread_data}; llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset, 0); llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + sizeof(double), 0); llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<2>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 2 * sizeof(double), 0); llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<3>{}], dst_wave_buffer_resource, dst_thread_addr_offset, dst_wave_addr_offset + 3 * sizeof(double), 0); } } } // buffer_load requires: // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); using vector_t = typename vector_type_maker::type::type; using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; return amd_buffer_load_impl( src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); #else vector_t tmp{amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0)}; return src_thread_element_valid ? tmp : vector_t(0); #endif } // buffer_load requires: // 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ typename vector_type_maker::type::type amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, index_t src_thread_element_offset, bool src_thread_element_valid, index_t src_element_space_size, T customized_value) { const int32x4_t src_wave_buffer_resource = make_wave_buffer_resource(p_src_wave, src_element_space_size); index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); using vector_t = typename vector_type_maker::type::type; using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; vector_t tmp{amd_buffer_load_impl( src_wave_buffer_resource, src_thread_addr_offset, 0)}; return src_thread_element_valid ? tmp : vector_t(customized_value); } // buffer_store requires: // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size) { const int32x4_t dst_wave_buffer_resource = make_wave_buffer_resource(p_dst_wave, dst_element_space_size); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); using vector_t = typename vector_type_maker::type::type; using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_element_valid) { amd_buffer_store_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif } // buffer_atomic_add requires: // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ void amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size) { const int32x4_t dst_wave_buffer_resource = make_wave_buffer_resource(p_dst_wave, dst_element_space_size); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); using vector_t = typename vector_type_maker::type::type; using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; if constexpr(is_same::value) { if(dst_thread_element_valid) { amd_global_atomic_add_impl( src_thread_data, p_dst_wave + dst_thread_element_offset); } } else { #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; amd_buffer_atomic_add_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_element_valid) { amd_buffer_atomic_add_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif } } // buffer_atomic_max requires: // 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ void amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thread_data, T* p_dst_wave, const index_t dst_thread_element_offset, const bool dst_thread_element_valid, const index_t dst_element_space_size) { const int32x4_t dst_wave_buffer_resource = make_wave_buffer_resource(p_dst_wave, dst_element_space_size); index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); using vector_t = typename vector_type_maker::type::type; using scalar_t = typename scalar_type::type; constexpr index_t vector_size = scalar_type::vector_size; #if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; amd_buffer_atomic_max_impl( src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); #else if(dst_thread_element_valid) { amd_buffer_atomic_max_impl( src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); } #endif } // Direct loads from global to LDS. __device__ void llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, __attribute__((address_space(3))) uint32_t* lds_ptr, index_t size, index_t voffset, index_t soffset, index_t offset, index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); template __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, const index_t global_offset, T* lds_base_ptr, const index_t lds_offset, const bool is_valid, const index_t src_element_space_size) { // Direct loads require that each thread reads and writes exactly a single DWORD. constexpr auto dword_bytes = 4; constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; static_assert(bytes_per_thread == dword_bytes); const uint32_t* global_ptr = reinterpret_cast(reinterpret_cast(global_base_ptr)); const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size); const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; #if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM T* lds_ptr = lds_base_ptr + lds_offset; auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); asm volatile("s_mov_b32 m0, %0; \n\t" "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), "v"(global_offset_bytes), "s"(src_resource) : "memory"); #else // LDS pointer must be attributed with the LDS address space. __attribute__((address_space(3))) uint32_t* lds_ptr = reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( reinterpret_cast(lds_base_ptr + lds_offset)); llvm_amdgcn_raw_buffer_load_lds( src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); #endif } } // namespace ck