Unverified Commit 736569da authored by zzhxxx's avatar zzhxxx Committed by GitHub
Browse files

[Platform] Custom ops support for LMhead and LogitsProcessor (#23564)


Signed-off-by: default avatarzzhx1 <zzh_201018@outlook.com>
parent 2eb9986a
...@@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor ...@@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed import (tensor_model_parallel_all_gather, from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather) tensor_model_parallel_gather)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -22,7 +22,8 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None: ...@@ -22,7 +22,8 @@ if envs.VLLM_LOGITS_PROCESSOR_THREADS is not None:
envs.VLLM_LOGITS_PROCESSOR_THREADS) envs.VLLM_LOGITS_PROCESSOR_THREADS)
class LogitsProcessor(nn.Module): @CustomOp.register("logits_processor")
class LogitsProcessor(CustomOp):
"""Process logits and apply logits processors from sampling metadata. """Process logits and apply logits processors from sampling metadata.
This layer does the following: This layer does the following:
......
...@@ -429,6 +429,7 @@ class VocabParallelEmbedding(CustomOp): ...@@ -429,6 +429,7 @@ class VocabParallelEmbedding(CustomOp):
return s return s
@CustomOp.register("parallel_lm_head")
class ParallelLMHead(VocabParallelEmbedding): class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head. """Parallelized LM head.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment