Unverified Commit 8977ffb5 authored by Sage Moore's avatar Sage Moore Committed by GitHub
Browse files

[ROCm][Bugfix] Fix compilation errors with fused_qknorm_rope_kernel.cu (#28682)


Signed-off-by: default avatarSage Moore <sage@neuralmagic.com>
parent fd455508
...@@ -37,6 +37,16 @@ ...@@ -37,6 +37,16 @@
#ifdef USE_ROCM #ifdef USE_ROCM
#define FINAL_MASK 0xffffffffffffffffULL #define FINAL_MASK 0xffffffffffffffffULL
#if defined(HIP_VERSION) && HIP_VERSION < 70000000
// On ROCm versions before 7.0, __syncwarp isn't defined. The below
// implementation is copy/pasted from the implementation in ROCm 7.0
__device__ inline void __syncwarp() {
__builtin_amdgcn_fence(__ATOMIC_RELEASE, "wavefront");
__builtin_amdgcn_wave_barrier();
__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "wavefront");
}
#endif
#else #else
#define FINAL_MASK 0xffffffff #define FINAL_MASK 0xffffffff
#endif #endif
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment