Unverified Commit 1dfd64c1 authored by Hexiang Wang's avatar Hexiang Wang Committed by GitHub
Browse files

[PluggableLayer][3/N] Apply PluggableLayer to llm_head and vocab embedding layer (#33465)


Signed-off-by: default avatarwhx-sjtu <2952154980@qq.com>
Co-authored-by: default avatarmergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
parent ad720aef
...@@ -9,14 +9,14 @@ from vllm.distributed import ( ...@@ -9,14 +9,14 @@ from vllm.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_gather, tensor_model_parallel_gather,
) )
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform from vllm.platforms import current_platform
# --8<-- [start:logits_processor] # --8<-- [start:logits_processor]
@CustomOp.register("logits_processor") @PluggableLayer.register("logits_processor")
class LogitsProcessor(CustomOp): class LogitsProcessor(PluggableLayer):
"""Process logits and apply logits processors from sampling metadata. """Process logits and apply logits processors from sampling metadata.
This layer does the following: This layer does the following:
......
...@@ -14,7 +14,7 @@ from vllm.distributed import ( ...@@ -14,7 +14,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
...@@ -182,8 +182,8 @@ def get_masked_input_and_mask( ...@@ -182,8 +182,8 @@ def get_masked_input_and_mask(
# --8<-- [start:vocab_parallel_embedding] # --8<-- [start:vocab_parallel_embedding]
@CustomOp.register("vocab_parallel_embedding") @PluggableLayer.register("vocab_parallel_embedding")
class VocabParallelEmbedding(CustomOp): class VocabParallelEmbedding(PluggableLayer):
"""Embedding parallelized in the vocabulary dimension. """Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
...@@ -461,7 +461,7 @@ class VocabParallelEmbedding(CustomOp): ...@@ -461,7 +461,7 @@ class VocabParallelEmbedding(CustomOp):
param[: loaded_weight.shape[0]].data.copy_(loaded_weight) param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0] :].data.fill_(0) param[loaded_weight.shape[0] :].data.fill_(0)
def forward_native(self, input_): def forward(self, input_):
if self.tp_size > 1: if self.tp_size > 1:
# Build the mask. # Build the mask.
masked_input, input_mask = get_masked_input_and_mask( masked_input, input_mask = get_masked_input_and_mask(
...@@ -483,9 +483,6 @@ class VocabParallelEmbedding(CustomOp): ...@@ -483,9 +483,6 @@ class VocabParallelEmbedding(CustomOp):
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
return output return output
def forward_cuda(self, input_):
return self.forward_native(input_)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}" s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}" s += f", embedding_dim={self.embedding_dim}"
...@@ -496,7 +493,7 @@ class VocabParallelEmbedding(CustomOp): ...@@ -496,7 +493,7 @@ class VocabParallelEmbedding(CustomOp):
# --8<-- [start:parallel_lm_head] # --8<-- [start:parallel_lm_head]
@CustomOp.register("parallel_lm_head") @PluggableLayer.register("parallel_lm_head")
class ParallelLMHead(VocabParallelEmbedding): class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head. """Parallelized LM head.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment