Unverified Commit 1861ae8a authored by whx's avatar whx Committed by GitHub
Browse files

[PluggableLayer][1/N] Define PluggableLayer (Fix ci) (#32744)


Signed-off-by: default avatarwhx-sjtu <2952154980@qq.com>
parent 4e31b7f2
...@@ -7,7 +7,7 @@ import itertools ...@@ -7,7 +7,7 @@ import itertools
import torch import torch
import vllm.model_executor.layers.activation # noqa F401 import vllm.model_executor.layers.activation # noqa F401
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import op_registry
from vllm.triton_utils import triton from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE, set_random_seed
...@@ -33,14 +33,14 @@ def benchmark_activation( ...@@ -33,14 +33,14 @@ def benchmark_activation(
torch.set_default_device(device) torch.set_default_device(device)
if func_name == "gelu_and_mul": if func_name == "gelu_and_mul":
layer = CustomOp.op_registry[func_name](approximate="none") layer = op_registry[func_name](approximate="none")
elif func_name == "gelu_and_mul_tanh": elif func_name == "gelu_and_mul_tanh":
layer = CustomOp.op_registry["gelu_and_mul"](approximate="tanh") layer = op_registry["gelu_and_mul"](approximate="tanh")
elif func_name == "fatrelu_and_mul": elif func_name == "fatrelu_and_mul":
threshold = 0.5 threshold = 0.5
layer = CustomOp.op_registry[func_name](threshold) layer = op_registry[func_name](threshold)
else: else:
layer = CustomOp.op_registry[func_name]() layer = op_registry[func_name]()
x = torch.randn(num_tokens, dim, dtype=dtype, device=device) x = torch.randn(num_tokens, dim, dtype=dtype, device=device)
compiled_layer = torch.compile(layer.forward_native) compiled_layer = torch.compile(layer.forward_native)
......
...@@ -8,15 +8,6 @@ This document will introduce how CustomOp works in vLLM and how to implement a n ...@@ -8,15 +8,6 @@ This document will introduce how CustomOp works in vLLM and how to implement a n
`CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively. `CustomOp` manages two dictionaries of all custom ops (i.e., op classes, indexed by registered name) in its class, for vLLM and OOT plugins respectively.
??? code
```python
class CustomOp(nn.Module):
op_registry: dict[str, type["CustomOp"]] = {}
op_registry_oot: dict[str, type["CustomOp"]] = {}
```
We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later. We can use `@CustomOp.register("op_name")` to register an op class to the `CustomOp` system. After this, the `op_name` and its class will be added into the `op_registry` dictionary. In addition, We can also register an OOT op by `@CustomOp.register_oot("op_name")`. We will introduce this mechanism in detail later.
When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled (i.e., with `--compilation_config.custom_ops '["+op_name"]'`), it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method. When a `CustomOp` is called (i.e., call its `forward()` method), if it is enabled (i.e., with `--compilation_config.custom_ops '["+op_name"]'`), it will automatically dispatch the forward method to the appropriate backend according to `current_platform`. Otherwise (i.e., it is disabled), it will only call the `forward_native()` method to use PyTorch-native implementation of this forward method.
......
...@@ -13,7 +13,7 @@ import torch ...@@ -13,7 +13,7 @@ import torch
from torch._prims_common import TensorLikeType from torch._prims_common import TensorLikeType
from tests.kernels.quant_utils import native_w8a8_block_matmul from tests.kernels.quant_utils import native_w8a8_block_matmul
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import op_registry
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input
from vllm.utils.torch_utils import make_tensor_with_pad from vllm.utils.torch_utils import make_tensor_with_pad
...@@ -883,7 +883,7 @@ def torch_experts( ...@@ -883,7 +883,7 @@ def torch_experts(
f32 = torch.float32 f32 = torch.float32
act = CustomOp.op_registry[activation] act = op_registry[activation]
for i in range(num_experts): for i in range(num_experts):
mask = topk_ids == i mask = topk_ids == i
......
...@@ -11,7 +11,7 @@ from vllm.config import ( ...@@ -11,7 +11,7 @@ from vllm.config import (
get_cached_compilation_config, get_cached_compilation_config,
set_current_vllm_config, set_current_vllm_config,
) )
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp, op_registry
from vllm.model_executor.layers.activation import ( from vllm.model_executor.layers.activation import (
GeluAndMul, GeluAndMul,
ReLUSquaredActivation, ReLUSquaredActivation,
...@@ -98,17 +98,17 @@ def test_enabled_ops( ...@@ -98,17 +98,17 @@ def test_enabled_ops(
ops_enabled = [bool(x) for x in ops_enabled] ops_enabled = [bool(x) for x in ops_enabled]
assert RMSNorm(1024).enabled() == ops_enabled[0] assert RMSNorm(1024).enabled() == ops_enabled[0]
assert CustomOp.op_registry["rms_norm"].enabled() == ops_enabled[0] assert op_registry["rms_norm"].enabled() == ops_enabled[0]
assert SiluAndMul().enabled() == ops_enabled[1] assert SiluAndMul().enabled() == ops_enabled[1]
assert CustomOp.op_registry["silu_and_mul"].enabled() == ops_enabled[1] assert op_registry["silu_and_mul"].enabled() == ops_enabled[1]
assert GeluAndMul().enabled() == ops_enabled[2] assert GeluAndMul().enabled() == ops_enabled[2]
assert CustomOp.op_registry["gelu_and_mul"].enabled() == ops_enabled[2] assert op_registry["gelu_and_mul"].enabled() == ops_enabled[2]
# If registered, subclasses should follow their own name # If registered, subclasses should follow their own name
assert Relu3().enabled() == ops_enabled[3] assert Relu3().enabled() == ops_enabled[3]
assert CustomOp.op_registry["relu3"].enabled() == ops_enabled[3] assert op_registry["relu3"].enabled() == ops_enabled[3]
# Unregistered subclass # Unregistered subclass
class SiluAndMul2(SiluAndMul): class SiluAndMul2(SiluAndMul):
......
...@@ -1033,13 +1033,13 @@ class CompilationConfig: ...@@ -1033,13 +1033,13 @@ class CompilationConfig:
# check if op name exists in model # check if op name exists in model
op_name = op[1:] op_name = op[1:]
if op_name not in all_ops_in_model: if op_name not in all_ops_in_model:
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import op_registry
# Does op exist at all or is it just not present in this model? # Does op exist at all or is it just not present in this model?
# Note: Only imported op classes appear in the registry. # Note: Only imported op classes appear in the registry.
missing_str = ( missing_str = (
"doesn't exist (or wasn't imported/registered)" "doesn't exist (or wasn't imported/registered)"
if op_name not in CustomOp.op_registry if op_name not in op_registry
else "not present in model" else "not present in model"
) )
......
...@@ -11,6 +11,86 @@ from vllm.platforms import current_platform ...@@ -11,6 +11,86 @@ from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
class PluggableLayer(nn.Module):
"""
Base class for pluggable layers.
A PluggableLayer is a *module-composing* abstraction: it may instantiate other
``torch.nn.Module`` objects as sub-layers, and its functionality depends on
these sub-layers following a generalized invocation sequence. Also, it is stateful
and may hold parameters or buffers.
Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform
``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement
of the entire layer class at instantiation time, allowing customized
initialization and submodule composition.
"""
def __new__(cls, *args, **kwargs):
try:
layer_class_name = cls.__name__
except AttributeError:
raise TypeError(
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
f"was not set, possibly because it was not decorated with "
f"@PluggableLayer.register, or it's the PluggableLayer itself."
) from None
if layer_class_name not in op_registry_oot:
layer_cls_to_instantiate = cls
else:
layer_cls_to_instantiate = op_registry_oot[layer_class_name]
logger.debug(
"Instantiating pluggable layer: %s using %s",
layer_class_name,
str(layer_cls_to_instantiate),
)
return super().__new__(layer_cls_to_instantiate)
# Decorator to register pluggable layers.
@classmethod
def register(cls, name: str):
def decorator(op_cls):
assert name not in op_registry, f"Duplicate op name: {name}"
op_cls.name = name
op_registry[name] = op_cls
return op_cls
return decorator
# Decorator to register out-of-tree(oot) pluggable layers.
# For OOT pluggable layers:
# if in-tree layer class is registered with an oot_custom_layer,
# the oot_custom_layer will be used instead.
@classmethod
def register_oot(cls, _decorated_layer_cls=None, name: str | None = None):
def decorator(layer_cls):
reg_name = name if name is not None else cls.__name__
assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}"
layer_cls.name = reg_name
op_registry_oot[reg_name] = layer_cls
return layer_cls
if _decorated_layer_cls is None:
# Called with parentheses: @PluggableLayer.register_oot()
# or @PluggableLayer.register_oot(name="...")
return decorator
elif isinstance(_decorated_layer_cls, type): # Check if it's a class
# Called without parentheses: @PluggableLayer.register_oot
return decorator(_decorated_layer_cls)
else:
raise TypeError("Decorator can only be applied to classes.")
class CustomOp(nn.Module): class CustomOp(nn.Module):
""" """
Base class for custom ops. Base class for custom ops.
...@@ -27,10 +107,10 @@ class CustomOp(nn.Module): ...@@ -27,10 +107,10 @@ class CustomOp(nn.Module):
f"@CustomOp.register, or it's the CustomOp base class itself." f"@CustomOp.register, or it's the CustomOp base class itself."
) from None ) from None
if op_name not in cls.op_registry_oot: if op_name not in op_registry_oot:
op_cls_to_instantiate = cls op_cls_to_instantiate = cls
else: else:
op_cls_to_instantiate = cls.op_registry_oot[op_name] op_cls_to_instantiate = op_registry_oot[op_name]
logger.debug( logger.debug(
"Instantiating custom op: %s using %s", "Instantiating custom op: %s using %s",
op_name, op_name,
...@@ -150,21 +230,13 @@ class CustomOp(nn.Module): ...@@ -150,21 +230,13 @@ class CustomOp(nn.Module):
return not count_none > 0 or count_all > 0 return not count_none > 0 or count_all > 0
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type["CustomOp"]] = {}
op_registry_oot: dict[str, type["CustomOp"]] = {}
# Decorator to register custom ops. # Decorator to register custom ops.
@classmethod @classmethod
def register(cls, name: str): def register(cls, name: str):
def decorator(op_cls): def decorator(op_cls):
assert name not in cls.op_registry, f"Duplicate op name: {name}" assert name not in op_registry, f"Duplicate op name: {name}"
op_cls.name = name op_cls.name = name
cls.op_registry[name] = op_cls op_registry[name] = op_cls
return op_cls return op_cls
return decorator return decorator
...@@ -182,9 +254,9 @@ class CustomOp(nn.Module): ...@@ -182,9 +254,9 @@ class CustomOp(nn.Module):
def register_oot(cls, _decorated_op_cls=None, name: str | None = None): def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
def decorator(op_cls): def decorator(op_cls):
reg_name = name if name is not None else cls.__name__ reg_name = name if name is not None else cls.__name__
assert reg_name not in cls.op_registry_oot, f"Duplicate op name: {reg_name}" assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}"
op_cls.name = reg_name op_cls.name = reg_name
cls.op_registry_oot[reg_name] = op_cls op_registry_oot[reg_name] = op_cls
return op_cls return op_cls
if _decorated_op_cls is None: if _decorated_op_cls is None:
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
from vllm.attention.layer import MLAAttention from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -30,13 +30,13 @@ class MLAModules: ...@@ -30,13 +30,13 @@ class MLAModules:
# --8<-- [start:multi_head_latent_attention] # --8<-- [start:multi_head_latent_attention]
@CustomOp.register("multi_head_latent_attention") @PluggableLayer.register("multi_head_latent_attention")
class MultiHeadLatentAttentionWrapper(CustomOp): class MultiHeadLatentAttentionWrapper(PluggableLayer):
"""MLA layer registered as CustomOp to allow OOT backends to add """Pluggable MLA layer which allows OOT backends to add
custom implementations of the outer MLA layer (including rope & o_proj). custom implementations of the outer MLA layer (including rope & o_proj).
Note that currently MLA ignores the enable/disable mechanism of CustomOp Note that currently oot platforms can still use CustomOp.register_oot to
because there is only one in-tree implementation in forward_native. replace MLA layer entirly, although we use PluggableLayer to register
TODO: implement this with a new PluggableLayer mechanism. this layer now.
This class takes positions and hidden_states as input. This class takes positions and hidden_states as input.
The input tensors can either contain prefill tokens or decode tokens. The input tensors can either contain prefill tokens or decode tokens.
...@@ -110,7 +110,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp): ...@@ -110,7 +110,7 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
self.prefix = prefix self.prefix = prefix
def forward_native( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -174,6 +174,3 @@ class MultiHeadLatentAttentionWrapper(CustomOp): ...@@ -174,6 +174,3 @@ class MultiHeadLatentAttentionWrapper(CustomOp):
) )
return self.o_proj(attn_out)[0] return self.o_proj(attn_out)[0]
def forward_cuda(self, *args, **kwargs):
return self.forward_native(*args, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment