Commit 36b5b6a0 authored by mashun1's avatar mashun1
Browse files

gme

parents
Pipeline #2588 failed with stages
in 0 seconds
__pycache__
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10
\ No newline at end of file
# gme-Qwen2-VL
## 论文
`GME:Improving Universal Multimodal Retrieval by Multimodal LLMs`
* https://arxiv.org/pdf/2412.16855
## 模型结构
该模型基于Qwen2-VL系列MLLM,包含视觉编码器、文本编码器和跨模态投影模块。
视觉编码器:采用类似ViT的结构,将图像分块编码为视觉token序列,并限制每张图像最多生成1024个视觉token以平衡效率与性能。
文本编码器:基于Transformer架构,支持长文本输入(最大长度1800 token)。
跨模态投影层:将视觉token与文本token对齐到统一的语义空间。
![alt text](readme_imgs/arch.png)
## 算法原理
算法采用对比学习,输入可以是文本,图像或文本-图像对。
![alt text](readme_imgs/alg.png)
## 环境配置
### Docker(方法一)
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04-py3.10
docker run --shm-size 100g --network=host --name=gme --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -r requirements.txt
### Dockerfile(方法二)
docker build -t <IMAGE_NAME>:<TAG> .
docker run --shm-size 100g --network=host --name=gme --privileged --device=/dev/kfd --device=/dev/dri --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v 项目地址(绝对路径):/home/ -v /opt/hyhal:/opt/hyhal:ro -it <your IMAGE ID> bash
pip install -r requirements.txt
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
DTK驱动: dtk24.04
python: python3.10
torch: 2.4.0
torchvision: 0.19.1
```
2、其他
```bash
pip install -r requirements.txt
```
## 数据集
## 训练
## 推理
```bash
python gme_inference.py
```
注意:运行前请在文件中修改相应参数。
## result
![alt text](readme_imgs/result.png)
### 精度
## 应用场景
### 算法类别
`多模态嵌入`
### 热点应用行业
`电商,教育,交通,能源`
## 预训练权重
|model|save_path|url|
|:---:|:---:|:---:|
|gme-Qwen2-VL-2B-Instruct|ckpts/gme-Qwen2-VL-2B-Instruct|[hf](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct) \| [SCNet](http://113.200.138.88:18080/aimodels/alibaba-nlp/gme-Qwen2-VL-2B-Instruct) |
|gme-Qwen2-VL-7B-Instruct|ckpts/gme-Qwen2-VL-7B-Instruct|[hf](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct) \| [SCNet](http://113.200.138.88:18080/aimodels/alibaba-nlp/gme-Qwen2-VL-7B-Instruct) |
## 源码仓库及问题反馈
* https://developer.sourcefind.cn/codes/modelzoo/gme-qwen2-vl_pytorch
## 参考资料
* https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct
from __future__ import annotations
import logging
import math
import os
from typing import Dict, List, Optional
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
from transformers import AutoModelForVision2Seq, AutoProcessor
class GmeQwen2VL:
def __init__(
self,
model_name: str = "Alibaba-NLP/gme-Qwen2-VL-2B-Instruct",
model_path: Optional[str] = None,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
min_image_tokens=256,
max_image_tokens=1280,
max_length=1800,
**kwargs,
) -> None:
model_name = model_path or model_name
self.base = AutoModelForVision2Seq.from_pretrained(
model_name, torch_dtype=torch.float16, **kwargs
)
self.base.eval()
self.normalize = True
self.device = device
min_pixels = min_image_tokens * 28 * 28
max_pixels = max_image_tokens * 28 * 28
self.max_length = max_length
self.processor = AutoProcessor.from_pretrained(
model_name, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs
)
self.processor.tokenizer.padding_side = 'right'
self.defualt_instruction = 'You are a helpful assistant.'
self.sep = ' '
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
# pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
# video_grid_thw: Optional[torch.LongTensor] = None,
pooling_mask: Optional[torch.LongTensor] = None,
**kwargs
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.base.model.embed_tokens(input_ids)
if pixel_values is not None:
pixel_values = pixel_values.type(self.base.visual.get_dtype())
image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)
image_mask = input_ids == self.base.config.image_token_id
inputs_embeds[image_mask] = image_embeds
# if pixel_values_videos is not None:
# pixel_values_videos = pixel_values_videos.type(self.base.visual.get_dtype())
# video_embeds = self.base.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)
# video_mask = input_ids == self.base.config.video_token_id
# inputs_embeds[video_mask] = video_embeds
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)
outputs = self.base.model(
input_ids=None,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
)
pooling_mask = attention_mask if pooling_mask is None else pooling_mask
left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO
if left_padding:
embeddings = outputs.last_hidden_state[:, -1]
else:
sequence_lengths = pooling_mask.sum(dim=1) - 1
batch_size = outputs.last_hidden_state.shape[0]
embeddings = outputs.last_hidden_state[torch.arange(
batch_size, device=outputs.last_hidden_state.device
), sequence_lengths]
if self.normalize:
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.contiguous()
def embed(self, texts: list[str], images: list[Image.Image], is_query=True, instruction=None, **kwargs):
self.base.to(self.device)
# Inputs must be batched
input_texts, input_images = list(), list()
for t, i in zip(texts, images):
if not is_query or instruction is None:
instruction = self.defualt_instruction
input_str = ''
if i is None:
input_images = None # All examples in the same batch are consistent
else:
input_str += '<|vision_start|><|image_pad|><|vision_end|>'
i = fetch_image(i)
input_images.append(i)
if t is not None:
input_str += t
msg = f'<|im_start|>system\n{instruction}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>'
input_texts.append(msg)
inputs = self.processor(
text=input_texts,
images=input_images,
padding=True,
truncation=True,
max_length=self.max_length,
return_tensors='pt'
)
inputs = {k: v.to(self.device) for k, v in inputs.items()} # TODO
with torch.no_grad():
embeddings = self.forward(**inputs)
return embeddings
def encode(self, sentences: list[str], *, prompt_name=None, **kwargs):
return self.get_fused_embeddings(texts=sentences, prompt_name=prompt_name, **kwargs)
def encode_queries(self, queries: List[str], **kwargs):
embeddings = self.encode(queries, **kwargs)
return embeddings
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs):
if type(corpus) is dict:
sentences = [
(corpus["title"][i] + self.sep + corpus["text"][i]).strip()
if "title" in corpus
else corpus["text"][i].strip()
for i in range(len(corpus["text"]))
]
else:
sentences = [
(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip()
for doc in corpus
]
embeddings = self.encode(sentences, is_query=False, **kwargs)
return embeddings
def get_image_embeddings(self, images: list[Image.Image] | DataLoader, **kwargs):
return self.get_fused_embeddings(images=images, **kwargs)
def get_text_embeddings(self, texts: list[str], **kwargs):
return self.get_fused_embeddings(texts=texts, **kwargs)
def get_fused_embeddings(self, texts: list[str] = None, images: list[Image.Image] | DataLoader = None, **kwargs):
if isinstance(images, DataLoader):
image_loader = images
batch_size = image_loader.batch_size
image_loader.dataset.transform = None
else:
batch_size = kwargs.pop('batch_size', 32)
if images is None:
image_loader = None
else:
image_loader = DataLoader(
images,
batch_size=batch_size,
shuffle=False,
collate_fn=custom_collate_fn,
num_workers=min(math.floor(os.cpu_count() / 2), 8),
)
if texts is None:
assert image_loader is not None
n_batch = len(image_loader)
else:
n_batch = len(texts) // batch_size + int(len(texts) % batch_size > 0)
image_loader = image_loader or [None] * n_batch
all_embeddings = list()
none_batch = [None] * batch_size
show_progress_bar = kwargs.pop('show_progress_bar', True)
pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc='encode')
for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader):
text_batch = none_batch if texts is None else texts[n: n+batch_size]
img_batch = none_batch if img_batch is None else img_batch
embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)
pbar.update(1)
all_embeddings.append(embeddings.cpu())
pbar.close()
all_embeddings = torch.cat(all_embeddings, dim=0)
return all_embeddings
def custom_collate_fn(batch):
return batch
### Copied from qwen_vl_utils.vision_process.py
import base64
from io import BytesIO
import requests
IMAGE_FACTOR = 28
MIN_PIXELS = 4 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
def round_by_factor(number: int, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: int, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: int, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:
logging.warning(
f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}"
)
if h_bar > w_bar:
h_bar = w_bar * MAX_RATIO
else:
w_bar = h_bar * MAX_RATIO
return h_bar, w_bar
def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:
image_obj = None
if isinstance(image, Image.Image):
image_obj = image
elif image.startswith("http://") or image.startswith("https://"):
image_obj = Image.open(requests.get(image, stream=True).raw)
elif image.startswith("file://"):
image_obj = Image.open(image[7:])
elif image.startswith("data:image"):
if "base64," in image:
_, base64_data = image.split("base64,", 1)
data = base64.b64decode(base64_data)
image_obj = Image.open(BytesIO(data))
else:
image_obj = Image.open(image)
if image_obj is None:
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
image = image_obj.convert("RGB")
## resize
# if "resized_height" in ele and "resized_width" in ele:
# resized_height, resized_width = smart_resize(
# ele["resized_height"],
# ele["resized_width"],
# factor=size_factor,
# )
# else:
width, height = image.size
# min_pixels = ele.get("min_pixels", MIN_PIXELS)
# max_pixels = ele.get("max_pixels", MAX_PIXELS)
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=MIN_PIXELS,
max_pixels=MAX_PIXELS,
)
image = image.resize((resized_width, resized_height))
return image
###
if __name__ == '__main__':
texts = [
"What kind of car is this?",
"The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023."
]
images = [
'./examples/image1.png',
'./examples/image2.png',
]
gme = GmeQwen2VL("ckpts/gme-Qwen2-VL-7B-Instruct")
# Single-modal embedding
e_text = gme.get_text_embeddings(texts=texts)
e_image = gme.get_image_embeddings(images=images)
print((e_text * e_image).sum(-1))
## tensor([0.2281, 0.6001], dtype=torch.float16)
# How to set embedding instruction
e_query = gme.get_text_embeddings(texts=texts, instruction='Find an image that matches the given text.')
# If is_query=False, we always use the default instruction.
e_corpus = gme.get_image_embeddings(images=images, is_query=False)
print((e_query * e_corpus).sum(-1))
## tensor([0.2433, 0.7051], dtype=torch.float16)
# Fused-modal embedding
e_fused = gme.get_fused_embeddings(texts=texts, images=images)
print((e_fused[0] * e_fused[1]).sum())
## tensor(0.6108, dtype=torch.float16)
\ No newline at end of file
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"\n",
"import logging\n",
"import math\n",
"import os\n",
"from typing import Dict, List, Optional\n",
"\n",
"import torch\n",
"from PIL import Image\n",
"from torch.utils.data import DataLoader\n",
"from tqdm.autonotebook import tqdm\n",
"from transformers import AutoModelForVision2Seq, AutoProcessor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class GmeQwen2VL:\n",
" def __init__(\n",
" self,\n",
" model_name: str = \"Alibaba-NLP/gme-Qwen2-VL-2B-Instruct\",\n",
" model_path: Optional[str] = None,\n",
" device: str = \"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
" min_image_tokens=256,\n",
" max_image_tokens=1280,\n",
" max_length=1800,\n",
" **kwargs,\n",
" ) -> None:\n",
" model_name = model_path or model_name\n",
" self.base = AutoModelForVision2Seq.from_pretrained(\n",
" model_name, torch_dtype=torch.float16, **kwargs\n",
" )\n",
" self.base.eval()\n",
" self.normalize = True\n",
" self.device = device\n",
" min_pixels = min_image_tokens * 28 * 28\n",
" max_pixels = max_image_tokens * 28 * 28\n",
" self.max_length = max_length\n",
" self.processor = AutoProcessor.from_pretrained(\n",
" model_name, min_pixels=min_pixels, max_pixels=max_pixels, **kwargs\n",
" )\n",
" self.processor.tokenizer.padding_side = 'right'\n",
" self.defualt_instruction = 'You are a helpful assistant.'\n",
" self.sep = ' '\n",
"\n",
" def forward(\n",
" self,\n",
" input_ids: Optional[torch.LongTensor] = None,\n",
" attention_mask: Optional[torch.Tensor] = None,\n",
" position_ids: Optional[torch.LongTensor] = None,\n",
" past_key_values: Optional[List[torch.FloatTensor]] = None,\n",
" inputs_embeds: Optional[torch.FloatTensor] = None,\n",
" pixel_values: Optional[torch.Tensor] = None,\n",
" # pixel_values_videos: Optional[torch.FloatTensor] = None,\n",
" image_grid_thw: Optional[torch.LongTensor] = None,\n",
" # video_grid_thw: Optional[torch.LongTensor] = None,\n",
" pooling_mask: Optional[torch.LongTensor] = None,\n",
" **kwargs\n",
" ) -> torch.Tensor:\n",
" if inputs_embeds is None:\n",
" inputs_embeds = self.base.model.embed_tokens(input_ids)\n",
" if pixel_values is not None:\n",
" pixel_values = pixel_values.type(self.base.visual.get_dtype())\n",
" image_embeds = self.base.visual(pixel_values, grid_thw=image_grid_thw).to(inputs_embeds.device)\n",
" image_mask = input_ids == self.base.config.image_token_id\n",
" inputs_embeds[image_mask] = image_embeds\n",
" # if pixel_values_videos is not None:\n",
" # pixel_values_videos = pixel_values_videos.type(self.base.visual.get_dtype())\n",
" # video_embeds = self.base.visual(pixel_values_videos, grid_thw=video_grid_thw).to(inputs_embeds.device)\n",
" # video_mask = input_ids == self.base.config.video_token_id\n",
" # inputs_embeds[video_mask] = video_embeds\n",
" if attention_mask is not None:\n",
" attention_mask = attention_mask.to(inputs_embeds.device)\n",
"\n",
" outputs = self.base.model(\n",
" input_ids=None,\n",
" position_ids=position_ids,\n",
" attention_mask=attention_mask,\n",
" past_key_values=past_key_values,\n",
" inputs_embeds=inputs_embeds,\n",
" )\n",
"\n",
" pooling_mask = attention_mask if pooling_mask is None else pooling_mask\n",
" left_padding = (pooling_mask[:, -1].sum() == pooling_mask.shape[0]) # TODO\n",
" if left_padding:\n",
" embeddings = outputs.last_hidden_state[:, -1]\n",
" else:\n",
" sequence_lengths = pooling_mask.sum(dim=1) - 1\n",
" batch_size = outputs.last_hidden_state.shape[0]\n",
" embeddings = outputs.last_hidden_state[torch.arange(\n",
" batch_size, device=outputs.last_hidden_state.device\n",
" ), sequence_lengths]\n",
" if self.normalize:\n",
" embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)\n",
" return embeddings.contiguous()\n",
"\n",
" def embed(self, texts: list[str], images: list[Image.Image], is_query=True, instruction=None, **kwargs):\n",
" self.base.to(self.device)\n",
" # Inputs must be batched\n",
" input_texts, input_images = list(), list()\n",
" for t, i in zip(texts, images):\n",
" if not is_query or instruction is None:\n",
" instruction = self.defualt_instruction\n",
" input_str = ''\n",
" if i is None:\n",
" input_images = None # All examples in the same batch are consistent\n",
" else:\n",
" input_str += '<|vision_start|><|image_pad|><|vision_end|>'\n",
" i = fetch_image(i)\n",
" input_images.append(i)\n",
" if t is not None:\n",
" input_str += t\n",
" msg = f'<|im_start|>system\\n{instruction}<|im_end|>\\n<|im_start|>user\\n{input_str}<|im_end|>\\n<|im_start|>assistant\\n<|endoftext|>'\n",
" input_texts.append(msg)\n",
"\n",
" inputs = self.processor(\n",
" text=input_texts,\n",
" images=input_images,\n",
" padding=True,\n",
" truncation=True,\n",
" max_length=self.max_length,\n",
" return_tensors='pt'\n",
" )\n",
" inputs = {k: v.to(self.device) for k, v in inputs.items()} # TODO\n",
" with torch.no_grad():\n",
" embeddings = self.forward(**inputs)\n",
" return embeddings\n",
"\n",
" def encode(self, sentences: list[str], *, prompt_name=None, **kwargs):\n",
" return self.get_fused_embeddings(texts=sentences, prompt_name=prompt_name, **kwargs)\n",
"\n",
" def encode_queries(self, queries: List[str], **kwargs):\n",
" embeddings = self.encode(queries, **kwargs)\n",
" return embeddings\n",
"\n",
" def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs):\n",
" if type(corpus) is dict:\n",
" sentences = [\n",
" (corpus[\"title\"][i] + self.sep + corpus[\"text\"][i]).strip()\n",
" if \"title\" in corpus\n",
" else corpus[\"text\"][i].strip()\n",
" for i in range(len(corpus[\"text\"]))\n",
" ]\n",
" else:\n",
" sentences = [\n",
" (doc[\"title\"] + self.sep + doc[\"text\"]).strip() if \"title\" in doc else doc[\"text\"].strip()\n",
" for doc in corpus\n",
" ]\n",
" embeddings = self.encode(sentences, is_query=False, **kwargs)\n",
" return embeddings\n",
"\n",
" def get_image_embeddings(self, images: list[Image.Image] | DataLoader, **kwargs):\n",
" return self.get_fused_embeddings(images=images, **kwargs)\n",
"\n",
" def get_text_embeddings(self, texts: list[str], **kwargs):\n",
" return self.get_fused_embeddings(texts=texts, **kwargs)\n",
"\n",
" def get_fused_embeddings(self, texts: list[str] = None, images: list[Image.Image] | DataLoader = None, **kwargs):\n",
" if isinstance(images, DataLoader):\n",
" image_loader = images\n",
" batch_size = image_loader.batch_size\n",
" image_loader.dataset.transform = None\n",
" else:\n",
" batch_size = kwargs.pop('batch_size', 32)\n",
" if images is None:\n",
" image_loader = None\n",
" else:\n",
" image_loader = DataLoader(\n",
" images,\n",
" batch_size=batch_size,\n",
" shuffle=False,\n",
" collate_fn=custom_collate_fn,\n",
" num_workers=min(math.floor(os.cpu_count() / 2), 8),\n",
" )\n",
"\n",
" if texts is None:\n",
" assert image_loader is not None\n",
" n_batch = len(image_loader)\n",
" else:\n",
" n_batch = len(texts) // batch_size + int(len(texts) % batch_size > 0)\n",
" image_loader = image_loader or [None] * n_batch\n",
"\n",
" all_embeddings = list()\n",
" none_batch = [None] * batch_size\n",
" show_progress_bar = kwargs.pop('show_progress_bar', True)\n",
" pbar = tqdm(total=n_batch, disable=not show_progress_bar, mininterval=1, miniters=10, desc='encode')\n",
" for n, img_batch in zip(range(0, n_batch * batch_size, batch_size), image_loader):\n",
" text_batch = none_batch if texts is None else texts[n: n+batch_size]\n",
" img_batch = none_batch if img_batch is None else img_batch\n",
" embeddings = self.embed(texts=text_batch, images=img_batch, **kwargs)\n",
" pbar.update(1)\n",
" all_embeddings.append(embeddings.cpu())\n",
" pbar.close()\n",
" all_embeddings = torch.cat(all_embeddings, dim=0)\n",
" return all_embeddings\n",
"\n",
"\n",
"def custom_collate_fn(batch):\n",
" return batch\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import base64\n",
"from io import BytesIO\n",
"import requests\n",
"\n",
"IMAGE_FACTOR = 28\n",
"MIN_PIXELS = 4 * 28 * 28\n",
"MAX_PIXELS = 16384 * 28 * 28\n",
"MAX_RATIO = 200\n",
"\n",
"\n",
"def round_by_factor(number: int, factor: int) -> int:\n",
" \"\"\"Returns the closest integer to 'number' that is divisible by 'factor'.\"\"\"\n",
" return round(number / factor) * factor\n",
"\n",
"\n",
"def ceil_by_factor(number: int, factor: int) -> int:\n",
" \"\"\"Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.\"\"\"\n",
" return math.ceil(number / factor) * factor\n",
"\n",
"\n",
"def floor_by_factor(number: int, factor: int) -> int:\n",
" \"\"\"Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.\"\"\"\n",
" return math.floor(number / factor) * factor\n",
"\n",
"\n",
"def smart_resize(\n",
" height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS\n",
") -> tuple[int, int]:\n",
" \"\"\"\n",
" Rescales the image so that the following conditions are met:\n",
" 1. Both dimensions (height and width) are divisible by 'factor'.\n",
" 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].\n",
" 3. The aspect ratio of the image is maintained as closely as possible.\n",
" \"\"\"\n",
" h_bar = max(factor, round_by_factor(height, factor))\n",
" w_bar = max(factor, round_by_factor(width, factor))\n",
" if h_bar * w_bar > max_pixels:\n",
" beta = math.sqrt((height * width) / max_pixels)\n",
" h_bar = floor_by_factor(height / beta, factor)\n",
" w_bar = floor_by_factor(width / beta, factor)\n",
" elif h_bar * w_bar < min_pixels:\n",
" beta = math.sqrt(min_pixels / (height * width))\n",
" h_bar = ceil_by_factor(height * beta, factor)\n",
" w_bar = ceil_by_factor(width * beta, factor)\n",
"\n",
" if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO:\n",
" logging.warning(\n",
" f\"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}\"\n",
" )\n",
" if h_bar > w_bar:\n",
" h_bar = w_bar * MAX_RATIO\n",
" else:\n",
" w_bar = h_bar * MAX_RATIO\n",
" return h_bar, w_bar\n",
"\n",
"\n",
"def fetch_image(image: str | Image.Image, size_factor: int = IMAGE_FACTOR) -> Image.Image:\n",
" image_obj = None\n",
" if isinstance(image, Image.Image):\n",
" image_obj = image\n",
" elif image.startswith(\"http://\") or image.startswith(\"https://\"):\n",
" image_obj = Image.open(requests.get(image, stream=True).raw)\n",
" elif image.startswith(\"file://\"):\n",
" image_obj = Image.open(image[7:])\n",
" elif image.startswith(\"data:image\"):\n",
" if \"base64,\" in image:\n",
" _, base64_data = image.split(\"base64,\", 1)\n",
" data = base64.b64decode(base64_data)\n",
" image_obj = Image.open(BytesIO(data))\n",
" else:\n",
" image_obj = Image.open(image)\n",
" if image_obj is None:\n",
" raise ValueError(f\"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}\")\n",
" image = image_obj.convert(\"RGB\")\n",
" ## resize\n",
" # if \"resized_height\" in ele and \"resized_width\" in ele:\n",
" # resized_height, resized_width = smart_resize(\n",
" # ele[\"resized_height\"],\n",
" # ele[\"resized_width\"],\n",
" # factor=size_factor,\n",
" # )\n",
" # else:\n",
" width, height = image.size\n",
" # min_pixels = ele.get(\"min_pixels\", MIN_PIXELS)\n",
" # max_pixels = ele.get(\"max_pixels\", MAX_PIXELS)\n",
" resized_height, resized_width = smart_resize(\n",
" height,\n",
" width,\n",
" factor=size_factor,\n",
" min_pixels=MIN_PIXELS,\n",
" max_pixels=MAX_PIXELS,\n",
" )\n",
" image = image.resize((resized_width, resized_height))\n",
"\n",
" return image"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"gme = GmeQwen2VL(\"ckpts/gme-Qwen2-VL-2B-Instruct\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"texts = [\n",
" \"What kind of car is this?\",\n",
" \"The Tesla Cybertruck is a battery electric pickup truck built by Tesla, Inc. since 2023.\"\n",
" ]\n",
"\n",
"images = [\n",
" './examples/image1.png',\n",
" './examples/image2.png',\n",
"]\n",
"\n",
"# Single-modal embedding\n",
"e_text = gme.get_text_embeddings(texts=texts)\n",
"e_image = gme.get_image_embeddings(images=images)\n",
"print((e_text * e_image).sum(-1))\n",
"## tensor([0.2281, 0.6001], dtype=torch.float16)\n",
"\n",
"# How to set embedding instruction\n",
"e_query = gme.get_text_embeddings(texts=texts, instruction='Find an image that matches the given text.')\n",
"# If is_query=False, we always use the default instruction.\n",
"e_corpus = gme.get_image_embeddings(images=images, is_query=False)\n",
"print((e_query * e_corpus).sum(-1))\n",
"## tensor([0.2433, 0.7051], dtype=torch.float16)\n",
"\n",
"# Fused-modal embedding\n",
"e_fused = gme.get_fused_embeddings(texts=texts, images=images)\n",
"print((e_fused[0] * e_fused[1]).sum())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
# 模型唯一标识
modelCode=1474
# 模型名称
modelName=gme-Qwen2-VL_pytorch
# 模型描述
modelDescription=阿里开发的多模态嵌入模型
# 应用场景
appScenario=推理,多模态嵌入,电商,教育,交通,能源
# 框架类型
frameType=Pytorch
transformers==4.46.3
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment