Unverified Commit c86d777a authored by miaoneng's avatar miaoneng Committed by GitHub
Browse files

Use macro for `__shfl_*` functions for ROCm (#296)



* Use macro for __shfl_* functions

* Update test_matmul.py
Co-authored-by: default avatarjytang <striges@users.noreply.github.com>
Co-authored-by: default avatarMatthias Fey <matthias.fey@tu-dortmund.de>
parent 1bf12762
...@@ -63,9 +63,9 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data, ...@@ -63,9 +63,9 @@ __global__ void spmm_kernel(const int64_t *rowptr_data, const int64_t *col_data,
#pragma unroll #pragma unroll
for (int i = 0; i < 32; i++) { for (int i = 0; i < 32; i++) {
// Communication between all threads in a warp. // Communication between all threads in a warp.
mat_rows[i] = __shfl_sync(FULL_MASK, mat_row, i); mat_rows[i] = SHFL_SYNC(FULL_MASK, mat_row, i);
if (HAS_VALUE) if (HAS_VALUE)
vals[i] = __shfl_sync(FULL_MASK, val, i); vals[i] = SHFL_SYNC(FULL_MASK, val, i);
} }
#pragma unroll #pragma unroll
...@@ -179,7 +179,7 @@ spmm_value_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data, ...@@ -179,7 +179,7 @@ spmm_value_bw_kernel(const int64_t *row_data, const int64_t *rowptr_data,
#pragma unroll #pragma unroll
for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp. for (int i = 32 / 2; i > 0; i /= 2) { // Parallel reduction inside a warp.
val += __shfl_down_sync(FULL_MASK, val, i); val += SHFL_DOWN_SYNC(FULL_MASK, val, i);
} }
if (lane_idx == 0) { if (lane_idx == 0) {
......
...@@ -17,13 +17,15 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, ...@@ -17,13 +17,15 @@ __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask,
return __shfl_down_sync(mask, var.operator __half(), delta); return __shfl_down_sync(mask, var.operator __half(), delta);
} }
#ifdef USE_ROCM #ifdef USE_ROCM
__device__ __inline__ at::Half __ldg(const at::Half* ptr) { __device__ __inline__ at::Half __ldg(const at::Half* ptr) {
return __ldg(reinterpret_cast<const __half*>(ptr)); return __ldg(reinterpret_cast<const __half*>(ptr));
} }
#define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta) #define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta)
#define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta) #define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta)
#else #define SHFL_SYNC(mask, var, delta) __shfl(var, delta)
#define SHFL_UP_SYNC __shfl_up_sync #else
#define SHFL_DOWN_SYNC __shfl_down_sync #define SHFL_UP_SYNC __shfl_up_sync
#endif #define SHFL_DOWN_SYNC __shfl_down_sync
#define SHFL_SYNC __shfl_sync
#endif
...@@ -43,14 +43,13 @@ def test_spmm(dtype, device, reduce): ...@@ -43,14 +43,13 @@ def test_spmm(dtype, device, reduce):
out = matmul(src, other, reduce) out = matmul(src, other, reduce)
out.backward(grad_out) out.backward(grad_out)
atol = 1e-7
if dtype == torch.float16 or dtype == torch.bfloat16: if dtype == torch.float16 or dtype == torch.bfloat16:
assert torch.allclose(expected, out, atol=1e-1) atol = 1e-1
assert torch.allclose(expected_grad_value, value.grad, atol=1e-1)
assert torch.allclose(expected_grad_other, other.grad, atol=1e-1) assert torch.allclose(expected, out, atol=atol)
else: assert torch.allclose(expected_grad_value, value.grad, atol=atol)
assert torch.allclose(expected, out) assert torch.allclose(expected_grad_other, other.grad, atol=atol)
assert torch.allclose(expected_grad_value, value.grad)
assert torch.allclose(expected_grad_other, other.grad)
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment