Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c9415c19
Unverified
Commit
c9415c19
authored
Mar 12, 2024
by
kliuae
Committed by
GitHub
Mar 11, 2024
Browse files
[ROCm] Fix warp and lane calculation in blockReduceSum (#3321)
parent
4c922709
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
2 deletions
+12
-2
csrc/reduction_utils.cuh
csrc/reduction_utils.cuh
+12
-2
No files found.
csrc/reduction_utils.cuh
View file @
c9415c19
...
...
@@ -29,12 +29,22 @@ __inline__ __device__ T warpReduceSum(T val) {
return
val
;
}
__inline__
__device__
constexpr
int
_calculateLaneMask
(
int
warp_size
)
{
return
warp_size
-
1
;
}
__inline__
__device__
constexpr
int
_calculateWidShift
(
int
warp_size
)
{
return
5
+
(
warp_size
>>
6
);
}
/* Calculate the sum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
blockReduceSum
(
T
val
)
{
static
__shared__
T
shared
[
WARP_SIZE
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
constexpr
auto
LANE_MASK
=
_calculateLaneMask
(
WARP_SIZE
);
constexpr
auto
WID_SHIFT
=
_calculateWidShift
(
WARP_SIZE
);
int
lane
=
threadIdx
.
x
&
LANE_MASK
;
int
wid
=
threadIdx
.
x
>>
WID_SHIFT
;
val
=
warpReduceSum
<
T
>
(
val
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment