Unverified Commit 6f15ac5d authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Don'e assume `position_embedding_type` will be present for BERT and RoBERTa models (#30770)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 676db55e
...@@ -55,7 +55,9 @@ class BertEmbedding(nn.Module): ...@@ -55,7 +55,9 @@ class BertEmbedding(nn.Module):
"position_ids", "position_ids",
torch.arange(config.max_position_embeddings).unsqueeze(0), torch.arange(config.max_position_embeddings).unsqueeze(0),
) )
self.position_embedding_type = config.position_embedding_type self.position_embedding_type = getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type != "absolute": if self.position_embedding_type != "absolute":
raise ValueError( raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported" "Only 'absolute' position_embedding_type" + " is supported"
......
...@@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module): ...@@ -57,12 +57,6 @@ class RobertaEmbedding(nn.Module):
torch.arange(config.max_position_embeddings).unsqueeze(0), torch.arange(config.max_position_embeddings).unsqueeze(0),
) )
self.position_embedding_type = config.position_embedding_type
if self.position_embedding_type != "absolute":
raise ValueError(
"Only 'absolute' position_embedding_type" + " is supported"
)
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
...@@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel): ...@@ -135,12 +129,12 @@ class RobertaEmbeddingModel(BertEmbeddingModel):
def _build_model( def _build_model(
self, vllm_config: VllmConfig, prefix: str = "" self, vllm_config: VllmConfig, prefix: str = ""
) -> BertModel | BertWithRope: ) -> BertModel | BertWithRope:
if vllm_config.model_config.hf_config.position_embedding_type == "rotary": hf_config = vllm_config.model_config.hf_config
return JinaRobertaModel(vllm_config=vllm_config, prefix=prefix) kwargs = dict(vllm_config=vllm_config, prefix=prefix)
if getattr(hf_config, "position_embedding_type", "absolute") == "absolute":
return BertModel(**kwargs, embedding_class=RobertaEmbedding)
else: else:
return BertModel( return JinaRobertaModel(**kwargs)
vllm_config=vllm_config, prefix=prefix, embedding_class=RobertaEmbedding
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
weights_list = list(weights) weights_list = list(weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment