Unverified Commit 9f1c4eca authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Token type and position embeddings fail to be applied to `inputs_embeds` (#25922)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent ef283548
...@@ -61,11 +61,13 @@ class BertEmbedding(nn.Module): ...@@ -61,11 +61,13 @@ class BertEmbedding(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids) token_type_ids = _decode_token_type_ids(input_ids)
inputs_embeds = self.word_embeddings(input_ids) if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
...@@ -358,11 +360,12 @@ class BertModel(nn.Module, SupportsQuant): ...@@ -358,11 +360,12 @@ class BertModel(nn.Module, SupportsQuant):
intermediate_tensors: Optional[IntermediateTensors] = None, intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if inputs_embeds is not None: hidden_states = self.embeddings(
hidden_states = inputs_embeds input_ids=input_ids,
else: position_ids=positions,
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds,
position_ids=positions) )
return self.encoder(hidden_states) return self.encoder(hidden_states)
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
...@@ -56,11 +56,13 @@ class RobertaEmbedding(nn.Module): ...@@ -56,11 +56,13 @@ class RobertaEmbedding(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids) token_type_ids = _decode_token_type_ids(input_ids)
inputs_embeds = self.word_embeddings(input_ids) if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment