Unverified Commit 08d954f0 authored by Shanshan Shen's avatar Shanshan Shen Committed by GitHub
Browse files

[Doc] Add developer guide for CustomOp (#30886)


Signed-off-by: default avatarshen-shanshan <467638484@qq.com>
parent ac9f9330
...@@ -9,10 +9,13 @@ from vllm.model_executor.custom_op import CustomOp ...@@ -9,10 +9,13 @@ from vllm.model_executor.custom_op import CustomOp
from .common import rotate_gptj, rotate_neox from .common import rotate_gptj, rotate_neox
# --8<-- [start:dual_chunk_rotary_embedding]
@CustomOp.register("dual_chunk_rotary_embedding") @CustomOp.register("dual_chunk_rotary_embedding")
class DualChunkRotaryEmbedding(CustomOp): class DualChunkRotaryEmbedding(CustomOp):
"""Rotary positional embedding for Dual Chunk Attention.""" """Rotary positional embedding for Dual Chunk Attention."""
# --8<-- [end:dual_chunk_rotary_embedding]
def __init__( def __init__(
self, self,
head_size: int, head_size: int,
......
...@@ -181,6 +181,7 @@ def get_masked_input_and_mask( ...@@ -181,6 +181,7 @@ def get_masked_input_and_mask(
return input_, ~vocab_mask return input_, ~vocab_mask
# --8<-- [start:vocab_parallel_embedding]
@CustomOp.register("vocab_parallel_embedding") @CustomOp.register("vocab_parallel_embedding")
class VocabParallelEmbedding(CustomOp): class VocabParallelEmbedding(CustomOp):
"""Embedding parallelized in the vocabulary dimension. """Embedding parallelized in the vocabulary dimension.
...@@ -221,6 +222,8 @@ class VocabParallelEmbedding(CustomOp): ...@@ -221,6 +222,8 @@ class VocabParallelEmbedding(CustomOp):
prefix: full name of the layer in the state dict prefix: full name of the layer in the state dict
""" # noqa: E501 """ # noqa: E501
# --8<-- [end:vocab_parallel_embedding]
def __init__( def __init__(
self, self,
num_embeddings: int, num_embeddings: int,
...@@ -492,6 +495,7 @@ class VocabParallelEmbedding(CustomOp): ...@@ -492,6 +495,7 @@ class VocabParallelEmbedding(CustomOp):
return s return s
# --8<-- [start:parallel_lm_head]
@CustomOp.register("parallel_lm_head") @CustomOp.register("parallel_lm_head")
class ParallelLMHead(VocabParallelEmbedding): class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head. """Parallelized LM head.
...@@ -509,6 +513,8 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -509,6 +513,8 @@ class ParallelLMHead(VocabParallelEmbedding):
padding_size: padding size for the vocabulary. padding_size: padding size for the vocabulary.
""" """
# --8<-- [end:parallel_lm_head]
def __init__( def __init__(
self, self,
num_embeddings: int, num_embeddings: int,
......
...@@ -103,8 +103,11 @@ def is_mamba(config: Plamo2Config, i: int) -> bool: ...@@ -103,8 +103,11 @@ def is_mamba(config: Plamo2Config, i: int) -> bool:
# Adapted from: # Adapted from:
# vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2 # vllm.model_executor.layers.mamba.mamba_mixer2.MambaMixer2
# transformers.models.mamba.modeling_mamba.MambaMixer # transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register(name="plamo2_mamba_mixer") # --8<-- [start:plamo2_mamba_mixer]
@CustomOp.register("plamo2_mamba_mixer")
class Plamo2MambaMixer(MambaBase, CustomOp): class Plamo2MambaMixer(MambaBase, CustomOp):
# --8<-- [end:plamo2_mamba_mixer]
def __init__(self, vllm_config: VllmConfig, *, prefix: str = "", **kwargs) -> None: def __init__(self, vllm_config: VllmConfig, *, prefix: str = "", **kwargs) -> None:
super().__init__() super().__init__()
self.config = vllm_config.model_config.hf_config self.config = vllm_config.model_config.hf_config
......
...@@ -37,10 +37,13 @@ if TYPE_CHECKING: ...@@ -37,10 +37,13 @@ if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
# --8<-- [start:transformers_fused_moe]
@CustomOp.register("transformers_fused_moe") @CustomOp.register("transformers_fused_moe")
class TransformersFusedMoE(FusedMoE): class TransformersFusedMoE(FusedMoE):
"""Custom FusedMoE for the Transformers modeling backend.""" """Custom FusedMoE for the Transformers modeling backend."""
# --8<-- [end:transformers_fused_moe]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._topk_ids: torch.Tensor = None self._topk_ids: torch.Tensor = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment