Unverified Commit 08d954f0 authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[Doc] Add developer guide for CustomOp (#30886)


Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
parent ac9f9330
......@@ -9,10 +9,13 @@ from vllm.model_executor.custom_op import CustomOp
from .common import rotate_gptj, rotate_neox
# --8<-- [start:dual_chunk_rotary_embedding]
@CustomOp.register("dual_chunk_rotary_embedding")
class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention."""
# --8<-- [end:dual_chunk_rotary_embedding]
def __init__(
self,
head_size: int,
......
......@@ -181,6 +181,7 @@ def get_masked_input_and_mask(
return input_, ~vocab_mask
# --8<-- [start:vocab_parallel_embedding]
@CustomOp.register("vocab_parallel_embedding")
class VocabParallelEmbedding(CustomOp):
"""Embedding parallelized in the vocabulary dimension.
......@@ -221,6 +222,8 @@ class VocabParallelEmbedding(CustomOp):
prefix: full name of the layer in the state dict
""" # noqa: E501
# --8<-- [end:vocab_parallel_embedding]
def __init__(
self,
num_embeddings: int,
......@@ -492,6 +495,7 @@ class VocabParallelEmbedding(CustomOp):
return s
# --8<-- [start:parallel_lm_head]
@CustomOp.register("parallel_lm_head")
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
......@@ -509,6 +513,8 @@ class ParallelLMHead(VocabParallelEmbedding):
padding_size: padding size for the vocabulary.
"""
# --8<-- [end:parallel_lm_head]
def __init__(
self,
num_embeddings: int,
......
......@@ -103,8 +103,11 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
# Adapted from:
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
# transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register(name="plamo2_mamba_mixer")
# --8<-- [start:plamo2_mamba_mixer]
@CustomOp.register("plamo2_mamba_mixer")
class Plamo2MambaMixer(MambaBase, CustomOp):
# --8<-- [end:plamo2_mamba_mixer]
def __init__(self, vllm_config: VllmConfig, *, prefix: str = "", **kwargs) -> None:
super().__init__()
self.config = vllm_config.model_config.hf_config
......
......@@ -37,10 +37,13 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig
# --8<-- [start:transformers_fused_moe]
@CustomOp.register("transformers_fused_moe")
class TransformersFusedMoE(FusedMoE):
"""Custom FusedMoE for the Transformers modeling backend."""
# --8<-- [end:transformers_fused_moe]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._topk_ids: torch.Tensor = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment