Unverified Commit 7025b11d authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix weight loading for Chameleon when TP>1 (#7410)

parent 5469146b
...@@ -395,8 +395,11 @@ class DeepseekForCausalLM(nn.Module): ...@@ -395,8 +395,11 @@ class DeepseekForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -505,8 +505,11 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -505,8 +505,11 @@ class DeepseekV2ForCausalLM(nn.Module):
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -420,8 +420,11 @@ class FalconForCausalLM(nn.Module): ...@@ -420,8 +420,11 @@ class FalconForCausalLM(nn.Module):
) )
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -287,8 +287,11 @@ class FuyuForCausalLM(nn.Module, SupportsVision): ...@@ -287,8 +287,11 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
) )
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.language_model.logits_processor( logits = self.language_model.logits_processor(
self.language_model.lm_head, hidden_states, sampling_metadata) self.language_model.lm_head, hidden_states, sampling_metadata)
return logits return logits
......
...@@ -352,8 +352,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA): ...@@ -352,8 +352,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states, logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -343,8 +343,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -343,8 +343,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.model.embed_tokens, hidden_states, logits = self.logits_processor(self.model.embed_tokens, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -265,8 +265,11 @@ class GPT2LMHeadModel(nn.Module): ...@@ -265,8 +265,11 @@ class GPT2LMHeadModel(nn.Module):
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -279,8 +279,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): ...@@ -279,8 +279,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -246,8 +246,11 @@ class GPTJForCausalLM(nn.Module): ...@@ -246,8 +246,11 @@ class GPTJForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata, self.lm_head.bias) sampling_metadata, self.lm_head.bias)
return logits return logits
......
...@@ -258,8 +258,11 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -258,8 +258,11 @@ class GPTNeoXForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.embed_out, hidden_states, logits = self.logits_processor(self.embed_out, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -279,8 +279,11 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -279,8 +279,11 @@ class InternLM2ForCausalLM(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.output, hidden_states, logits = self.logits_processor(self.output, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -466,8 +466,11 @@ class InternVLChatModel(nn.Module, SupportsVision): ...@@ -466,8 +466,11 @@ class InternVLChatModel(nn.Module, SupportsVision):
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
......
...@@ -295,8 +295,11 @@ class JAISLMHeadModel(nn.Module): ...@@ -295,8 +295,11 @@ class JAISLMHeadModel(nn.Module):
attn_metadata) attn_metadata)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -861,8 +861,11 @@ class JambaForCausalLM(nn.Module, HasInnerState): ...@@ -861,8 +861,11 @@ class JambaForCausalLM(nn.Module, HasInnerState):
dtype=dtype, dtype=dtype,
device="cuda")) device="cuda"))
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -430,8 +430,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -430,8 +430,11 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return model_output return model_output
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
...@@ -355,8 +355,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision): ...@@ -355,8 +355,11 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
......
...@@ -588,8 +588,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision): ...@@ -588,8 +588,11 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
return self.language_model.compute_logits(hidden_states, return self.language_model.compute_logits(hidden_states,
sampling_metadata) sampling_metadata)
......
...@@ -65,22 +65,28 @@ class Medusa(nn.Module): ...@@ -65,22 +65,28 @@ class Medusa(nn.Module):
def compute_logits( def compute_logits(
self, hidden_states: List[torch.Tensor], self, hidden_states: List[torch.Tensor],
sampling_metadata: SamplingMetadata) -> List[torch.Tensor]: sampling_metadata: SamplingMetadata) -> List[torch.Tensor]:
logits = [] logits_lst: List[torch.Tensor] = []
for hs, lm_head in zip(hidden_states, self.lm_heads): for hs, lm_head in zip(hidden_states, self.lm_heads):
_logits = self.logits_processor(lm_head, hs, sampling_metadata) _logits = self.logits_processor(lm_head, hs, sampling_metadata)
if _logits is None:
# _logits should only be None on rank > 0, in which case
# it should remain true for every lm_head
assert len(logits_lst) == 0
continue
if self.token_map is None: if self.token_map is None:
logits.append(_logits) logits_lst.append(_logits)
else: else:
logits.append(-torch.inf * torch.ones( logits_lst.append(-torch.inf * torch.ones(
size=(*_logits.shape[:-1], self.orig_vocab_size), size=(*_logits.shape[:-1], self.orig_vocab_size),
device=_logits.device, device=_logits.device,
dtype=_logits.dtype)) dtype=_logits.dtype))
logits[-1][..., self.token_map] = _logits logits_lst[-1][..., self.token_map] = _logits
return logits return logits_lst
def sample( def sample(
self, self,
......
...@@ -470,8 +470,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA): ...@@ -470,8 +470,11 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA):
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
hidden_states = hidden_states / self.scale_width hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
lm_head = self.model.embed_tokens lm_head = self.model.embed_tokens
......
...@@ -630,8 +630,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision): ...@@ -630,8 +630,11 @@ class MiniCPMVBaseModel(nn.Module, SupportsVision):
) )
return output return output
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(
sampling_metadata: SamplingMetadata) -> torch.Tensor: self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[torch.Tensor]:
logits = self.logits_processor(self.lm_head, hidden_states, logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata) sampling_metadata)
return logits return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment