"vscode:/vscode.git/clone" did not exist on "3cfa63ad991665b2440155cd29352342024072fd"
Unverified Commit 2061f0b8 authored by Roger Wang's avatar Roger Wang Committed by GitHub
Browse files

[Bugfix] Fix img_sizes Parsing in Phi3-Vision (#5888)

parent 96354d6a
......@@ -65,12 +65,6 @@ class Phi3ImageEmbeddingBase(nn.Module):
self.type_feature: str
self.img_processor: CLIPVisionModel
def set_img_features(self, img_features: torch.FloatTensor) -> None:
self.img_features = img_features
def set_img_sizes(self, img_sizes: torch.LongTensor) -> None:
self.img_sizes = img_sizes
def get_img_features(self,
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
LAYER_IDX = self.layer_idx
......@@ -144,21 +138,16 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
self.layer_idx = config.img_processor.get('layer_idx', -2)
self.type_feature = config.img_processor.get('type_feature', 'patch')
def forward(self,
input_ids: torch.LongTensor,
def forward(self, input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
image_sizes=None) -> torch.FloatTensor:
image_sizes: torch.Tensor) -> torch.FloatTensor:
"""process and merge text embeddings with image embeddings."""
# (batch_size, max_num_crops, 3, height, width)
img_embeds = pixel_values
img_sizes = image_sizes
if self.img_features is not None:
img_embeds = self.img_features.clone()
self.img_features = None
if self.img_sizes is not None:
img_sizes = self.img_sizes
# (batch_size, 2)
img_sizes = image_sizes
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
......@@ -190,11 +179,8 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
output_imgs = []
output_len = []
if isinstance(img_sizes, torch.Tensor):
img_sizes.squeeze_(0)
for _bs in range(bs):
h, w = img_sizes
h, w = img_sizes[_bs]
h = h // 336
w = w // 336
B_ = h * w
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment