Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1d754726
Unverified
Commit
1d754726
authored
Nov 16, 2024
by
rasmith
Committed by
GitHub
Nov 16, 2024
Browse files
[BugFix] [Kernel] Fix GPU SEGV occuring in fused_moe kernel (#10385)
Signed-off-by:
Randall Smith
<
Randall.Smith@amd.com
>
parent
2f427c2d
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
3 deletions
+5
-3
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+5
-3
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
1d754726
...
@@ -105,16 +105,18 @@ def fused_moe_kernel(
...
@@ -105,16 +105,18 @@ def fused_moe_kernel(
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
num_tokens_post_padded
=
tl
.
load
(
num_tokens_post_padded_ptr
)
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
if
pid_m
*
BLOCK_SIZE_M
>=
num_tokens_post_padded
:
return
return
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_token_id
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
).
to
(
tl
.
int64
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
offs_token
=
tl
.
load
(
sorted_token_ids_ptr
+
offs_token_id
)
token_mask
=
offs_token
<
num_valid_tokens
token_mask
=
offs_token
<
num_valid_tokens
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
).
to
(
tl
.
int64
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
a_ptrs
=
a_ptr
+
(
offs_token
[:,
None
]
//
top_k
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
offs_k
[
None
,
:]
*
stride_ak
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
off_experts
=
tl
.
load
(
expert_ids_ptr
+
pid_m
)
.
to
(
tl
.
int64
)
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
b_ptrs
=
b_ptr
+
off_experts
*
stride_be
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
offs_bn
[
None
,
:]
*
stride_bn
)
if
use_int8_w8a16
:
if
use_int8_w8a16
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment