Unverified Commit 0147f940 authored by shiyi.c_98's avatar shiyi.c_98 Committed by GitHub
Browse files

fix batch error for llava-hd (#98)

parent 23950056
......@@ -112,24 +112,28 @@ class LlavaLlamaForCausalLM(nn.Module):
need_vision = need_vision & has_pixel
if need_vision.any():
pixel_values = torch.tensor(
np.array([pixel_values[i] for i in range(bs) if need_vision[i]]),
device=self.vision_tower.device,
)
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
########## Encode Image ########
if pixel_values.ndim == 5:
if pixel_values[0].ndim == 4:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
concat_images = torch.cat(
[image for image in pixel_values], dim=0
) # ndim=4
np.concatenate(pixel_values, axis=0)
# ndim=4
concat_images = torch.tensor(
np.concatenate(pixel_values, axis=0),
device=self.vision_tower.device,
)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# hd image_features: BS, num_patch, 576, 4096
else:
# normal pixel: BS, C=3, H=336, W=336
pixel_values = torch.tensor(
np.array(pixel_values), device=self.vision_tower.device
)
image_features = self.encode_images(pixel_values)
# image_features: BS, 576, 4096
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment