Unverified Commit 7025b11d authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)

parent 5469146b
...@@ -375,8 +375,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA): ...@@ -375,8 +375,11 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -362,8 +362,11 @@ class MixtralForCausalLM(nn.Module): ...@@ -362,8 +362,11 @@ class MixtralForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -279,8 +279,11 @@ class MPTForCausalLM(nn.Module): ...@@ -279,8 +279,11 @@ class MPTForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -453,8 +453,11 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA): ...@@ -453,8 +453,11 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return model_output return model_output
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -311,8 +311,11 @@ class OlmoForCausalLM(nn.Module): ...@@ -311,8 +311,11 @@ class OlmoForCausalLM(nn.Module):
) )
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -323,8 +323,11 @@ class OPTForCausalLM(nn.Module): ...@@ -323,8 +323,11 @@ class OPTForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -277,8 +277,11 @@ class OrionForCausalLM(nn.Module): ...@@ -277,8 +277,11 @@ class OrionForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -262,8 +262,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -262,8 +262,11 @@ class PaliGemmaForConditionalGeneration(nn.Module, SupportsVision):
return hidden_states return hidden_states
# Copied from vllm/model_executor/models/gemma.py # Copied from vllm/model_executor/models/gemma.py
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.language_model.embed_tokens, logits = self.logits_processor(self.language_model.embed_tokens,
hidden_states, sampling_metadata) hidden_states, sampling_metadata)
return logits return logits
......
...@@ -285,8 +285,11 @@ class PersimmonForCausalLM(nn.Module): ...@@ -285,8 +285,11 @@ class PersimmonForCausalLM(nn.Module):
) )
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -286,8 +286,11 @@ class PhiForCausalLM(nn.Module, SupportsLoRA): ...@@ -286,8 +286,11 @@ class PhiForCausalLM(nn.Module, SupportsLoRA):
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias) sampling_metadata, self.lm_head.bias)
return logits return logits
......
...@@ -399,8 +399,11 @@ class Phi3SmallForCausalLM(nn.Module): ...@@ -399,8 +399,11 @@ class Phi3SmallForCausalLM(nn.Module):
def get_decoder(self): def get_decoder(self):
return self.model return self.model
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
if self.dummy_token_indices is not None and logits is not None: if self.dummy_token_indices is not None and logits is not None:
......
...@@ -584,8 +584,11 @@ class Phi3VForCausalLM(nn.Module, SupportsVision): ...@@ -584,8 +584,11 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -281,8 +281,11 @@ class QWenLMHeadModel(nn.Module): ...@@ -281,8 +281,11 @@ class QWenLMHeadModel(nn.Module):
device=device), device=device),
}) })
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -362,8 +362,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -362,8 +362,11 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -400,8 +400,11 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -400,8 +400,11 @@ class Qwen2MoeForCausalLM(nn.Module):
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -258,8 +258,11 @@ class StablelmForCausalLM(nn.Module): ...@@ -258,8 +258,11 @@ class StablelmForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -268,8 +268,11 @@ class Starcoder2ForCausalLM(nn.Module): ...@@ -268,8 +268,11 @@ class Starcoder2ForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -328,8 +328,11 @@ class XverseForCausalLM(nn.Module, SupportsLoRA): ...@@ -328,8 +328,11 @@ class XverseForCausalLM(nn.Module, SupportsLoRA):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple, Union from typing import List, Optional
from typing import Sequence as GenericSequence
from typing import Union
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
...@@ -28,7 +30,7 @@ class CompletionOutput: ...@@ -28,7 +30,7 @@ class CompletionOutput:
index: int index: int
text: str text: str
token_ids: Tuple[int, ...] token_ids: GenericSequence[int]
cumulative_logprob: Optional[float] cumulative_logprob: Optional[float]
logprobs: Optional[SampleLogprobs] logprobs: Optional[SampleLogprobs]
finish_reason: Optional[str] = None finish_reason: Optional[str] = None
...@@ -139,7 +141,7 @@ class RequestOutput: ...@@ -139,7 +141,7 @@ class RequestOutput:
CompletionOutput( CompletionOutput(
seqs.index(seq), seqs.index(seq),
seq.get_output_text_to_return(text_buffer_length), seq.get_output_text_to_return(text_buffer_length),
seq.data._output_token_ids, # type: ignore seq.data._output_token_ids,
seq.get_cumulative_logprob() if include_logprobs else None, seq.get_cumulative_logprob() if include_logprobs else None,
seq.output_logprobs if include_logprobs else None, seq.output_logprobs if include_logprobs else None,
SequenceStatus.get_finished_reason(seq.status), SequenceStatus.get_finished_reason(seq.status),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment