Unverified Commit 0f5b5260 authored by zhang-prog's avatar zhang-prog Committed by GitHub
Browse files

[Fix] Remove unused packing_position_embedding from PaddleOCRVL for better...


[Fix] Remove unused packing_position_embedding from PaddleOCRVL for better checkpoint compatibility (#38232)
Signed-off-by: default avatarzhangyue66 <zhangyue66@baidu.com>
parent be1a85b7
...@@ -409,7 +409,6 @@ class SiglipVisionEmbeddings(nn.Module): ...@@ -409,7 +409,6 @@ class SiglipVisionEmbeddings(nn.Module):
self.cache_position_embedding = dict() self.cache_position_embedding = dict()
self.cache_position_count = dict() self.cache_position_count = dict()
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.packing_position_embedding = nn.Embedding(32768, self.embed_dim)
self.register_buffer( self.register_buffer(
"position_ids", "position_ids",
...@@ -501,7 +500,6 @@ class SiglipVisionEmbeddings(nn.Module): ...@@ -501,7 +500,6 @@ class SiglipVisionEmbeddings(nn.Module):
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
embeddings = patch_embeds.flatten(-2).squeeze(-1) embeddings = patch_embeds.flatten(-2).squeeze(-1)
if interpolate_pos_encoding and image_grid_thw is not None:
start = 0 start = 0
tmp_embeddings = list() tmp_embeddings = list()
for image_grid in image_grid_thw: for image_grid in image_grid_thw:
...@@ -517,8 +515,7 @@ class SiglipVisionEmbeddings(nn.Module): ...@@ -517,8 +515,7 @@ class SiglipVisionEmbeddings(nn.Module):
tmp_embeddings.append(image_embeddings) tmp_embeddings.append(image_embeddings)
start = end start = end
embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0) embeddings = torch.concat(tmp_embeddings, dim=0).unsqueeze(0)
else:
embeddings = embeddings + self.packing_position_embedding(position_ids)
return embeddings return embeddings
else: else:
raise ValueError( raise ValueError(
...@@ -939,6 +936,8 @@ class SiglipVisionModel(nn.Module): ...@@ -939,6 +936,8 @@ class SiglipVisionModel(nn.Module):
continue continue
if "head.mlp" in name or "head.probe" in name: if "head.mlp" in name or "head.probe" in name:
continue continue
if "packing_position_embedding" in name:
continue
if self.quant_config is not None and ( if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name) scale_name := self.quant_config.get_cache_scale(name)
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment