Unverified Commit 1dfd64c1 authored by Hexiang Wang's avatar Hexiang Wang Committed by GitHub
Browse files

[PluggableLayer][3/N] Apply PluggableLayer to llm_head and vocab embedding layer (#33465)


Signed-off-by: default avatarwhx-sjtu <2952154980@qq.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent ad720aef
......@@ -9,14 +9,14 @@ from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_gather,
)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform
# --8<-- [start:logits_processor]
@CustomOp.register("logits_processor")
class LogitsProcessor(CustomOp):
@PluggableLayer.register("logits_processor")
class LogitsProcessor(PluggableLayer):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following:
......
......@@ -14,7 +14,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
......@@ -182,8 +182,8 @@ def get_masked_input_and_mask(
# --8<-- [start:vocab_parallel_embedding]
@CustomOp.register("vocab_parallel_embedding")
class VocabParallelEmbedding(CustomOp):
@PluggableLayer.register("vocab_parallel_embedding")
class VocabParallelEmbedding(PluggableLayer):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
......@@ -461,7 +461,7 @@ class VocabParallelEmbedding(CustomOp):
param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0] :].data.fill_(0)
def forward_native(self, input_):
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
masked_input, input_mask = get_masked_input_and_mask(
......@@ -483,9 +483,6 @@ class VocabParallelEmbedding(CustomOp):
output = tensor_model_parallel_all_reduce(output_parallel)
return output
def forward_cuda(self, input_):
return self.forward_native(input_)
def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}"
......@@ -496,7 +493,7 @@ class VocabParallelEmbedding(CustomOp):
# --8<-- [start:parallel_lm_head]
@CustomOp.register("parallel_lm_head")
@PluggableLayer.register("parallel_lm_head")
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment