Unverified Commit 664287b2 authored by Kaichen Zhang - NTU's avatar Kaichen Zhang - NTU Committed by GitHub
Browse files

[Feat] Add llava qwen, llava mistral (#419)


Co-authored-by: default avatarBo Li <drluodian@gmail.com>
parent e0ae5d42
"""
Usage:
# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
# Installing latest sglang.
# Endpoint Service CLI:
# python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --tokenizer-path lmms-lab/llama3-llava-next-8b-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4
python3 http_llama3_llava_test.py
Output:
"Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment."
"""
import argparse
import asyncio
import json
import time
import copy
import aiohttp
import requests
from llava.conversation import (
default_conversation,
conv_templates,
SeparatorStyle,
conv_llava_llama_3,
conv_qwen,
)
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def test_concurrent(args):
url = f"{args.host}:{args.port}"
prompt = "<image>\nPlease generate caption towards this image."
conv_template = copy.deepcopy(conv_llava_llama_3)
conv_template.append_message(role="user", message=prompt)
prompt_with_template = conv_template.get_prompt()
response = []
for i in range(1):
response.append(
send_request(
url + "/generate",
{
"text": prompt_with_template,
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
"sampling_params": {
"max_new_tokens": 1024,
"temperature": 0,
"top_p": 1.0,
"presence_penalty": 2,
"frequency_penalty": 2,
"stop": "<|eot_id|>",
},
},
)
)
rets = await asyncio.gather(*response)
for ret in rets:
print(ret["text"])
def test_streaming(args):
url = f"{args.host}:{args.port}"
prompt = "<image>\nPlease generate caption towards this image."
conv_template = copy.deepcopy(conv_llava_llama_3)
conv_template.append_message(role="user", message=prompt)
prompt_with_template = conv_template.get_prompt()
pload = {
"text": prompt_with_template,
"sampling_params": {
"max_new_tokens": 1024,
"temperature": 0,
"top_p": 1.0,
"presence_penalty": 2,
"frequency_penalty": 2,
"stop": "<|eot_id|>",
},
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
"stream": True,
}
response = requests.post(
url + "/generate",
json=pload,
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
asyncio.run(test_concurrent(args))
test_streaming(args)
"""
Usage:
# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
# Installing latest sglang.
# Endpoint Service CLI:
# python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --tokenizer-path lmms-lab/llavanext-qwen-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4
python3 http_qwen_llava_test.py
Output:
"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants."
"""
import argparse
import asyncio
import json
import time
import copy
import aiohttp
import requests
from llava.conversation import (
default_conversation,
conv_templates,
SeparatorStyle,
conv_llava_llama_3,
conv_qwen,
)
async def send_request(url, data, delay=0):
await asyncio.sleep(delay)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=data) as resp:
output = await resp.json()
return output
async def test_concurrent(args):
url = f"{args.host}:{args.port}"
prompt = "<image>\nPlease generate caption towards this image."
conv_template = copy.deepcopy(conv_qwen)
conv_template.append_message(role="user", message=prompt)
prompt_with_template = conv_template.get_prompt()
response = []
for i in range(1):
response.append(
send_request(
url + "/generate",
{
"text": prompt_with_template,
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
"sampling_params": {
"max_new_tokens": 1024,
"temperature": 0,
"top_p": 1.0,
"presence_penalty": 2,
"frequency_penalty": 2,
"stop": "<|im_end|>",
},
},
)
)
rets = await asyncio.gather(*response)
for ret in rets:
print(ret["text"])
def test_streaming(args):
url = f"{args.host}:{args.port}"
prompt = "<image>\nPlease generate caption towards this image."
conv_template = copy.deepcopy(conv_qwen)
conv_template.append_message(role="user", message=prompt)
prompt_with_template = conv_template.get_prompt()
pload = {
"text": prompt_with_template,
"sampling_params": {
"max_new_tokens": 1024,
"temperature": 0,
"top_p": 1.0,
"presence_penalty": 2,
"frequency_penalty": 2,
"stop": "<|im_end|>",
},
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
"stream": True,
}
response = requests.post(
url + "/generate",
json=pload,
stream=True,
)
prev = 0
for chunk in response.iter_lines(decode_unicode=False):
chunk = chunk.decode("utf-8")
if chunk and chunk.startswith("data:"):
if chunk == "data: [DONE]":
break
data = json.loads(chunk[5:].strip("\n"))
output = data["text"].strip()
print(output[prev:], end="", flush=True)
prev = len(output)
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="http://127.0.0.1")
parser.add_argument("--port", type=int, default=30000)
args = parser.parse_args()
# asyncio.run(test_concurrent(args))
test_streaming(args)
"""
Usage: python3 srt_example_llava.py
"""
import sglang as sgl
from sglang.srt.utils import load_image
from sglang.lang.chat_template import get_chat_template
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images
@sgl.function
def image_qa(s, image, question):
s += sgl.user(sgl.image(image) + question)
s += sgl.assistant(sgl.gen("answer"))
def single():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image = load_image(image_url)
state = image_qa.run(image=pil_image, question="What is this?", max_new_tokens=512)
print(state["answer"], "\n")
def stream():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image = load_image(image_url)
state = image_qa.run(
image=pil_image,
question="Please generate short caption for this image.",
max_new_tokens=512,
temperature=0,
stream=True,
)
for out in state.text_iter("answer"):
print(out, end="", flush=True)
print()
def batch():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image = load_image(image_url)
states = image_qa.run_batch(
[
{"image": pil_image, "question": "What is this?"},
{"image": pil_image, "question": "What is this?"},
],
max_new_tokens=512,
)
for s in states:
print(s["answer"], "\n")
if __name__ == "__main__":
import multiprocessing as mp
mp.set_start_method("spawn", force=True)
runtime = sgl.Runtime(
model_path="lmms-lab/llama3-llava-next-8b",
tokenizer_path="lmms-lab/llama3-llava-next-8b-tokenizer",
)
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct")
# runtime = sgl.Runtime(
# model_path="lmms-lab/llava-next-72b",
# tokenizer_path="lmms-lab/llavanext-qwen-tokenizer",
# )
# runtime.endpoint.chat_template = get_chat_template("chatml-llava")
sgl.set_default_backend(runtime)
print(f"chat template: {runtime.endpoint.chat_template.name}")
# Or you can use API models
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
# Run a single request
print("\n========== single ==========\n")
single()
# Stream output
print("\n========== stream ==========\n")
stream()
# Run a batch of requests
print("\n========== batch ==========\n")
batch()
runtime.shutdown()
...@@ -328,4 +328,4 @@ def monkey_path_clip_vision_embed_forward(): ...@@ -328,4 +328,4 @@ def monkey_path_clip_vision_embed_forward():
) )
EntryClass = LlavaLlamaForCausalLM EntryClass = LlavaLlamaForCausalLM
\ No newline at end of file
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from typing import List, Optional
import numpy as np
import torch
from torch import nn
from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, MistralConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
unpad_image_shape,
)
from sglang.srt.models.mistral import MistralForCausalLM
class LlavaMistralForCausalLM(nn.Module):
def __init__(
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.vision_tower = None
if getattr(self.config, "vision_config", None) is None:
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
if getattr(self.config, "text_config", None) is None:
self.config.text_config = MistralConfig(self.config._name_or_path)
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 32000
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = MistralForCausalLM(config, quant_config=quant_config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
new_image_feature_len = self.image_feature_len
# now only support spatial_unpad + anyres
if self.mm_patch_merge_type.startswith("spatial"):
height = width = self.num_patches_per_side
if pt_shape[0] > 1:
if self.image_aspect_ratio == "anyres":
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
if "unpad" in self.mm_patch_merge_type:
h = num_patch_height * height
w = num_patch_width * width
new_h, new_w = unpad_image_shape(h, w, image_size)
new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :]
)
return new_input_ids, offset
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
)
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
pixel_values: Optional[List[Optional[np.array]]] = None,
image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size
# Embed text input
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input
need_vision = (
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
.cpu()
.numpy()
)
# FIXME: We need to substract the length of the system prompt
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
need_vision = need_vision & has_pixel
if need_vision.any():
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
########## Encode Image ########
if pixel_values[0].ndim == 4:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
np.concatenate(pixel_values, axis=0)
# ndim=4
concat_images = torch.tensor(
np.concatenate(pixel_values, axis=0),
device=self.vision_tower.device,
)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# hd image_features: BS, num_patch, 576, 4096
else:
# normal pixel: BS, C=3, H=336, W=336
pixel_values = torch.tensor(
np.array(pixel_values), device=self.vision_tower.device
)
image_features = self.encode_images(pixel_values)
# image_features: BS, 576, 4096
if self.mm_patch_merge_type.startswith("spatial"):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.num_patches_per_side
assert height * width == base_image_feature.shape[0]
if self.image_aspect_ratio == "anyres":
(
num_patch_width,
num_patch_height,
) = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
else:
raise NotImplementedError()
if "unpad" in self.mm_patch_merge_type:
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(
2, 3
)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.language_model.model.image_newline[
:, None, None
].expand(*image_feature.shape[:-1], 1),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(
0, 1
)
else:
image_feature = image_feature.permute(
0, 2, 1, 3, 4
).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
if "unpad" in self.mm_patch_merge_type:
image_feature = torch.cat(
(
image_feature,
self.language_model.model.image_newline[None],
),
dim=0,
)
new_image_features.append(image_feature)
image_features = new_image_features
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
pt = 0
for i in range(bs):
if not need_vision[i]:
continue
start_idx = extend_start_loc_cpu[i]
pad_len, pad_dim = image_features[pt].shape # 576, 4096
dim = input_embeds.shape[1]
assert (
pad_dim == dim
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
# Fill in the placeholder for the image
try:
input_embeds[
start_idx
+ image_offsets[i] : start_idx
+ image_offsets[i]
+ pad_len
] = image_features[pt]
except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}")
print(input_embeds.shape)
print(start_idx, image_offsets[i])
pt += 1
return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
)
elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model(input_ids, positions, input_metadata)
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
self.vision_tower.eval()
self.vision_feature_layer = self.config.mm_vision_select_layer
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
self.image_size = self.vision_tower.config.image_size
self.patch_size = self.vision_tower.config.patch_size
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
if self.vision_feature_select_strategy == "patch":
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
# load mm_projector
projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
# FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
# load language model
self.language_model.load_weights(
model_name_or_path, cache_dir, load_format, revision
)
monkey_path_clip_vision_embed_forward()
@property
def num_patches_per_side(self):
return self.image_size // self.patch_size
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global first_call
if first_call:
self.patch_embedding.cpu().float()
first_call = False
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def monkey_path_clip_vision_embed_forward():
import transformers
setattr(
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
"forward",
clip_vision_embed_forward,
)
EntryClass = LlavaMistralForCausalLM
"""Inference-only LLaVa model compatible with HuggingFace weights."""
from typing import List, Optional
import numpy as np
import torch
from torch import nn
from transformers import CLIPVisionModel, LlavaConfig, CLIPVisionConfig, Qwen2Config
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
unpad_image_shape,
)
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
class LlavaQwenForCausalLM(nn.Module):
def __init__(
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.vision_tower = None
if getattr(self.config, "vision_config", None) is None:
self.config.vision_config = CLIPVisionConfig(self.config.mm_vision_tower)
if getattr(self.config, "text_config", None) is None:
self.config.text_config = Qwen2Config(self.config._name_or_path)
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
if getattr(self.config, "projector_hidden_act", None) is None:
self.config.projector_hidden_act = "gelu"
if getattr(self.config, "image_token_index", None) is None:
self.config.image_token_index = 151646
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.language_model = Qwen2ForCausalLM(config, quant_config=quant_config)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
new_image_feature_len = self.image_feature_len
# now only support spatial_unpad + anyres
if self.mm_patch_merge_type.startswith("spatial"):
height = width = self.num_patches_per_side
if pt_shape[0] > 1:
if self.image_aspect_ratio == "anyres":
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size,
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
if "unpad" in self.mm_patch_merge_type:
h = num_patch_height * height
w = num_patch_width * width
new_h, new_w = unpad_image_shape(h, w, image_size)
new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :]
)
return new_input_ids, offset
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
)
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
pixel_values: Optional[List[Optional[np.array]]] = None,
image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size
# Embed text input
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input
need_vision = (
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
.cpu()
.numpy()
)
# FIXME: We need to substract the length of the system prompt
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
need_vision = need_vision & has_pixel
if need_vision.any():
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
########## Encode Image ########
if pixel_values[0].ndim == 4:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
np.concatenate(pixel_values, axis=0)
# ndim=4
concat_images = torch.tensor(
np.concatenate(pixel_values, axis=0),
device=self.vision_tower.device,
)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# hd image_features: BS, num_patch, 576, 4096
else:
# normal pixel: BS, C=3, H=336, W=336
pixel_values = torch.tensor(
np.array(pixel_values), device=self.vision_tower.device
)
image_features = self.encode_images(pixel_values)
# image_features: BS, 576, 4096
if self.mm_patch_merge_type.startswith("spatial"):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.num_patches_per_side
assert height * width == base_image_feature.shape[0]
if self.image_aspect_ratio == "anyres":
(
num_patch_width,
num_patch_height,
) = get_anyres_image_grid_shape(
image_sizes[image_idx],
self.image_grid_pinpoints,
self.vision_tower.config.image_size,
)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height, width, -1
)
else:
raise NotImplementedError()
if "unpad" in self.mm_patch_merge_type:
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(
2, 3
)
image_feature = unpad_image(
image_feature, image_sizes[image_idx]
)
image_feature = torch.cat(
(
image_feature,
self.language_model.model.image_newline[
:, None, None
].expand(*image_feature.shape[:-1], 1),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(
0, 1
)
else:
image_feature = image_feature.permute(
0, 2, 1, 3, 4
).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0
)
else:
image_feature = image_feature[0]
if "unpad" in self.mm_patch_merge_type:
image_feature = torch.cat(
(
image_feature,
self.language_model.model.image_newline[None],
),
dim=0,
)
new_image_features.append(image_feature)
image_features = new_image_features
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
pt = 0
for i in range(bs):
if not need_vision[i]:
continue
start_idx = extend_start_loc_cpu[i]
pad_len, pad_dim = image_features[pt].shape # 576, 4096
dim = input_embeds.shape[1]
assert (
pad_dim == dim
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
# Fill in the placeholder for the image
try:
input_embeds[
start_idx
+ image_offsets[i] : start_idx
+ image_offsets[i]
+ pad_len
] = image_features[pt]
except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}")
print(input_embeds.shape)
print(start_idx, image_offsets[i])
pt += 1
return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
)
elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model(input_ids, positions, input_metadata)
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
self.vision_tower.eval()
self.vision_feature_layer = self.config.mm_vision_select_layer
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
self.image_size = self.vision_tower.config.image_size
self.patch_size = self.vision_tower.config.patch_size
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
if self.vision_feature_select_strategy == "patch":
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
# load mm_projector
projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
# FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
# load language model
self.language_model.load_weights(
model_name_or_path, cache_dir, load_format, revision
)
monkey_path_clip_vision_embed_forward()
@property
def num_patches_per_side(self):
return self.image_size // self.patch_size
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global first_call
if first_call:
self.patch_embedding.cpu().float()
first_call = False
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def monkey_path_clip_vision_embed_forward():
import transformers
setattr(
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
"forward",
clip_vision_embed_forward,
)
EntryClass = LlavaQwenForCausalLM
...@@ -303,6 +303,8 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -303,6 +303,8 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
...@@ -311,6 +313,8 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -311,6 +313,8 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
if name.startswith("model.vision_tower") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment