Unverified Commit 643ecf7b authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[V1] Refactor model executable interface for all text-only language models (#10374)


Signed-off-by: default avatarRoger Wang <ywang@roblox.com>
parent 4fd93750
...@@ -389,6 +389,9 @@ class ArcticModel(nn.Module): ...@@ -389,6 +389,9 @@ class ArcticModel(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size)) config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -396,9 +399,13 @@ class ArcticModel(nn.Module): ...@@ -396,9 +399,13 @@ class ArcticModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -439,6 +446,9 @@ class ArcticForCausalLM(nn.Module, SupportsPP): ...@@ -439,6 +446,9 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -446,9 +456,11 @@ class ArcticForCausalLM(nn.Module, SupportsPP): ...@@ -446,9 +456,11 @@ class ArcticForCausalLM(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -284,6 +284,9 @@ class BaiChuanModel(nn.Module): ...@@ -284,6 +284,9 @@ class BaiChuanModel(nn.Module):
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -291,9 +294,13 @@ class BaiChuanModel(nn.Module): ...@@ -291,9 +294,13 @@ class BaiChuanModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -363,6 +370,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -363,6 +370,9 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -370,9 +380,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -370,9 +380,11 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -251,6 +251,9 @@ class BloomModel(nn.Module): ...@@ -251,6 +251,9 @@ class BloomModel(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size)) config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings_layernorm(self.word_embeddings(input_ids))
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -258,10 +261,13 @@ class BloomModel(nn.Module): ...@@ -258,10 +261,13 @@ class BloomModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.word_embeddings(input_ids) if inputs_embeds is not None:
hidden_states = self.word_embeddings_layernorm(hidden_states) hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -301,6 +307,9 @@ class BloomForCausalLM(nn.Module, SupportsPP): ...@@ -301,6 +307,9 @@ class BloomForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -308,9 +317,11 @@ class BloomForCausalLM(nn.Module, SupportsPP): ...@@ -308,9 +317,11 @@ class BloomForCausalLM(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -280,6 +280,9 @@ class CohereModel(nn.Module): ...@@ -280,6 +280,9 @@ class CohereModel(nn.Module):
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -287,9 +290,13 @@ class CohereModel(nn.Module): ...@@ -287,9 +290,13 @@ class CohereModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -354,6 +361,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -354,6 +361,9 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
...@@ -362,9 +372,11 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -362,9 +372,11 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -321,6 +321,9 @@ class DbrxModel(nn.Module): ...@@ -321,6 +321,9 @@ class DbrxModel(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.d_model)) config.d_model))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -328,9 +331,13 @@ class DbrxModel(nn.Module): ...@@ -328,9 +331,13 @@ class DbrxModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids) if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else: else:
assert intermediate_tensors assert intermediate_tensors
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -376,6 +383,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -376,6 +383,9 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -383,9 +393,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ...@@ -383,9 +393,11 @@ class DbrxForCausalLM(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -353,6 +353,9 @@ class DeepseekModel(nn.Module): ...@@ -353,6 +353,9 @@ class DeepseekModel(nn.Module):
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -360,9 +363,13 @@ class DeepseekModel(nn.Module): ...@@ -360,9 +363,13 @@ class DeepseekModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None residual = None
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -401,6 +408,9 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): ...@@ -401,6 +408,9 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -408,9 +418,11 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): ...@@ -408,9 +418,11 @@ class DeepseekForCausalLM(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -445,6 +445,9 @@ class DeepseekV2Model(nn.Module): ...@@ -445,6 +445,9 @@ class DeepseekV2Model(nn.Module):
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -452,9 +455,13 @@ class DeepseekV2Model(nn.Module): ...@@ -452,9 +455,13 @@ class DeepseekV2Model(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -495,6 +502,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -495,6 +502,9 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -502,9 +512,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): ...@@ -502,9 +512,11 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -78,6 +78,9 @@ class EAGLE(nn.Module): ...@@ -78,6 +78,9 @@ class EAGLE(nn.Module):
def sampler(self): def sampler(self):
return self.model.sampler return self.model.sampler
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -86,11 +89,14 @@ class EAGLE(nn.Module): ...@@ -86,11 +89,14 @@ class EAGLE(nn.Module):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
previous_hidden_states: torch.Tensor, previous_hidden_states: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
tok_embeds = self.model.model.embed_tokens(input_ids) if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
inputs_embeds = self.fc( inputs_embeds = self.fc(
torch.cat([tok_embeds, previous_hidden_states], dim=-1)) torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
inputs_embeds[positions == 0] = 0 # masking inputs at position=0 inputs_embeds[positions == 0] = 0 # masking inputs at position=0
...@@ -100,7 +106,8 @@ class EAGLE(nn.Module): ...@@ -100,7 +106,8 @@ class EAGLE(nn.Module):
positions=positions, positions=positions,
kv_caches=kv_caches, kv_caches=kv_caches,
attn_metadata=attn_metadata, attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors) intermediate_tensors=intermediate_tensors,
)
return hidden_states return hidden_states
def compute_logits(self, hidden_states: torch.Tensor, def compute_logits(self, hidden_states: torch.Tensor,
......
...@@ -479,6 +479,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -479,6 +479,9 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -486,9 +489,11 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -486,9 +489,11 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.transformer(input_ids, positions, kv_caches, model_output = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return model_output return model_output
def compute_logits( def compute_logits(
......
...@@ -367,6 +367,9 @@ class FalconModel(nn.Module): ...@@ -367,6 +367,9 @@ class FalconModel(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size)) config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -374,9 +377,13 @@ class FalconModel(nn.Module): ...@@ -374,9 +377,13 @@ class FalconModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.word_embeddings(input_ids) if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
...@@ -432,6 +439,9 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -432,6 +439,9 @@ class FalconForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
...@@ -439,9 +449,11 @@ class FalconForCausalLM(nn.Module, SupportsPP): ...@@ -439,9 +449,11 @@ class FalconForCausalLM(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -390,6 +390,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -390,6 +390,9 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -397,9 +400,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -397,9 +400,11 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -272,6 +272,9 @@ class Gemma2Model(nn.Module): ...@@ -272,6 +272,9 @@ class Gemma2Model(nn.Module):
make_empty_intermediate_tensors_factory( make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size)) ["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: Optional[torch.Tensor], input_ids: Optional[torch.Tensor],
...@@ -285,7 +288,7 @@ class Gemma2Model(nn.Module): ...@@ -285,7 +288,7 @@ class Gemma2Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.normalizer hidden_states *= self.normalizer
residual = None residual = None
else: else:
...@@ -414,6 +417,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -414,6 +417,9 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -421,9 +427,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -421,9 +427,11 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -209,6 +209,9 @@ class GPT2Model(nn.Module): ...@@ -209,6 +209,9 @@ class GPT2Model(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd)) config.n_embd))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -220,7 +223,7 @@ class GPT2Model(nn.Module): ...@@ -220,7 +223,7 @@ class GPT2Model(nn.Module):
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.get_input_embeddings(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds hidden_states = inputs_embeds + position_embeds
else: else:
...@@ -262,7 +265,7 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): ...@@ -262,7 +265,7 @@ class GPT2LMHeadModel(nn.Module, SupportsPP):
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.wte(input_ids) return self.transformer.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
......
...@@ -218,6 +218,9 @@ class GPTBigCodeModel(nn.Module): ...@@ -218,6 +218,9 @@ class GPTBigCodeModel(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd)) config.n_embd))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -225,11 +228,12 @@ class GPTBigCodeModel(nn.Module): ...@@ -225,11 +228,12 @@ class GPTBigCodeModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids) if inputs_embeds is None:
position_embeds = self.wpe(position_ids) inputs_embeds = self.get_input_embeddings(input_ids)
hidden_states = inputs_embeds + position_embeds hidden_states = inputs_embeds + self.wpe(position_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
...@@ -285,6 +289,9 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -285,6 +289,9 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -292,9 +299,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -292,9 +299,11 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -201,6 +201,9 @@ class GPTJModel(nn.Module): ...@@ -201,6 +201,9 @@ class GPTJModel(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd)) config.n_embd))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -208,9 +211,13 @@ class GPTJModel(nn.Module): ...@@ -208,9 +211,13 @@ class GPTJModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.wte(input_ids) if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
...@@ -250,6 +257,9 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -250,6 +257,9 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -257,9 +267,11 @@ class GPTJForCausalLM(nn.Module, SupportsPP): ...@@ -257,9 +267,11 @@ class GPTJForCausalLM(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -214,6 +214,9 @@ class GPTNeoXModel(nn.Module): ...@@ -214,6 +214,9 @@ class GPTNeoXModel(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size)) config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_in(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -221,9 +224,13 @@ class GPTNeoXModel(nn.Module): ...@@ -221,9 +224,13 @@ class GPTNeoXModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.embed_in(input_ids) if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
else: else:
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer): for i in range(self.start_layer, self.end_layer):
...@@ -262,6 +269,9 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -262,6 +269,9 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.gpt_neox.make_empty_intermediate_tensors) self.gpt_neox.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.gpt_neox.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -269,9 +279,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): ...@@ -269,9 +279,11 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.gpt_neox(input_ids, positions, kv_caches, hidden_states = self.gpt_neox(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -409,6 +409,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -409,6 +409,9 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
else: else:
self.lm_head = PPMissingLayer() self.lm_head = PPMissingLayer()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -416,9 +419,11 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -416,9 +419,11 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]: ) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(input_ids, positions, kv_caches, model_output = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return model_output return model_output
def compute_logits( def compute_logits(
......
...@@ -277,6 +277,9 @@ class GraniteMoeModel(nn.Module): ...@@ -277,6 +277,9 @@ class GraniteMoeModel(nn.Module):
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -284,9 +287,13 @@ class GraniteMoeModel(nn.Module): ...@@ -284,9 +287,13 @@ class GraniteMoeModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids) if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
hidden_states *= self.embedding_multiplier hidden_states *= self.embedding_multiplier
residual = None residual = None
else: else:
...@@ -366,6 +373,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -366,6 +373,9 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.sampler = get_sampler() self.sampler = get_sampler()
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -373,9 +383,11 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -373,9 +383,11 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -290,7 +290,7 @@ class InternLM2Model(nn.Module): ...@@ -290,7 +290,7 @@ class InternLM2Model(nn.Module):
if inputs_embeds is not None: if inputs_embeds is not None:
hidden_states = inputs_embeds hidden_states = inputs_embeds
else: else:
hidden_states = self.tok_embeddings(input_ids) hidden_states = self.get_input_embeddings(input_ids)
residual = None residual = None
else: else:
assert intermediate_tensors is not None assert intermediate_tensors is not None
...@@ -335,6 +335,9 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP): ...@@ -335,6 +335,9 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -342,9 +345,11 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP): ...@@ -342,9 +345,11 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors], intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
...@@ -250,6 +250,9 @@ class JAISModel(nn.Module): ...@@ -250,6 +250,9 @@ class JAISModel(nn.Module):
make_empty_intermediate_tensors_factory(["hidden_states"], make_empty_intermediate_tensors_factory(["hidden_states"],
config.n_embd)) config.n_embd))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.wte(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -257,9 +260,11 @@ class JAISModel(nn.Module): ...@@ -257,9 +260,11 @@ class JAISModel(nn.Module):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[IntermediateTensors, torch.Tensor]: ) -> Union[IntermediateTensors, torch.Tensor]:
if get_pp_group().is_first_rank: if get_pp_group().is_first_rank:
inputs_embeds = self.wte(input_ids) if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings(input_ids)
if self.wpe is not None: if self.wpe is not None:
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds hidden_states = inputs_embeds + position_embeds
...@@ -311,6 +316,9 @@ class JAISLMHeadModel(nn.Module, SupportsPP): ...@@ -311,6 +316,9 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.transformer.make_empty_intermediate_tensors) self.transformer.make_empty_intermediate_tensors)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.transformer.get_input_embeddings(input_ids)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -318,9 +326,11 @@ class JAISLMHeadModel(nn.Module, SupportsPP): ...@@ -318,9 +326,11 @@ class JAISLMHeadModel(nn.Module, SupportsPP):
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[IntermediateTensors, torch.Tensor]: ) -> Union[IntermediateTensors, torch.Tensor]:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.transformer(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors) attn_metadata, intermediate_tensors,
inputs_embeds)
return hidden_states return hidden_states
def compute_logits( def compute_logits(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment