Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
869579a7
Unverified
Commit
869579a7
authored
Jan 08, 2025
by
youkaichao
Committed by
GitHub
Jan 07, 2025
Browse files
[optimization] remove python function call for custom op (#11750)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
c0efe92d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
15 additions
and
13 deletions
+15
-13
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-4
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+11
-6
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+2
-2
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+2
-1
No files found.
vllm/_custom_ops.py
View file @
869579a7
...
...
@@ -35,10 +35,6 @@ else:
# activation ops
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
silu_and_mul
(
out
,
x
)
def
gelu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
gelu_and_mul
(
out
,
x
)
...
...
vllm/model_executor/layers/activation.py
View file @
869579a7
...
...
@@ -10,6 +10,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
LazyDict
...
...
@@ -58,27 +59,31 @@ class SiluAndMul(CustomOp):
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def
__init__
(
self
):
super
().
__init__
()
if
current_platform
.
is_cuda_alike
():
self
.
op
=
torch
.
ops
.
_C
.
silu_and_mul
elif
current_platform
.
is_xpu
():
import
intel_extension_for_pytorch
as
ipex
self
.
op
=
ipex
.
llm
.
functional
.
silu_and_mul
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
d
=
x
.
shape
[
-
1
]
//
2
return
F
.
silu
(
x
[...,
:
d
])
*
x
[...,
d
:]
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
vllm
import
_custom_ops
as
ops
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
ops
.
silu_and_mul
(
out
,
x
)
self
.
op
(
out
,
x
)
return
out
def
forward_xpu
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
from
vllm._ipex_ops
import
ipex_ops
as
ops
d
=
x
.
shape
[
-
1
]
//
2
output_shape
=
(
x
.
shape
[:
-
1
]
+
(
d
,
))
out
=
torch
.
empty
(
output_shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
ops
.
silu_and_mul
(
out
,
x
)
self
.
op
(
out
,
x
)
return
out
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
869579a7
...
...
@@ -4,7 +4,6 @@ from typing import Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
from
vllm.scalar_type
import
scalar_types
...
...
@@ -301,7 +300,8 @@ def fused_marlin_moe(
False
,
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
2
*
N
))
intermediate_cache3
=
torch
.
ops
.
_moe_C
.
marlin_gemm_moe
(
intermediate_cache2
,
...
...
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
869579a7
...
...
@@ -753,7 +753,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int8_w8a16
=
use_int8_w8a16
,
block_shape
=
block_shape
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
torch
.
ops
.
_C
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment