Commit c3ec9351 authored by Jeff Daily's avatar Jeff Daily
Browse files

apex definition of macro conflicts with pytorch macro WARP_SHFL_XOR

parent 88eee5fe
......@@ -12,9 +12,9 @@
#include <cmath>
#ifdef __HIP_PLATFORM_HCC__
#define WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width)
#define APEX_WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width)
#else
#define WARP_SHFL_XOR __shfl_xor_sync
#define APEX_WARP_SHFL_XOR __shfl_xor_sync
#endif
namespace {
......@@ -134,7 +134,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -159,7 +159,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -358,7 +358,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -382,7 +382,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
auto seeds = at::cuda::philox::unpack(philox_args);
......@@ -512,7 +512,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -536,7 +536,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
curandStatePhilox4_32_10_t state;
......@@ -772,7 +772,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -797,7 +797,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -1027,7 +1027,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -1052,7 +1052,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -1250,7 +1250,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -1275,7 +1275,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -1842,7 +1842,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
float val[WARP_BATCH];
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
val[i] = WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
val[i] = APEX_WARP_SHFL_XOR(FULL_MASK, max_value[i], offset, WARP_SIZE);
}
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
......@@ -1867,7 +1867,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -2312,7 +2312,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, con
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......@@ -2523,7 +2523,7 @@ __global__ void masked_softmax_warp_backward(__half *gradInput, const __half *gr
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) {
#pragma unroll
for (int i = 0;i < WARP_BATCH;++i) {
sum[i] += WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
sum[i] += APEX_WARP_SHFL_XOR(FULL_MASK, sum[i], offset, WARP_SIZE);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment