Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e4a28e53
Unverified
Commit
e4a28e53
authored
Mar 10, 2024
by
Douglas Lehr
Committed by
GitHub
Mar 10, 2024
Browse files
[ROCM] Fix blockReduceSum to use correct warp counts for ROCm and CUDA (#3262)
parent
0bba88df
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
11 deletions
+13
-11
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+0
-8
csrc/cuda_compat.h
csrc/cuda_compat.h
+10
-0
csrc/reduction_utils.cuh
csrc/reduction_utils.cuh
+3
-3
No files found.
csrc/attention/attention_kernels.cu
View file @
e4a28e53
...
...
@@ -15,9 +15,6 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
...
...
@@ -31,11 +28,6 @@
#include <algorithm>
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
...
...
csrc/cuda_compat.h
View file @
e4a28e53
#pragma once
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#ifndef USE_ROCM
#define VLLM_LDG(arg) __ldg(arg)
#else
...
...
csrc/reduction_utils.cuh
View file @
e4a28e53
...
...
@@ -24,7 +24,7 @@ namespace vllm {
template
<
typename
T
>
__inline__
__device__
T
warpReduceSum
(
T
val
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>
0
;
mask
>>=
1
)
val
+=
VLLM_SHFL_XOR_SYNC
(
val
,
mask
);
return
val
;
}
...
...
@@ -32,7 +32,7 @@ __inline__ __device__ T warpReduceSum(T val) {
/* Calculate the sum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
blockReduceSum
(
T
val
)
{
static
__shared__
T
shared
[
32
];
static
__shared__
T
shared
[
WARP_SIZE
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
...
...
@@ -45,7 +45,7 @@ __inline__ __device__ T blockReduceSum(T val) {
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val
=
(
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
))
?
shared
[
lane
]
:
(
T
)(
0.0
f
);
val
=
(
threadIdx
.
x
<
(
blockDim
.
x
/
(
WARP_SIZE
*
1.0
f
)
))
?
shared
[
lane
]
:
(
T
)(
0.0
f
);
val
=
warpReduceSum
<
T
>
(
val
);
return
val
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment