Unverified Commit f1c0fc39 authored by Roy's avatar Roy Committed by GitHub
Browse files

Migrate `logits` computation and gather to `model_runner` (#3233)

parent 6e435de7
......@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
......@@ -325,7 +326,11 @@ class LlamaForCausalLM(nn.Module):
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
def forward(
self,
......@@ -338,13 +343,18 @@ class LlamaForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
......@@ -369,7 +370,9 @@ class MixtralForCausalLM(nn.Module):
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -382,13 +385,18 @@ class MixtralForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: Optional[torch.Tensor],
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -39,6 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
......@@ -344,7 +345,8 @@ class MixtralForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = MixtralModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -357,13 +359,18 @@ class MixtralForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: Optional[torch.Tensor],
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -259,7 +260,8 @@ class MPTForCausalLM(nn.Module):
self.transformer = MPTModel(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -272,13 +274,18 @@ class MPTForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -7,6 +7,7 @@ from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
......@@ -25,7 +26,8 @@ class LlamaForCausalLM(nn.Module):
self.config = config
self.linear_method = linear_method
self.model = None
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -45,13 +47,18 @@ class LlamaForCausalLM(nn.Module):
start_ids=seq_ids.flatten())
return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -6,6 +6,7 @@ from torch import nn
from transformers import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
......@@ -26,7 +27,8 @@ class MistralForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = None
self.lm_head = None
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -48,13 +50,18 @@ class MistralForCausalLM(nn.Module):
start_ids=seq_ids)
return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -51,6 +51,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -336,7 +337,8 @@ class OLMoForCausalLM(nn.Module):
self.lm_head_weight = (self.model.transformer.wte.weight
if config.weight_tying else
self.model.transformer.ff_out.weight)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -353,13 +355,18 @@ class OLMoForCausalLM(nn.Module):
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(
......
......@@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
......@@ -292,7 +293,8 @@ class OPTForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = OPTModel(config, linear_method)
self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -305,13 +307,18 @@ class OPTForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
......@@ -256,7 +257,8 @@ class OrionForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = OrionModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -269,13 +271,18 @@ class OrionForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
......@@ -240,7 +241,8 @@ class PhiForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -254,14 +256,18 @@ class PhiForCausalLM(nn.Module):
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
head = self.lm_head
next_tokens = self.sampler(head.weight, hidden_states,
sampling_metadata, head.bias)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
......@@ -230,7 +231,8 @@ class QWenLMHeadModel(nn.Module):
self.linear_method = linear_method
self.transformer = QWenModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -243,13 +245,18 @@ class QWenLMHeadModel(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
......@@ -300,11 +301,15 @@ class Qwen2ForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = Qwen2Model(config, linear_method)
if not config.tie_word_embeddings:
if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size)
self.lm_head_weight = self.lm_head.weight
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -317,17 +322,18 @@ class Qwen2ForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
next_tokens = self.sampler(lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -33,6 +33,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
......@@ -238,7 +239,8 @@ class StablelmForCausalLM(nn.Module):
self.linear_method = linear_method
self.model = StableLMEpochModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -251,13 +253,18 @@ class StablelmForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: torch.Tensor,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
......@@ -254,7 +255,9 @@ class Starcoder2ForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.lm_head_weight = self.lm_head.weight
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward(
self,
......@@ -267,13 +270,18 @@ class Starcoder2ForCausalLM(nn.Module):
input_metadata)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample(
self,
hidden_states: Optional[torch.Tensor],
logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata)
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
......
......@@ -613,9 +613,16 @@ class ModelRunner:
input_metadata=input_metadata,
)
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
if not sampling_metadata.perform_sampling:
return None
# Sample the next token.
output = self.model.sample(
hidden_states=hidden_states,
logits=logits,
sampling_metadata=sampling_metadata,
)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment