Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96e0c9cb
Unverified
Commit
96e0c9cb
authored
Oct 31, 2024
by
youkaichao
Committed by
GitHub
Oct 31, 2024
Browse files
[torch.compile] directly register custom op (#9896)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
031a7995
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
192 additions
and
67 deletions
+192
-67
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+16
-4
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+16
-4
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+11
-5
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+11
-6
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+23
-11
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
+19
-6
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+41
-27
vllm/utils.py
vllm/utils.py
+45
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+10
-4
No files found.
tests/compile/piecewise/test_simple.py
View file @
96e0c9cb
...
@@ -6,18 +6,22 @@ import os
...
@@ -6,18 +6,22 @@ import os
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.library
import
Library
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.utils
import
direct_register_custom_op
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
global_counter
=
0
global_counter
=
0
# create a library to hold the custom op
silly_lib
=
Library
(
"silly"
,
"FRAGMENT"
)
# noqa
@
torch
.
library
.
custom_op
(
"silly::attention"
,
mutates_args
=
[
"out"
])
def
silly_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
def
silly_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
out
:
torch
.
Tensor
)
->
None
:
global
global_counter
global
global_counter
...
@@ -27,12 +31,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
...
@@ -27,12 +31,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out
[
0
]
+=
1
out
[
0
]
+=
1
@
silly_attention
.
register_fake
def
silly_attention_fake
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
def
_
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
out
:
torch
.
Tensor
)
->
None
:
return
return
direct_register_custom_op
(
op_name
=
"attention"
,
op_func
=
silly_attention
,
mutates_args
=
[
"out"
],
fake_impl
=
silly_attention_fake
,
target_lib
=
silly_lib
,
)
@
support_torch_compile
@
support_torch_compile
class
SillyModel
(
nn
.
Module
):
class
SillyModel
(
nn
.
Module
):
...
...
tests/compile/piecewise/test_toy_llama.py
View file @
96e0c9cb
...
@@ -8,6 +8,7 @@ from typing import Optional, Tuple
...
@@ -8,6 +8,7 @@ from typing import Optional, Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch.library
import
Library
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.config
import
CompilationConfig
from
vllm.compilation.config
import
CompilationConfig
...
@@ -15,9 +16,12 @@ from vllm.compilation.counter import compilation_counter
...
@@ -15,9 +16,12 @@ from vllm.compilation.counter import compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.plugins
import
set_compilation_config
from
vllm.plugins
import
set_compilation_config
from
vllm.utils
import
direct_register_custom_op
# create a library to hold the custom op
silly_lib
=
Library
(
"silly"
,
"FRAGMENT"
)
# noqa
@
torch
.
library
.
custom_op
(
"silly::attention"
,
mutates_args
=
[
"out"
])
def
silly_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
def
silly_attention
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
out
:
torch
.
Tensor
)
->
None
:
out
.
copy_
(
q
)
out
.
copy_
(
q
)
...
@@ -25,12 +29,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
...
@@ -25,12 +29,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
out
+=
v
out
+=
v
@
silly_attention
.
register_fake
def
silly_attention_fake
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
def
_
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
out
:
torch
.
Tensor
)
->
None
:
return
return
direct_register_custom_op
(
op_name
=
"attention"
,
op_func
=
silly_attention
,
mutates_args
=
[
"out"
],
fake_impl
=
silly_attention_fake
,
target_lib
=
silly_lib
,
)
@
dataclass
@
dataclass
class
LlamaConfig
:
class
LlamaConfig
:
hidden_size
:
int
=
128
hidden_size
:
int
=
128
...
...
vllm/attention/backends/flash_attn.py
View file @
96e0c9cb
...
@@ -14,7 +14,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
...
@@ -14,7 +14,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, CommonAttentionState,
compute_slot_mapping_start_idx
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
get_forward_context
from
vllm.utils
import
async_tensor_h2d
,
make_tensor_with_pad
from
vllm.utils
import
(
async_tensor_h2d
,
direct_register_custom_op
,
make_tensor_with_pad
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
@@ -595,8 +596,6 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -595,8 +596,6 @@ class FlashAttentionImpl(AttentionImpl):
return
output
return
output
@
torch
.
library
.
custom_op
(
"vllm::unified_flash_attention"
,
mutates_args
=
[
"kv_cache"
])
def
unified_flash_attention
(
def
unified_flash_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
@@ -755,8 +754,7 @@ def unified_flash_attention(
...
@@ -755,8 +754,7 @@ def unified_flash_attention(
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
@
unified_flash_attention
.
register_fake
def
unified_flash_attention_fake
(
def
_
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
@@ -773,3 +771,11 @@ def _(
...
@@ -773,3 +771,11 @@ def _(
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
)
return
torch
.
empty_like
(
query
)
direct_register_custom_op
(
op_name
=
"unified_flash_attention"
,
op_func
=
unified_flash_attention
,
mutates_args
=
[
"kv_cache"
],
fake_impl
=
unified_flash_attention_fake
,
)
vllm/attention/backends/flashinfer.py
View file @
96e0c9cb
...
@@ -28,8 +28,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
...
@@ -28,8 +28,8 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
is_block_tables_empty
)
is_block_tables_empty
)
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
get_forward_context
from
vllm.utils
import
(
async_tensor_h2d
,
get_kv_cache_torch_dtype
,
from
vllm.utils
import
(
async_tensor_h2d
,
direct_register_custom_op
,
make_tensor_with_pad
)
get_kv_cache_torch_dtype
,
make_tensor_with_pad
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
@@ -785,8 +785,6 @@ class FlashInferImpl(AttentionImpl):
...
@@ -785,8 +785,6 @@ class FlashInferImpl(AttentionImpl):
)
)
@
torch
.
library
.
custom_op
(
"vllm::unified_flash_infer"
,
mutates_args
=
[
"kv_cache"
])
def
unified_flash_infer
(
def
unified_flash_infer
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
@@ -906,8 +904,7 @@ def unified_flash_infer(
...
@@ -906,8 +904,7 @@ def unified_flash_infer(
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
@
unified_flash_infer
.
register_fake
def
unified_flash_infer_fake
(
def
_
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
@@ -924,3 +921,11 @@ def _(
...
@@ -924,3 +921,11 @@ def _(
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
).
contiguous
()
return
torch
.
empty_like
(
query
).
contiguous
()
direct_register_custom_op
(
op_name
=
"unified_flash_infer"
,
op_func
=
unified_flash_infer
,
mutates_args
=
[
"kv_cache"
],
fake_impl
=
unified_flash_infer_fake
,
)
vllm/distributed/parallel_state.py
View file @
96e0c9cb
...
@@ -37,7 +37,7 @@ from torch.distributed import Backend, ProcessGroup
...
@@ -37,7 +37,7 @@ from torch.distributed import Backend, ProcessGroup
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
supports_custom_op
from
vllm.utils
import
direct_register_custom_op
,
supports_custom_op
@
dataclass
@
dataclass
...
@@ -99,8 +99,6 @@ def _register_group(group: "GroupCoordinator") -> None:
...
@@ -99,8 +99,6 @@ def _register_group(group: "GroupCoordinator") -> None:
if
supports_custom_op
():
if
supports_custom_op
():
@
torch
.
library
.
custom_op
(
"vllm::inplace_all_reduce"
,
mutates_args
=
[
"tensor"
])
def
inplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
def
inplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
group
=
_groups
[
group_name
]()
...
@@ -108,11 +106,16 @@ if supports_custom_op():
...
@@ -108,11 +106,16 @@ if supports_custom_op():
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
group
.
_all_reduce_in_place
(
tensor
)
group
.
_all_reduce_in_place
(
tensor
)
@
inplace_all_reduce
.
register_fake
def
inplace_all_reduce_fake
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
return
return
@
torch
.
library
.
custom_op
(
"vllm::outplace_all_reduce"
,
mutates_args
=
[])
direct_register_custom_op
(
op_name
=
"inplace_all_reduce"
,
op_func
=
inplace_all_reduce
,
mutates_args
=
[
"tensor"
],
fake_impl
=
inplace_all_reduce_fake
,
)
def
outplace_all_reduce
(
tensor
:
torch
.
Tensor
,
def
outplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
group_name
:
str
)
->
torch
.
Tensor
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
...
@@ -121,10 +124,17 @@ if supports_custom_op():
...
@@ -121,10 +124,17 @@ if supports_custom_op():
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
_all_reduce_out_place
(
tensor
)
return
group
.
_all_reduce_out_place
(
tensor
)
@
outplace_all_reduce
.
register_fake
def
outplace_all_reduce
_fake
(
tensor
:
torch
.
Tensor
,
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
group_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
tensor
)
return
torch
.
empty_like
(
tensor
)
direct_register_custom_op
(
op_name
=
"outplace_all_reduce"
,
op_func
=
outplace_all_reduce
,
mutates_args
=
[],
fake_impl
=
outplace_all_reduce_fake
,
)
class
GroupCoordinator
:
class
GroupCoordinator
:
"""
"""
...
@@ -338,6 +348,11 @@ class GroupCoordinator:
...
@@ -338,6 +348,11 @@ class GroupCoordinator:
if
self
.
world_size
==
1
:
if
self
.
world_size
==
1
:
return
input_
return
input_
if
input_
.
is_cpu
:
import
intel_extension_for_pytorch
as
ipex
ipex
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
if
not
supports_custom_op
():
if
not
supports_custom_op
():
self
.
_all_reduce_in_place
(
input_
)
self
.
_all_reduce_in_place
(
input_
)
return
input_
return
input_
...
@@ -369,9 +384,6 @@ class GroupCoordinator:
...
@@ -369,9 +384,6 @@ class GroupCoordinator:
pynccl_comm
=
self
.
pynccl_comm
pynccl_comm
=
self
.
pynccl_comm
if
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
):
if
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
):
pynccl_comm
.
all_reduce
(
input_
)
pynccl_comm
.
all_reduce
(
input_
)
elif
input_
.
is_cpu
:
import
intel_extension_for_pytorch
as
ipex
ipex
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
else
:
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
...
...
vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
View file @
96e0c9cb
...
@@ -8,6 +8,7 @@ from vllm import _custom_ops as ops
...
@@ -8,6 +8,7 @@ from vllm import _custom_ops as ops
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
fused_topk
,
moe_align_block_size
,
try_get_optimal_moe_config
)
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
direct_register_custom_op
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
def
get_scalar_type
(
num_bits
:
int
,
has_zp
:
bool
):
...
@@ -18,7 +19,6 @@ def get_scalar_type(num_bits: int, has_zp: bool):
...
@@ -18,7 +19,6 @@ def get_scalar_type(num_bits: int, has_zp: bool):
return
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
return
scalar_types
.
uint4b8
if
num_bits
==
4
else
scalar_types
.
uint8b128
@
torch
.
library
.
custom_op
(
"vllm::single_marlin_moe"
,
mutates_args
=
[])
def
single_marlin_moe
(
def
single_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
...
@@ -119,8 +119,7 @@ def single_marlin_moe(
...
@@ -119,8 +119,7 @@ def single_marlin_moe(
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
return
torch
.
sum
(
intermediate_cache
.
view
(
*
intermediate_cache
.
shape
),
dim
=
1
)
@
single_marlin_moe
.
register_fake
def
single_marlin_moe_fake
(
def
_
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
w
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
...
@@ -136,7 +135,14 @@ def _(
...
@@ -136,7 +135,14 @@ def _(
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
@
torch
.
library
.
custom_op
(
"vllm::fused_marlin_moe"
,
mutates_args
=
[])
direct_register_custom_op
(
op_name
=
"single_marlin_moe"
,
op_func
=
single_marlin_moe
,
mutates_args
=
[],
fake_impl
=
single_marlin_moe_fake
,
)
def
fused_marlin_moe
(
def
fused_marlin_moe
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -324,8 +330,7 @@ def fused_marlin_moe(
...
@@ -324,8 +330,7 @@ def fused_marlin_moe(
dim
=
1
)
dim
=
1
)
@
fused_marlin_moe
.
register_fake
def
fused_marlin_moe_fake
(
def
_
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
@@ -344,3 +349,11 @@ def _(
...
@@ -344,3 +349,11 @@ def _(
is_k_full
:
bool
=
True
,
is_k_full
:
bool
=
True
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
op_name
=
"fused_marlin_moe"
,
op_func
=
fused_marlin_moe
,
mutates_args
=
[],
fake_impl
=
fused_marlin_moe_fake
,
)
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
96e0c9cb
...
@@ -12,6 +12,7 @@ import vllm.envs as envs
...
@@ -12,6 +12,7 @@ import vllm.envs as envs
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
direct_register_custom_op
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -466,8 +467,6 @@ def get_config_dtype_str(dtype: torch.dtype,
...
@@ -466,8 +467,6 @@ def get_config_dtype_str(dtype: torch.dtype,
return
None
return
None
@
torch
.
library
.
custom_op
(
"vllm::inplace_fused_experts"
,
mutates_args
=
[
"hidden_states"
])
def
inplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
def
inplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
@@ -484,22 +483,29 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
...
@@ -484,22 +483,29 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a1_scale
,
a2_scale
)
a1_scale
,
a2_scale
)
@
inplace_fused_experts
.
register
_fake
def
inplace_fused_experts_fake
(
def
_
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
pass
pass
@
torch
.
library
.
custom_op
(
"vllm::outplace_fused_experts"
,
mutates_args
=
[])
direct_register_custom_op
(
op_name
=
"inplace_fused_experts"
,
op_func
=
inplace_fused_experts
,
mutates_args
=
[
"hidden_states"
],
fake_impl
=
inplace_fused_experts_fake
,
)
def
outplace_fused_experts
(
def
outplace_fused_experts
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -517,21 +523,29 @@ def outplace_fused_experts(
...
@@ -517,21 +523,29 @@ def outplace_fused_experts(
w2_scale
,
a1_scale
,
a2_scale
)
w2_scale
,
a1_scale
,
a2_scale
)
@
outplace_fused_experts
.
register
_fake
def
outplace_fused_experts_fake
(
def
_
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_fp8_w8a8
:
bool
=
False
,
use_fp8_w8a8
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
use_int8_w8a16
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
hidden_states
)
return
torch
.
empty_like
(
hidden_states
)
direct_register_custom_op
(
op_name
=
"outplace_fused_experts"
,
op_func
=
outplace_fused_experts
,
mutates_args
=
[],
fake_impl
=
outplace_fused_experts_fake
,
)
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
def
fused_experts
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
...
...
vllm/utils.py
View file @
96e0c9cb
...
@@ -32,6 +32,7 @@ import torch
...
@@ -32,6 +32,7 @@ import torch
import
torch.types
import
torch.types
import
yaml
import
yaml
from
packaging.version
import
Version
from
packaging.version
import
Version
from
torch.library
import
Library
from
typing_extensions
import
ParamSpec
,
TypeIs
,
assert_never
from
typing_extensions
import
ParamSpec
,
TypeIs
,
assert_never
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -1512,3 +1513,47 @@ def weak_ref_tensors(
...
@@ -1512,3 +1513,47 @@ def weak_ref_tensors(
if
isinstance
(
tensors
,
tuple
):
if
isinstance
(
tensors
,
tuple
):
return
tuple
(
weak_ref_tensor
(
t
)
for
t
in
tensors
)
return
tuple
(
weak_ref_tensor
(
t
)
for
t
in
tensors
)
raise
ValueError
(
"Invalid type for tensors"
)
raise
ValueError
(
"Invalid type for tensors"
)
def
is_in_doc_build
()
->
bool
:
try
:
from
sphinx.ext.autodoc.mock
import
_MockModule
return
isinstance
(
torch
,
_MockModule
)
except
ModuleNotFoundError
:
return
False
# create a library to hold the custom op
vllm_lib
=
Library
(
"vllm"
,
"FRAGMENT"
)
# noqa
def
direct_register_custom_op
(
op_name
:
str
,
op_func
:
Callable
,
mutates_args
:
List
[
str
],
fake_impl
:
Optional
[
Callable
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
):
"""
`torch.library.custom_op` can have significant overhead because it
needs to consider complicated dispatching logic. This function
directly registers a custom op and dispatches it to the CUDA backend.
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
for more details.
By default, the custom op is registered to the vLLM library. If you
want to register it to a different library, you can pass the library
object to the `target_lib` argument.
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
library object. If you want to bind the operator to a different library,
make sure the library object is alive when the operator is used.
"""
if
is_in_doc_build
():
return
schema_str
=
torch
.
library
.
infer_schema
(
op_func
,
mutates_args
=
mutates_args
)
my_lib
=
target_lib
or
vllm_lib
my_lib
.
define
(
op_name
+
schema_str
)
my_lib
.
impl
(
op_name
,
op_func
,
"CUDA"
)
if
fake_impl
is
not
None
:
my_lib
.
_register_fake
(
op_name
,
fake_impl
)
vllm/v1/attention/backends/flash_attn.py
View file @
96e0c9cb
...
@@ -7,6 +7,7 @@ import torch
...
@@ -7,6 +7,7 @@ import torch
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
)
AttentionMetadata
,
AttentionType
)
from
vllm.forward_context
import
get_forward_context
from
vllm.forward_context
import
get_forward_context
from
vllm.utils
import
direct_register_custom_op
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
...
@@ -152,8 +153,6 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -152,8 +153,6 @@ class FlashAttentionImpl(AttentionImpl):
return
output
return
output
@
torch
.
library
.
custom_op
(
"vllm::unified_flash_attention"
,
mutates_args
=
[
"kv_cache"
])
def
unified_flash_attention
(
def
unified_flash_attention
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
...
@@ -217,8 +216,7 @@ def unified_flash_attention(
...
@@ -217,8 +216,7 @@ def unified_flash_attention(
return
output
.
view
(
num_tokens
,
hidden_size
)
return
output
.
view
(
num_tokens
,
hidden_size
)
@
unified_flash_attention
.
register_fake
def
unified_flash_attention_fake
(
def
_
(
query
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
...
@@ -235,3 +233,11 @@ def _(
...
@@ -235,3 +233,11 @@ def _(
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
query
)
return
torch
.
empty_like
(
query
)
direct_register_custom_op
(
op_name
=
"unified_flash_attention"
,
op_func
=
unified_flash_attention
,
mutates_args
=
[
"kv_cache"
],
fake_impl
=
unified_flash_attention_fake
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment