Unverified Commit 64fe3115 authored by Geary.Z's avatar Geary.Z Committed by GitHub
Browse files

replace skip_embed with input_embeds (#222)

parent a7ace9c8
...@@ -227,12 +227,12 @@ class LlamaModel(nn.Module): ...@@ -227,12 +227,12 @@ class LlamaModel(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
skip_embed: bool = False, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if not skip_embed: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
else: else:
hidden_states = input_ids hidden_states = input_embeds
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
...@@ -264,9 +264,9 @@ class LlamaForCausalLM(nn.Module): ...@@ -264,9 +264,9 @@ class LlamaForCausalLM(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
skip_embed: bool = False, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
......
...@@ -230,11 +230,11 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -230,11 +230,11 @@ class LlavaLlamaForCausalLM(nn.Module):
pt += 1 pt += 1
return self.language_model( return self.language_model(
input_embeds, positions, input_metadata, skip_embed=True input_ids, positions, input_metadata, input_embeds=input_embeds
) )
elif input_metadata.forward_mode == ForwardMode.DECODE: elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model( return self.language_model(
input_ids, positions, input_metadata, skip_embed=False input_ids, positions, input_metadata
) )
def load_weights( def load_weights(
......
...@@ -296,12 +296,12 @@ class MixtralModel(nn.Module): ...@@ -296,12 +296,12 @@ class MixtralModel(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
skip_embed: bool = False, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if not skip_embed: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
else: else:
hidden_states = input_ids hidden_states = input_embeds
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
...@@ -330,9 +330,9 @@ class MixtralForCausalLM(nn.Module): ...@@ -330,9 +330,9 @@ class MixtralForCausalLM(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
skip_embed: bool = False, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
......
...@@ -228,12 +228,12 @@ class Qwen2Model(nn.Module): ...@@ -228,12 +228,12 @@ class Qwen2Model(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
skip_embed: bool = False, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if not skip_embed: if input_embeds is None:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
else: else:
hidden_states = input_ids hidden_states = input_embeds
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
...@@ -265,9 +265,9 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -265,9 +265,9 @@ class Qwen2ForCausalLM(nn.Module):
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
skip_embed: bool = False, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, skip_embed) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment