Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
64c87135
Unverified
Commit
64c87135
authored
Feb 10, 2025
by
Yineng Zhang
Committed by
GitHub
Feb 10, 2025
Browse files
remove activation dependency in fused_moe (#3433)
parent
1646149a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
3 deletions
+18
-3
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+18
-3
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
64c87135
...
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
...
@@ -18,7 +18,7 @@ from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from
sglang.srt.utils
import
direct_register_custom_op
,
get_device_name
,
is_hip
from
sglang.srt.utils
import
direct_register_custom_op
,
get_device_name
,
is_hip
is_hip_flag
=
is_hip
()
is_hip_flag
=
is_hip
()
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"MOE_PADDING"
,
"0"
)))
else
0
...
@@ -27,6 +27,15 @@ enable_moe_align_block_size_triton = bool(
...
@@ -27,6 +27,15 @@ enable_moe_align_block_size_triton = bool(
int
(
os
.
getenv
(
"ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
int
(
os
.
getenv
(
"ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
)
)
_is_cuda
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
cuda
_is_rocm
=
torch
.
cuda
.
is_available
()
and
torch
.
version
.
hip
if
_is_cuda
:
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
if
_is_cuda
or
_is_rocm
:
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
@
triton
.
jit
@
triton
.
jit
def
fused_moe_kernel
(
def
fused_moe_kernel
(
...
@@ -989,8 +998,14 @@ def fused_experts_impl(
...
@@ -989,8 +998,14 @@ def fused_experts_impl(
)
)
if
activation
==
"silu"
:
if
activation
==
"silu"
:
if
_is_cuda
:
silu_and_mul
(
intermediate_cache1
.
view
(
-
1
,
N
),
intermediate_cache2
)
else
:
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
elif
activation
==
"gelu"
:
elif
activation
==
"gelu"
:
if
_is_cuda
:
gelu_and_mul
(
intermediate_cache1
.
view
(
-
1
,
N
),
intermediate_cache2
)
else
:
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
ops
.
gelu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
else
:
else
:
raise
ValueError
(
f
"Unsupported activation:
{
activation
=
}
"
)
raise
ValueError
(
f
"Unsupported activation:
{
activation
=
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment