"docs/vscode:/vscode.git/clone" did not exist on "41c80698b3849969dcb5c5e40d0991b0eb4821cc"
Unverified Commit 9f1c4eca authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Token type and position embeddings fail to be applied to `inputs_embeds` (#25922)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent ef283548
......@@ -61,11 +61,13 @@ class BertEmbedding(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
......@@ -358,11 +360,12 @@ class BertModel(nn.Module, SupportsQuant):
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embeddings(input_ids=input_ids,
position_ids=positions)
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=positions,
inputs_embeds=inputs_embeds,
)
return self.encoder(hidden_states)
def _load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
......
......@@ -56,11 +56,13 @@ class RobertaEmbedding(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
token_type_ids = _decode_token_type_ids(input_ids)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment