Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0f41fbe5
Unverified
Commit
0f41fbe5
authored
Oct 17, 2024
by
Luka Govedič
Committed by
GitHub
Oct 17, 2024
Browse files
[torch.compile] Fine-grained CustomOp enabling mechanism (#9300)
parent
7871659a
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
220 additions
and
21 deletions
+220
-21
tests/model_executor/test_enabled_custom_ops.py
tests/model_executor/test_enabled_custom_ops.py
+92
-0
vllm/envs.py
vllm/envs.py
+12
-1
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+64
-4
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+25
-11
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-4
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+2
-0
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+2
-1
vllm/utils.py
vllm/utils.py
+22
-0
No files found.
tests/model_executor/test_enabled_custom_ops.py
0 → 100644
View file @
0f41fbe5
import
os
from
typing
import
List
import
pytest
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.activation
import
(
GeluAndMul
,
ReLUSquaredActivation
,
SiluAndMul
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
# Registered subclass for test
@
CustomOp
.
register
(
"relu3"
)
class
Relu3
(
ReLUSquaredActivation
):
pass
@
pytest
.
mark
.
parametrize
(
"env, torch_level, ops_enabled, default_on"
,
[
# Default values based on compile level
(
""
,
0
,
[
True
]
*
4
,
True
),
(
""
,
1
,
[
True
]
*
4
,
True
),
(
""
,
2
,
[
True
]
*
4
,
True
),
# All by default
(
""
,
3
,
[
False
]
*
4
,
False
),
(
""
,
4
,
[
False
]
*
4
,
False
),
# None by default
# Explicitly enabling/disabling
#
# Default: all
#
# All but SiluAndMul
(
"+rms_norm,-silu_and_mul"
,
0
,
[
1
,
0
,
1
,
1
],
True
),
# Only ReLU3
(
"none,-rms_norm,+relu3"
,
0
,
[
0
,
0
,
0
,
1
],
False
),
# All but SiluAndMul
(
"all,-silu_and_mul"
,
1
,
[
1
,
0
,
1
,
1
],
True
),
# All but ReLU3 (even if ReLU2 is on)
(
"-relu3,relu2"
,
1
,
[
1
,
1
,
1
,
0
],
True
),
# GeluAndMul and SiluAndMul
(
"none,-relu3,+gelu_and_mul,+silu_and_mul"
,
2
,
[
0
,
1
,
1
,
0
],
False
),
# All but RMSNorm
(
"-rms_norm"
,
2
,
[
0
,
1
,
1
,
1
],
True
),
#
# Default: none
#
# Only ReLU3
(
"-silu_and_mul,+relu3"
,
3
,
[
0
,
0
,
0
,
1
],
False
),
# All but RMSNorm
(
"all,-rms_norm"
,
4
,
[
0
,
1
,
1
,
1
],
True
),
])
def
test_enabled_ops
(
env
:
str
,
torch_level
:
int
,
ops_enabled
:
List
[
int
],
default_on
:
bool
):
os
.
environ
[
"VLLM_CUSTOM_OPS"
]
=
env
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
torch_level
)
# Reset default_on (computed once):
CustomOp
.
default_on
.
cache_clear
()
assert
CustomOp
.
default_on
()
==
default_on
ops_enabled
=
[
bool
(
x
)
for
x
in
ops_enabled
]
assert
RMSNorm
(
1024
).
enabled
()
==
ops_enabled
[
0
]
assert
CustomOp
.
op_registry
[
"rms_norm"
].
enabled
()
==
ops_enabled
[
0
]
assert
SiluAndMul
().
enabled
()
==
ops_enabled
[
1
]
assert
CustomOp
.
op_registry
[
"silu_and_mul"
].
enabled
()
==
ops_enabled
[
1
]
assert
GeluAndMul
().
enabled
()
==
ops_enabled
[
2
]
assert
CustomOp
.
op_registry
[
"gelu_and_mul"
].
enabled
()
==
ops_enabled
[
2
]
# If registered, subclasses should follow their own name
assert
Relu3
().
enabled
()
==
ops_enabled
[
3
]
assert
CustomOp
.
op_registry
[
"relu3"
].
enabled
()
==
ops_enabled
[
3
]
# Unregistered subclass
class
SiluAndMul2
(
SiluAndMul
):
pass
# Subclasses should not require registration
assert
SiluAndMul2
().
enabled
()
==
SiluAndMul
().
enabled
()
@
pytest
.
mark
.
parametrize
(
"env"
,
[
"all,none"
,
"all,+rms_norm,all"
,
"+rms_norm,-rms_norm"
])
def
test_enabled_ops_invalid
(
env
:
str
):
os
.
environ
[
"VLLM_CUSTOM_OPS"
]
=
env
CustomOp
.
default_on
.
cache_clear
()
with
pytest
.
raises
(
AssertionError
):
RMSNorm
(
1024
).
enabled
()
vllm/envs.py
View file @
0f41fbe5
...
...
@@ -65,6 +65,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_TORCH_COMPILE_LEVEL
:
int
=
0
VLLM_CUSTOM_OPS
:
List
[
str
]
=
[]
VLLM_DISABLED_KERNELS
:
List
[
str
]
=
[]
...
...
@@ -205,7 +206,17 @@ environment_variables: Dict[str, Callable[[], Any]] = {
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
,
"1"
)
!=
"0"
),
"VLLM_TORCH_COMPILE_LEVEL"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TORCH_COMPILE_LEVEL"
,
"0"
)),
# Fine-grained control over which custom ops to enable/disable.
# Use 'all' to enable all, 'none' to disable all.
# Also specify a list of custom op names to enable (prefixed with a '+'),
# or disable (prefixed with a '-').
# Examples:
# - 'all,-op1' to enable all except op1
# - 'none,+op1,+op2' to enable only op1 and op2
# By default, all custom ops are enabled when running without Inductor
# and disabled when running with Inductor (compile_level >= Inductor).
"VLLM_CUSTOM_OPS"
:
lambda
:
os
.
environ
.
get
(
"VLLM_CUSTOM_OPS"
,
""
).
replace
(
" "
,
""
).
split
(
","
),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK"
:
...
...
vllm/model_executor/custom_op.py
View file @
0f41fbe5
from
functools
import
lru_cache
from
typing
import
Dict
,
Type
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
,
is_xpu
from
vllm.utils
import
is_cpu
,
is_hip
,
is_xpu
,
print_warning_once
logger
=
init_logger
(
__name__
)
class
CustomOp
(
nn
.
Module
):
"""
Base class for custom ops.
Dispatches the forward method to the appropriate backend.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
):
super
().
__init__
()
self
.
_forward_method
=
self
.
dispatch_forward
()
...
...
@@ -17,7 +27,6 @@ class CustomOp(nn.Module):
def
forward_native
(
self
,
*
args
,
**
kwargs
):
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
...
...
@@ -56,7 +65,11 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if
envs
.
VLLM_TORCH_COMPILE_LEVEL
>=
CompilationLevel
.
INDUCTOR
:
enabled
=
self
.
enabled
()
logger
.
debug
(
"custom op %s %s"
,
self
.
__class__
.
name
,
"enabled"
if
enabled
else
"disabled"
)
if
not
enabled
:
return
self
.
forward_native
if
is_hip
():
...
...
@@ -69,3 +82,50 @@ class CustomOp(nn.Module):
return
self
.
forward_xpu
else
:
return
self
.
forward_cuda
@
classmethod
def
enabled
(
cls
)
->
bool
:
# if no name, then it was not registered
if
not
hasattr
(
cls
,
"name"
):
print_warning_once
(
f
"Custom op
{
cls
.
__name__
}
was not registered, "
f
"which means it won't appear in the op registry. "
f
"It will be enabled/disabled based on the global settings."
)
return
CustomOp
.
default_on
()
enabled
=
f
"+
{
cls
.
name
}
"
in
envs
.
VLLM_CUSTOM_OPS
disabled
=
f
"-
{
cls
.
name
}
"
in
envs
.
VLLM_CUSTOM_OPS
assert
not
(
enabled
and
disabled
),
f
"Cannot enable and disable
{
cls
.
name
}
"
return
(
CustomOp
.
default_on
()
or
enabled
)
and
not
disabled
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
@
staticmethod
@
lru_cache
()
def
default_on
()
->
bool
:
count_none
=
envs
.
VLLM_CUSTOM_OPS
.
count
(
"none"
)
count_all
=
envs
.
VLLM_CUSTOM_OPS
.
count
(
"all"
)
assert
count_none
+
count_all
<=
1
,
"Can only specify 'none' or 'all'"
return
envs
.
VLLM_TORCH_COMPILE_LEVEL
<
CompilationLevel
.
INDUCTOR
and
\
not
count_none
>
0
or
count_all
>
0
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry
:
Dict
[
str
,
Type
[
'CustomOp'
]]
=
{}
# Decorator to register custom ops.
@
classmethod
def
register
(
cls
,
name
:
str
):
def
decorator
(
op_cls
):
assert
name
not
in
cls
.
op_registry
,
f
"Duplicate op name:
{
name
}
"
op_cls
.
name
=
name
cls
.
op_registry
[
name
]
=
op_cls
return
op_cls
return
decorator
vllm/model_executor/layers/activation.py
View file @
0f41fbe5
...
...
@@ -11,8 +11,10 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.utils
import
LazyDict
@
CustomOp
.
register
(
"fatrelu_and_mul"
)
class
FatreluAndMul
(
CustomOp
):
"""An activation function for FATReLU.
...
...
@@ -40,6 +42,7 @@ class FatreluAndMul(CustomOp):
return
self
.
forward_native
(
x
)
@
CustomOp
.
register
(
"silu_and_mul"
)
class
SiluAndMul
(
CustomOp
):
"""An activation function for SwiGLU.
...
...
@@ -74,6 +77,7 @@ class SiluAndMul(CustomOp):
return
out
@
CustomOp
.
register
(
"gelu_and_mul"
)
class
GeluAndMul
(
CustomOp
):
"""An activation function for GeGLU.
...
...
@@ -123,6 +127,7 @@ class GeluAndMul(CustomOp):
return
f
'approximate=
{
repr
(
self
.
approximate
)
}
'
@
CustomOp
.
register
(
"gelu_new"
)
class
NewGELU
(
CustomOp
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -144,6 +149,7 @@ class NewGELU(CustomOp):
return
ops
.
gelu_new
(
x
)
@
CustomOp
.
register
(
"gelu_fast"
)
class
FastGELU
(
CustomOp
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -164,8 +170,8 @@ class FastGELU(CustomOp):
return
ops
.
gelu_fast
(
x
)
@
CustomOp
.
register
(
"quick_gelu"
)
class
QuickGELU
(
CustomOp
):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""PyTorch-native implementation equivalent to forward()."""
...
...
@@ -189,6 +195,7 @@ class QuickGELU(CustomOp):
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
@
CustomOp
.
register
(
"relu2"
)
class
ReLUSquaredActivation
(
CustomOp
):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
...
...
@@ -244,15 +251,22 @@ class ScaledActivation(nn.Module):
param_data
.
copy_
(
loaded_weight
)
_ACTIVATION_REGISTRY
=
{
"gelu"
:
nn
.
GELU
(),
"gelu_fast"
:
FastGELU
(),
"gelu_new"
:
NewGELU
(),
"gelu_pytorch_tanh"
:
nn
.
GELU
(
approximate
=
"tanh"
),
"relu"
:
nn
.
ReLU
(),
"relu2"
:
ReLUSquaredActivation
(),
"quick_gelu"
:
QuickGELU
(),
}
_ACTIVATION_REGISTRY
=
LazyDict
({
"gelu"
:
lambda
:
nn
.
GELU
(),
"gelu_fast"
:
lambda
:
FastGELU
(),
"gelu_new"
:
lambda
:
NewGELU
(),
"gelu_pytorch_tanh"
:
lambda
:
nn
.
GELU
(
approximate
=
"tanh"
),
"relu"
:
lambda
:
nn
.
ReLU
(),
"relu2"
:
lambda
:
ReLUSquaredActivation
(),
"quick_gelu"
:
lambda
:
QuickGELU
(),
})
def
get_act_fn
(
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
0f41fbe5
...
...
@@ -37,13 +37,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise
NotImplementedError
@
CustomOp
.
register
(
"unquantized_fused_moe"
)
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
...
...
@@ -74,7 +74,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
=
x
,
layer
=
layer
,
router_logits
=
router_logits
,
...
...
@@ -97,7 +96,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
)
...
...
@@ -134,7 +132,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.moe_pallas
import
fused_moe
assert
not
use_grouped_topk
assert
num_expert_group
is
None
...
...
vllm/model_executor/layers/layernorm.py
View file @
0f41fbe5
...
...
@@ -7,6 +7,7 @@ import torch.nn as nn
from
vllm.model_executor.custom_op
import
CustomOp
@
CustomOp
.
register
(
"rms_norm"
)
class
RMSNorm
(
CustomOp
):
"""Root mean square normalization.
...
...
@@ -122,6 +123,7 @@ class RMSNorm(CustomOp):
return
s
@
CustomOp
.
register
(
"gemma_rms_norm"
)
class
GemmaRMSNorm
(
CustomOp
):
"""RMS normalization for Gemma.
...
...
vllm/model_executor/layers/rotary_embedding.py
View file @
0f41fbe5
...
...
@@ -72,6 +72,7 @@ def _apply_rotary_emb(
return
torch
.
stack
((
o1
,
o2
),
dim
=-
1
).
flatten
(
-
2
)
@
CustomOp
.
register
(
"rotary_embedding"
)
class
RotaryEmbedding
(
CustomOp
):
"""Original rotary positional embedding."""
...
...
vllm/utils.py
View file @
0f41fbe5
...
...
@@ -17,6 +17,7 @@ import uuid
import
warnings
import
weakref
from
asyncio
import
FIRST_COMPLETED
,
ensure_future
from
collections.abc
import
Mapping
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
from
typing
import
(
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generic
,
...
...
@@ -1442,3 +1443,24 @@ class AtomicCounter:
@
property
def
value
(
self
):
return
self
.
_value
# Adapted from: https://stackoverflow.com/a/47212782/5082708
class
LazyDict
(
Mapping
,
Generic
[
T
]):
def
__init__
(
self
,
factory
:
Dict
[
str
,
Callable
[[],
T
]]):
self
.
_factory
=
factory
self
.
_dict
:
Dict
[
str
,
T
]
=
{}
def
__getitem__
(
self
,
key
)
->
T
:
if
key
not
in
self
.
_dict
:
if
key
not
in
self
.
_factory
:
raise
KeyError
(
key
)
self
.
_dict
[
key
]
=
self
.
_factory
[
key
]()
return
self
.
_dict
[
key
]
def
__iter__
(
self
):
return
iter
(
self
.
_factory
)
def
__len__
(
self
):
return
len
(
self
.
_factory
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment