Unverified Commit f1c0fc39 authored by Roy's avatar Roy Committed by GitHub
Browse files

Migrate `logits` computation and gather to `model_runner` (#3233)

parent 6e435de7
...@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, ...@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
...@@ -325,7 +326,11 @@ class LlamaForCausalLM(nn.Module): ...@@ -325,7 +326,11 @@ class LlamaForCausalLM(nn.Module):
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
) )
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -338,13 +343,18 @@ class LlamaForCausalLM(nn.Module): ...@@ -338,13 +343,18 @@ class LlamaForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, ...@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
...@@ -369,7 +370,9 @@ class MixtralForCausalLM(nn.Module): ...@@ -369,7 +370,9 @@ class MixtralForCausalLM(nn.Module):
# compatibility # compatibility
if not lora_config else lora_config.lora_vocab_padding_size, if not lora_config else lora_config.lora_vocab_padding_size,
) )
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -382,13 +385,18 @@ class MixtralForCausalLM(nn.Module): ...@@ -382,13 +385,18 @@ class MixtralForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: Optional[torch.Tensor], logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -39,6 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, ...@@ -39,6 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -344,7 +345,8 @@ class MixtralForCausalLM(nn.Module): ...@@ -344,7 +345,8 @@ class MixtralForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = MixtralModel(config, linear_method) self.model = MixtralModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -357,13 +359,18 @@ class MixtralForCausalLM(nn.Module): ...@@ -357,13 +359,18 @@ class MixtralForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: Optional[torch.Tensor], logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -259,7 +260,8 @@ class MPTForCausalLM(nn.Module): ...@@ -259,7 +260,8 @@ class MPTForCausalLM(nn.Module):
self.transformer = MPTModel(config, linear_method) self.transformer = MPTModel(config, linear_method)
self.lm_head_weight = self.transformer.wte.weight self.lm_head_weight = self.transformer.wte.weight
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -272,13 +274,18 @@ class MPTForCausalLM(nn.Module): ...@@ -272,13 +274,18 @@ class MPTForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -7,6 +7,7 @@ from torch import nn ...@@ -7,6 +7,7 @@ from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -25,7 +26,8 @@ class LlamaForCausalLM(nn.Module): ...@@ -25,7 +26,8 @@ class LlamaForCausalLM(nn.Module):
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.model = None self.model = None
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -45,13 +47,18 @@ class LlamaForCausalLM(nn.Module): ...@@ -45,13 +47,18 @@ class LlamaForCausalLM(nn.Module):
start_ids=seq_ids.flatten()) start_ids=seq_ids.flatten())
return logits return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head, next_tokens = self.sampler(logits, sampling_metadata)
hidden_states, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -6,6 +6,7 @@ from torch import nn ...@@ -6,6 +6,7 @@ from torch import nn
from transformers import MistralConfig from transformers import MistralConfig
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput from vllm.sequence import SamplerOutput
...@@ -26,7 +27,8 @@ class MistralForCausalLM(nn.Module): ...@@ -26,7 +27,8 @@ class MistralForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = None self.model = None
self.lm_head = None self.lm_head = None
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -48,13 +50,18 @@ class MistralForCausalLM(nn.Module): ...@@ -48,13 +50,18 @@ class MistralForCausalLM(nn.Module):
start_ids=seq_ids) start_ids=seq_ids)
return logits return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head, next_tokens = self.sampler(logits, sampling_metadata)
hidden_states, sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -51,6 +51,7 @@ from vllm.model_executor.layers.linear import ( ...@@ -51,6 +51,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -336,7 +337,8 @@ class OLMoForCausalLM(nn.Module): ...@@ -336,7 +337,8 @@ class OLMoForCausalLM(nn.Module):
self.lm_head_weight = (self.model.transformer.wte.weight self.lm_head_weight = (self.model.transformer.wte.weight
if config.weight_tying else if config.weight_tying else
self.model.transformer.ff_out.weight) self.model.transformer.ff_out.weight)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -353,13 +355,18 @@ class OLMoForCausalLM(nn.Module): ...@@ -353,13 +355,18 @@ class OLMoForCausalLM(nn.Module):
) )
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights( def load_weights(
......
...@@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
...@@ -292,7 +293,8 @@ class OPTForCausalLM(nn.Module): ...@@ -292,7 +293,8 @@ class OPTForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = OPTModel(config, linear_method) self.model = OPTModel(config, linear_method)
self.lm_head_weight = self.model.decoder.embed_tokens.weight self.lm_head_weight = self.model.decoder.embed_tokens.weight
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -305,13 +307,18 @@ class OPTForCausalLM(nn.Module): ...@@ -305,13 +307,18 @@ class OPTForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, ...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -256,7 +257,8 @@ class OrionForCausalLM(nn.Module): ...@@ -256,7 +257,8 @@ class OrionForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = OrionModel(config, linear_method) self.model = OrionModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -269,13 +271,18 @@ class OrionForCausalLM(nn.Module): ...@@ -269,13 +271,18 @@ class OrionForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -240,7 +241,8 @@ class PhiForCausalLM(nn.Module): ...@@ -240,7 +241,8 @@ class PhiForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size, config.hidden_size,
bias=True) bias=True)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -254,14 +256,18 @@ class PhiForCausalLM(nn.Module): ...@@ -254,14 +256,18 @@ class PhiForCausalLM(nn.Module):
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata, self.lm_head.bias)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
head = self.lm_head next_tokens = self.sampler(logits, sampling_metadata)
next_tokens = self.sampler(head.weight, hidden_states,
sampling_metadata, head.bias)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, ...@@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -230,7 +231,8 @@ class QWenLMHeadModel(nn.Module): ...@@ -230,7 +231,8 @@ class QWenLMHeadModel(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.transformer = QWenModel(config, linear_method) self.transformer = QWenModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -243,13 +245,18 @@ class QWenLMHeadModel(nn.Module): ...@@ -243,13 +245,18 @@ class QWenLMHeadModel(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, ...@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -300,11 +301,15 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -300,11 +301,15 @@ class Qwen2ForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = Qwen2Model(config, linear_method) self.model = Qwen2Model(config, linear_method)
if not config.tie_word_embeddings: if config.tie_word_embeddings:
self.lm_head_weight = self.model.embed_tokens.weight
else:
self.lm_head = ParallelLMHead(config.vocab_size, self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size) config.hidden_size)
self.lm_head_weight = self.lm_head.weight
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -317,17 +322,18 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -317,17 +322,18 @@ class Qwen2ForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
if self.config.tie_word_embeddings: next_tokens = self.sampler(logits, sampling_metadata)
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
next_tokens = self.sampler(lm_head_weight, hidden_states,
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -33,6 +33,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, ...@@ -33,6 +33,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead) VocabParallelEmbedding, ParallelLMHead)
...@@ -238,7 +239,8 @@ class StablelmForCausalLM(nn.Module): ...@@ -238,7 +239,8 @@ class StablelmForCausalLM(nn.Module):
self.linear_method = linear_method self.linear_method = linear_method
self.model = StableLMEpochModel(config, linear_method) self.model = StableLMEpochModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -251,13 +253,18 @@ class StablelmForCausalLM(nn.Module): ...@@ -251,13 +253,18 @@ class StablelmForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: torch.Tensor, logits: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase, LinearMethodBase,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
...@@ -254,7 +255,9 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -254,7 +255,9 @@ class Starcoder2ForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) )
self.lm_head_weight = self.lm_head.weight self.lm_head_weight = self.lm_head.weight
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -267,13 +270,18 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -267,13 +270,18 @@ class Starcoder2ForCausalLM(nn.Module):
input_metadata) input_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head_weight, hidden_states,
sampling_metadata)
return logits
def sample( def sample(
self, self,
hidden_states: Optional[torch.Tensor], logits: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(logits, sampling_metadata)
sampling_metadata)
return next_tokens return next_tokens
def load_weights(self, def load_weights(self,
......
...@@ -613,9 +613,16 @@ class ModelRunner: ...@@ -613,9 +613,16 @@ class ModelRunner:
input_metadata=input_metadata, input_metadata=input_metadata,
) )
# Compute the logits.
logits = self.model.compute_logits(hidden_states, sampling_metadata)
# Only perform sampling in the driver worker.
if not sampling_metadata.perform_sampling:
return None
# Sample the next token. # Sample the next token.
output = self.model.sample( output = self.model.sample(
hidden_states=hidden_states, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment