Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c3ec9351
Commit
c3ec9351
authored
Oct 21, 2021
by
Jeff Daily
Browse files
apex definition of macro conflicts with pytorch macro WARP_SHFL_XOR
parent
88eee5fe
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
18 deletions
+18
-18
apex/contrib/csrc/multihead_attn/softmax.h
apex/contrib/csrc/multihead_attn/softmax.h
+18
-18
No files found.
apex/contrib/csrc/multihead_attn/softmax.h
View file @
c3ec9351
...
...
@@ -12,9 +12,9 @@
#include <cmath>
#ifdef __HIP_PLATFORM_HCC__
#define WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width)
#define
APEX_
WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width)
#else
#define WARP_SHFL_XOR __shfl_xor_sync
#define
APEX_
WARP_SHFL_XOR __shfl_xor_sync
#endif
namespace
{
...
...
@@ -134,7 +134,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -159,7 +159,7 @@ __global__ void softmax_warp_forward(input_t *dst, const output_t *src, int batc
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -358,7 +358,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -382,7 +382,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward_vec4(output_t *dst,
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
auto
seeds
=
at
::
cuda
::
philox
::
unpack
(
philox_args
);
...
...
@@ -512,7 +512,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -536,7 +536,7 @@ __global__ void additive_masked_softmax_dropout_warp_forward(output_t *dst, uint
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
curandStatePhilox4_32_10_t
state
;
...
...
@@ -772,7 +772,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -797,7 +797,7 @@ __global__ void additive_masked_softmax_warp_forward(input_t *dst, const output_
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -1027,7 +1027,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -1052,7 +1052,7 @@ __global__ void masked_softmax_warp_forward(input_t *dst, const output_t *src, c
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -1250,7 +1250,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -1275,7 +1275,7 @@ __global__ void time_masked_softmax_warp_forward(input_t *dst, const output_t *s
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -1842,7 +1842,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
float
val
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
val
[
i
]
=
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
val
[
i
]
=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
max_value
[
i
],
offset
,
WARP_SIZE
);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
...
...
@@ -1867,7 +1867,7 @@ __global__ void masked_scale_softmax_warp_backward_recompute(output_t *gradInput
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -2312,7 +2312,7 @@ __global__ void softmax_warp_backward(__half *gradInput, const __half *grad, con
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
@@ -2523,7 +2523,7 @@ __global__ void masked_softmax_warp_backward(__half *gradInput, const __half *gr
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
+=
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
+=
APEX_
WARP_SHFL_XOR
(
FULL_MASK
,
sum
[
i
],
offset
,
WARP_SIZE
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment