Unverified Commit 0992d85f authored by Yuanhan Zhang's avatar Yuanhan Zhang Committed by GitHub
Browse files

support llava video (#426)

parent 5dc55a5f
......@@ -10,12 +10,16 @@ class ModelConfig:
trust_remote_code: bool = True,
revision: Optional[str] = None,
context_length: Optional[int] = None,
model_overide_args: Optional[dict] = None,
) -> None:
self.path = path
self.trust_remote_code = trust_remote_code
self.revision = revision
self.hf_config = get_config(self.path, trust_remote_code, revision)
if model_overide_args is not None:
self.hf_config.update(model_overide_args)
if context_length is not None:
self.context_len = context_length
else:
......
......@@ -27,29 +27,25 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn.parameter import Parameter
from transformers import PretrainedConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
@torch.compile
......
......@@ -5,37 +5,31 @@ from typing import Optional
import torch
import torch.nn as nn
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.models.dbrx_config import DbrxConfig
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class DbrxRouter(nn.Module):
......@@ -291,7 +285,9 @@ class DbrxBlock(nn.Module):
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.norm_attn_norm = DbrxFusedNormAttention(config, layer_id, quant_config=quant_config)
self.norm_attn_norm = DbrxFusedNormAttention(
config, layer_id, quant_config=quant_config
)
self.ffn = DbrxExperts(config, quant_config=quant_config)
def forward(
......@@ -322,7 +318,10 @@ class DbrxModel(nn.Module):
config.d_model,
)
self.blocks = nn.ModuleList(
[DbrxBlock(config, i, quant_config=quant_config) for i in range(config.n_layers)]
[
DbrxBlock(config, i, quant_config=quant_config)
for i in range(config.n_layers)
]
)
self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
for module in self.modules():
......
......@@ -7,6 +7,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
......@@ -14,21 +15,14 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class GemmaMLP(nn.Module):
......@@ -46,7 +40,10 @@ class GemmaMLP(nn.Module):
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.act_fn = GeluAndMul()
......
......@@ -6,6 +6,7 @@ from typing import Any, Dict, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
......@@ -13,24 +14,17 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlamaMLP(nn.Module):
......@@ -49,7 +43,10 @@ class LlamaMLP(nn.Module):
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
)
if hidden_act != "silu":
raise ValueError(
......
......@@ -7,12 +7,7 @@ import torch
from torch import nn
from transformers import CLIPVisionModel, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
......@@ -22,6 +17,7 @@ from sglang.srt.mm_utils import (
unpad_image_shape,
)
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaLlamaForCausalLM(nn.Module):
......
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
import os
from typing import List, Optional
import numpy as np
import torch
from torch import nn
from transformers import CLIPVisionModel, LlamaConfig, LlavaConfig
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.router.infer_batch import ForwardMode
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.mm_utils import (
get_anyres_image_grid_shape,
unpad_image,
unpad_image_shape,
)
from sglang.srt.models.llama2 import LlamaForCausalLM
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class LlavaVidForCausalLM(nn.Module):
def __init__(
self,
config: LlavaConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
self.vision_tower = None
self.config.vision_config.hidden_size = config.mm_hidden_size
self.config.text_config.hidden_size = config.hidden_size
self.multi_modal_projector = LlavaMultiModalProjector(config)
self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2)
self.resampler = nn.AvgPool2d(
kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
)
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
self.num_frames = getattr(self.config, "num_frames", 16)
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
self.language_model.model.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
)
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
new_image_feature_len = self.image_feature_len
# now only support spatial_unpad + anyres
# if self.mm_patch_merge_type.startswith("spatial"):
# height = width = self.num_patches_per_side
# if pt_shape[0] > 1:
# if self.image_aspect_ratio == "anyres":
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
# image_size,
# self.image_grid_pinpoints,
# self.vision_tower.config.image_size,
# )
# if "unpad" in self.mm_patch_merge_type:
# h = num_patch_height * height
# w = num_patch_width * width
# new_h, new_w = unpad_image_shape(h, w, image_size)
# new_image_feature_len += new_h * (new_w + 1)
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
# print(input_ids)
offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id
new_input_ids = (
input_ids[:offset]
+ pad_ids[:new_image_feature_len]
+ input_ids[offset + 1 :]
)
return new_input_ids, offset
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
if self.vision_feature_select_strategy in ["default", "patch"]:
selected_image_feature = selected_image_feature[:, 1:]
elif self.vision_feature_select_strategy == "full":
selected_image_feature = selected_image_feature
else:
raise ValueError(
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
)
height = width = self.num_patches_per_side
num_of_frames = selected_image_feature.shape[0]
selected_image_feature = selected_image_feature.view(
num_of_frames, height, width, -1
)
selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous()
selected_image_feature = (
self.resampler(selected_image_feature)
.flatten(2)
.transpose(1, 2)
.contiguous()
)
image_features = self.multi_modal_projector(selected_image_feature)
return image_features
def forward(
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
pixel_values: Optional[List[Optional[np.array]]] = None,
image_sizes: Optional[List[List[int]]] = None,
image_offsets: Optional[List[int]] = None,
) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size
# Embed text input
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Embed vision input
need_vision = (
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
.cpu()
.numpy()
)
# FIXME: We need to substract the length of the system prompt
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
need_vision = need_vision & has_pixel
if need_vision.any():
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
########## Encode Image ########
if pixel_values[0].ndim == 4:
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
np.concatenate(pixel_values, axis=0)
# ndim=4
concat_images = torch.tensor(
np.concatenate(pixel_values, axis=0),
device=self.vision_tower.device,
)
# image_features = self.encode_images(concat_images)
# split_sizes = [image.shape[0] for image in pixel_values]
# image_features = torch.split(image_features, split_sizes, dim=0)
image_features = self.encode_images(
concat_images
) # , prompts)#, image_counts, long_video=long_video)
split_sizes = [image.shape[0] for image in pixel_values]
image_features = torch.split(image_features, split_sizes, dim=0)
# hd image_features: BS, num_patch, 576, 4096
else:
# normal pixel: BS, C=3, H=336, W=336
pixel_values = torch.tensor(
np.array(pixel_values), device=self.vision_tower.device
)
image_features = self.encode_images(pixel_values)
# image_features: BS, 576, 4096
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
new_image_features.append(image_feature.flatten(0, 1))
image_features = new_image_features
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
pt = 0
for i in range(bs):
if not need_vision[i]:
continue
start_idx = extend_start_loc_cpu[i]
pad_len, pad_dim = image_features[pt].shape # 576, 4096
dim = input_embeds.shape[1]
assert (
pad_dim == dim
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
# Fill in the placeholder for the image
try:
input_embeds[
start_idx
+ image_offsets[i] : start_idx
+ image_offsets[i]
+ pad_len
] = image_features[pt]
except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}")
print(input_embeds.shape)
print(start_idx, image_offsets[i])
pt += 1
return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
)
elif input_metadata.forward_mode == ForwardMode.DECODE:
return self.language_model(input_ids, positions, input_metadata)
def load_weights(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
):
# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = self.config.mm_vision_tower
self.vision_tower = CLIPVisionModel.from_pretrained(
vision_path, torch_dtype=torch.float16
).cuda()
self.vision_tower.eval()
self.vision_feature_layer = self.config.mm_vision_select_layer
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
self.image_size = self.vision_tower.config.image_size
self.patch_size = self.vision_tower.config.patch_size
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
print(f"target_frames: {self.num_frames}")
self.image_feature_len = self.num_frames * int(
(self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2
)
if self.vision_feature_select_strategy == "patch":
pass
elif self.vision_feature_select_strategy == "cls_patch":
self.image_feature_len += 1
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
# load mm_projector
projector_weights = {
"model.mm_projector.0": "multi_modal_projector.linear_1",
"model.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
}
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision
):
# FIXME: why projector weights read two times?
if "projector" in name or "vision_tower" in name:
for weight_name, param_name in projector_weights.items():
if weight_name in name:
name = name.replace(weight_name, param_name)
if name in params_dict:
param = params_dict[name]
else:
print(f"Warning: {name} not found in the model")
continue
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
# load language model
self.language_model.load_weights(
model_name_or_path, cache_dir, load_format, revision
)
monkey_path_clip_vision_embed_forward()
@property
def num_patches_per_side(self):
return self.image_size // self.patch_size
first_call = True
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
global first_call
if first_call:
self.patch_embedding.cpu().float()
first_call = False
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings
def monkey_path_clip_vision_embed_forward():
import transformers
setattr(
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
"forward",
clip_vision_embed_forward,
)
EntryClass = LlavaVidForCausalLM
......@@ -8,34 +8,28 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.distributed import (
tensor_model_parallel_all_reduce,
)
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class MixtralMLP(nn.Module):
......
......@@ -3,6 +3,7 @@ from typing import Any, Dict, Optional
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
......@@ -10,24 +11,17 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class QWenMLP(nn.Module):
......@@ -132,7 +126,12 @@ class QWenAttention(nn.Module):
class QWenBlock(nn.Module):
def __init__(self, config: PretrainedConfig, layer_id, quant_config: Optional[QuantizationConfig] = None,):
def __init__(
self,
config: PretrainedConfig,
layer_id,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......@@ -181,7 +180,11 @@ class QWenBlock(nn.Module):
class QWenModel(nn.Module):
def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
......@@ -218,7 +221,11 @@ class QWenModel(nn.Module):
class QWenLMHeadModel(nn.Module):
def __init__(self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.config = config
self.transformer = QWenModel(config, quant_config=quant_config)
......@@ -276,4 +283,4 @@ class QWenLMHeadModel(nn.Module):
weight_loader(param, loaded_weight)
EntryClass = QWenLMHeadModel
\ No newline at end of file
EntryClass = QWenLMHeadModel
......@@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Tuple
import torch
from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
......@@ -12,24 +13,17 @@ from vllm.model_executor.layers.linear import (
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
Qwen2Config = None
......@@ -50,7 +44,10 @@ class Qwen2MLP(nn.Module):
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config,
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
)
if hidden_act != "silu":
raise ValueError(
......
......@@ -7,35 +7,31 @@ from typing import Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class StablelmMLP(nn.Module):
def __init__(
self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.config = config
......@@ -48,7 +44,10 @@ class StablelmMLP(nn.Module):
quant_config=quant_config,
)
self.down_proj = RowParallelLinear(
config.intermediate_size, config.hidden_size, bias=False, quant_config=quant_config,
config.intermediate_size,
config.hidden_size,
bias=False,
quant_config=quant_config,
)
self.act_fn = SiluAndMul()
......@@ -181,7 +180,9 @@ class StablelmDecoderLayer(nn.Module):
class StableLMEpochModel(nn.Module):
def __init__(
self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.embed_tokens = VocabParallelEmbedding(
......
......@@ -6,16 +6,13 @@ from typing import List, Optional
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlavaConfig
from sglang.srt.weight_utils import (
default_weight_loader,
hf_model_weights_iterator,
)
from sglang.srt.models.llava import (
LlavaLlamaForCausalLM,
clip_vision_embed_forward,
monkey_path_clip_vision_embed_forward,
)
from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
class YiVLForCausalLM(LlavaLlamaForCausalLM):
......
......@@ -107,7 +107,7 @@ async def openai_v1_chat_completions(raw_request: Request):
return await v1_chat_completions(tokenizer_manager, raw_request)
def launch_server(server_args: ServerArgs, pipe_finish_writer):
def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_args=None):
global tokenizer_manager
logging.basicConfig(
......@@ -140,17 +140,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
)
# Launch processes
tokenizer_manager = TokenizerManager(server_args, port_args)
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
pipe_router_reader, pipe_router_writer = mp.Pipe(duplex=False)
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
proc_router = mp.Process(
target=start_router_process,
args=(
server_args,
port_args,
pipe_router_writer,
),
args=(server_args, port_args, pipe_router_writer, model_overide_args),
)
proc_router.start()
proc_detoken = mp.Process(
......@@ -170,8 +166,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
if router_init_state != "init ok" or detoken_init_state != "init ok":
proc_router.kill()
proc_detoken.kill()
print(f"Initialization failed. router_init_state: {router_init_state}", flush=True)
print(f"Initialization failed. detoken_init_state: {detoken_init_state}", flush=True)
print(
f"Initialization failed. router_init_state: {router_init_state}", flush=True
)
print(
f"Initialization failed. detoken_init_state: {detoken_init_state}",
flush=True,
)
sys.exit(1)
assert proc_router.is_alive() and proc_detoken.is_alive()
......@@ -189,6 +190,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
time.sleep(0.5)
try:
requests.get(url + "/get_model_info", timeout=5, headers=headers)
success = True # Set flag to True if request succeeds
break
except requests.exceptions.RequestException as e:
pass
......@@ -205,7 +207,7 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
},
},
headers=headers,
timeout=60,
timeout=600,
)
assert res.status_code == 200
except Exception as e:
......@@ -235,7 +237,8 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer):
class Runtime:
def __init__(
self,
log_evel="error",
log_evel: str = "error",
model_overide_args: Optional[dict] = None,
*args,
**kwargs,
):
......@@ -244,7 +247,10 @@ class Runtime:
# Pre-allocate ports
self.server_args.port, self.server_args.additional_ports = allocate_init_ports(
self.server_args.port, self.server_args.additional_ports, self.server_args.tp_size)
self.server_args.port,
self.server_args.additional_ports,
self.server_args.tp_size,
)
self.url = self.server_args.url()
self.generate_url = (
......@@ -253,7 +259,10 @@ class Runtime:
self.pid = None
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
proc = mp.Process(target=launch_server, args=(self.server_args, pipe_writer))
proc = mp.Process(
target=launch_server,
args=(self.server_args, pipe_writer, model_overide_args),
)
proc.start()
pipe_writer.close()
self.pid = proc.pid
......@@ -265,7 +274,9 @@ class Runtime:
if init_state != "init ok":
self.shutdown()
raise RuntimeError("Initialization failed. Please see the error messages above.")
raise RuntimeError(
"Initialization failed. Please see the error messages above."
)
self.endpoint = RuntimeEndpoint(self.url)
......@@ -317,4 +328,4 @@ class Runtime:
pos += len(cur)
def __del__(self):
self.shutdown()
\ No newline at end of file
self.shutdown()
......@@ -80,10 +80,12 @@ class ServerArgs:
default=ServerArgs.tokenizer_path,
help="The path of the tokenizer.",
)
parser.add_argument("--host", type=str, default=ServerArgs.host,
help="The host of the server.")
parser.add_argument("--port", type=int, default=ServerArgs.port,
help="The port of the server.")
parser.add_argument(
"--host", type=str, default=ServerArgs.host, help="The host of the server."
)
parser.add_argument(
"--port", type=int, default=ServerArgs.port, help="The port of the server."
)
parser.add_argument(
"--additional-ports",
type=int,
......@@ -261,4 +263,4 @@ class PortArgs:
router_port: int
detokenizer_port: int
nccl_port: int
model_rpc_ports: List[int]
\ No newline at end of file
model_rpc_ports: List[int]
......@@ -131,11 +131,13 @@ def alloc_usable_network_port(num, used_list=()):
continue
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
try:
s.bind(("", port))
s.listen(1) # Attempt to listen on the port
port_list.append(port)
except socket.error:
pass
pass # If any error occurs, this port is not usable
if len(port_list) == num:
return port_list
......@@ -265,20 +267,102 @@ def wrap_kernel_launcher(kernel):
def is_multimodal_model(model):
if isinstance(model, str):
return "llava" in model or "yi-vl" in model
from sglang.srt.model_config import ModelConfig
if isinstance(model, str):
model = model.lower()
return "llava" in model or "yi-vl" in model or "llava-next" in model
if isinstance(model, ModelConfig):
model_path = model.path.lower()
return "llava" in model_path or "yi-vl" in model_path
raise Exception("unrecognized type")
return "llava" in model_path or "yi-vl" in model_path or "llava-next" in model_path
raise ValueError("unrecognized type")
def decode_video_base64(video_base64):
from PIL import Image
# Decode the base64 string
video_bytes = base64.b64decode(video_base64)
# Placeholder for the start indices of each PNG image
img_starts = []
frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG"))
assert frame_format in [
"PNG",
"JPEG",
], "FRAME_FORMAT must be either 'PNG' or 'JPEG'"
if frame_format == "PNG":
# Find each PNG start signature to isolate images
i = 0
while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature
# Check if we found the start of a PNG file
if (
video_bytes[i] == 0x89
and video_bytes[i + 1] == 0x50
and video_bytes[i + 2] == 0x4E
and video_bytes[i + 3] == 0x47
and video_bytes[i + 4] == 0x0D
and video_bytes[i + 5] == 0x0A
and video_bytes[i + 6] == 0x1A
and video_bytes[i + 7] == 0x0A
):
img_starts.append(i)
i += 8 # Skip the PNG signature
else:
i += 1
else:
# Find each JPEG start (0xFFD8) to isolate images
i = 0
while (
i < len(video_bytes) - 1
): # Adjusted for the length of the JPEG SOI signature
# Check if we found the start of a JPEG file
if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8:
img_starts.append(i)
# Move to the next byte to continue searching for the next image start
i += 2
else:
i += 1
frames = []
for start_idx in img_starts:
# Assuming each image is back-to-back, the end of one image is the start of another
# The last image goes until the end of the byte string
end_idx = (
img_starts[img_starts.index(start_idx) + 1]
if img_starts.index(start_idx) + 1 < len(img_starts)
else len(video_bytes)
)
img_bytes = video_bytes[start_idx:end_idx]
# Convert bytes to a PIL Image
img = Image.open(BytesIO(img_bytes))
# Convert PIL Image to a NumPy array
frame = np.array(img)
# Append the frame to the list of frames
frames.append(frame)
# Ensure there's at least one frame to avoid errors with np.stack
if frames:
return np.stack(frames, axis=0), img.size
else:
return np.array([]), (
0,
0,
) # Return an empty array and size tuple if no frames were found
def load_image(image_file):
from PIL import Image
image = None
image = image_size = None
if image_file.startswith("http://") or image_file.startswith("https://"):
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
......@@ -289,10 +373,13 @@ def load_image(image_file):
elif image_file.startswith("data:"):
image_file = image_file.split(",")[1]
image = Image.open(BytesIO(base64.b64decode(image_file)))
elif image_file.startswith("video:"):
image_file = image_file.replace("video:", "")
image, image_size = decode_video_base64(image_file)
else:
image = Image.open(BytesIO(base64.b64decode(image_file)))
return image
return image, image_size
def assert_pkg_version(pkg: str, min_version: str):
......@@ -304,7 +391,9 @@ def assert_pkg_version(pkg: str, min_version: str):
f"is less than the minimum required version {min_version}"
)
except PackageNotFoundError:
raise Exception(f"{pkg} with minimum required version {min_version} is not installed")
raise Exception(
f"{pkg} with minimum required version {min_version} is not installed"
)
API_KEY_HEADER_NAME = "X-API-Key"
......
......@@ -19,11 +19,12 @@ import torch
from huggingface_hub import HfFileSystem, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
get_quantization_config,
)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
logger = init_logger(__name__)
......@@ -32,17 +33,21 @@ logger = init_logger(__name__)
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = os.environ.get('TMPDIR') or os.environ.get(
'TEMP') or os.environ.get('TMP') or "/tmp/"
temp_dir = (
os.environ.get("TMPDIR")
or os.environ.get("TEMP")
or os.environ.get("TMP")
or "/tmp/"
)
def enable_hf_transfer():
"""automatically activates hf_transfer
"""
"""automatically activates hf_transfer"""
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
try:
# enable hf hub transfer if available
import hf_transfer # type: ignore # noqa
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
except ImportError:
pass
......@@ -65,8 +70,7 @@ def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
# add hash to avoid conflict with old users' lock files
lock_file_name = hash_name + model_name + ".lock"
# mode 0o666 is required for the filelock to be shared across users
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
mode=0o666)
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
return lock
......@@ -104,10 +108,12 @@ def convert_bin_to_safetensor_file(
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(f"""The file size different is more than 1%:
raise RuntimeError(
f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
""")
"""
)
# check if the tensors are the same
reloaded = load_file(sf_filename)
......@@ -122,8 +128,7 @@ def convert_bin_to_safetensor_file(
def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None)
hf_quant_config = getattr(model_config.hf_config, "quantization_config", None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
model_name_or_path = model_config.model
......@@ -131,26 +136,29 @@ def get_quant_config(model_config: ModelConfig) -> QuantizationConfig:
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, model_config.download_dir):
hf_folder = snapshot_download(model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=model_config.download_dir,
tqdm_class=Disabledtqdm)
hf_folder = snapshot_download(
model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=model_config.download_dir,
tqdm_class=Disabledtqdm,
)
else:
hf_folder = model_name_or_path
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in quant_cls.get_config_filenames())
f
for f in config_files
if any(f.endswith(x) for x in quant_cls.get_config_filenames())
]
if len(quant_config_files) == 0:
raise ValueError(
f"Cannot find the config file for {model_config.quantization}")
raise ValueError(f"Cannot find the config file for {model_config.quantization}")
if len(quant_config_files) > 1:
raise ValueError(
f"Found multiple config files for {model_config.quantization}: "
f"{quant_config_files}")
f"{quant_config_files}"
)
quant_config_file = quant_config_files[0]
with open(quant_config_file, "r") as f:
......@@ -166,8 +174,7 @@ def prepare_hf_model_weights(
revision: Optional[str] = None,
) -> Tuple[str, List[str], bool]:
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path) \
and load_format != "tensorizer"
is_local = os.path.isdir(model_name_or_path) and load_format != "tensorizer"
use_safetensors = False
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
......@@ -203,11 +210,13 @@ def prepare_hf_model_weights(
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=Disabledtqdm,
revision=revision)
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
cache_dir=cache_dir,
tqdm_class=Disabledtqdm,
revision=revision,
)
else:
hf_folder = model_name_or_path
hf_weights_files: List[str] = []
......@@ -228,16 +237,14 @@ def prepare_hf_model_weights(
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
f for f in hf_weights_files if not any(f.endswith(x) for x in blacklist)
]
if load_format == "tensorizer":
return hf_folder, hf_weights_files, use_safetensors
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
raise RuntimeError(f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
......@@ -254,7 +261,8 @@ def hf_model_weights_iterator(
cache_dir=cache_dir,
load_format=load_format,
fall_back_to_pt=fall_back_to_pt,
revision=revision)
revision=revision,
)
if load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
......@@ -289,22 +297,25 @@ def hf_model_weights_iterator(
param = np.load(f)
yield name, torch.from_numpy(param)
elif load_format == "tensorizer":
from vllm.model_executor.tensorizer_loader import (TensorDeserializer,
open_stream,
tensorizer_warning)
from vllm.model_executor.tensorizer_loader import (
TensorDeserializer,
open_stream,
tensorizer_warning,
)
tensorizer_args = load_format.params
tensorizer_warning(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the examples/tensorize_vllm_model.py example "
"script for serializing vLLM models.")
"script for serializing vLLM models."
)
deserializer_args = tensorizer_args.deserializer_params
stream_params = tensorizer_args.stream_params
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
with TensorDeserializer(stream, **deserializer_args,
device="cpu") as state:
with TensorDeserializer(stream, **deserializer_args, device="cpu") as state:
for name, param in state.items():
yield name, param
del state
......@@ -324,8 +335,12 @@ def hf_model_weights_iterator(
def kv_cache_scales_loader(
filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int,
model_type: Optional[str]) -> Iterable[Tuple[int, float]]:
filename: str,
tp_rank: int,
tp_size: int,
num_hidden_layers: int,
model_type: Optional[str],
) -> Iterable[Tuple[int, float]]:
"""
A simple utility to read in KV cache scaling factors that have been
previously serialized to disk. Used by the model to populate the appropriate
......@@ -343,8 +358,7 @@ def kv_cache_scales_loader(
"tp_size": tp_size,
}
schema_dct = json.load(f)
schema = QuantParamSchema.model_validate(schema_dct,
context=context)
schema = QuantParamSchema.model_validate(schema_dct, context=context)
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
return layer_scales_map.items()
......@@ -357,9 +371,11 @@ def kv_cache_scales_loader(
# This section is reached if and only if any of the excepts are hit
# Return an empty iterable (list) => no KV cache scales are loaded
# which ultimately defaults to 1.0 scales
logger.warning("Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading.")
logger.warning(
"Defaulting to KV cache scaling factors = 1.0 "
f"for all layers in TP rank {tp_rank} "
"as an error occurred during loading."
)
return []
......@@ -378,8 +394,7 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
return x
def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
......@@ -399,4 +414,4 @@ def initialize_dummy_weights(
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
param.data.uniform_(low, high)
\ No newline at end of file
param.data.uniform_(low, high)
......@@ -2,13 +2,16 @@
import base64
import json
import os
import sys
import threading
import traceback
import urllib.request
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from json import dumps
import numpy as np
import requests
......@@ -110,6 +113,74 @@ def encode_image_base64(image_path):
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def encode_frame(frame):
import cv2 # pip install opencv-python-headless
from PIL import Image
# Convert the frame to RGB (OpenCV uses BGR by default)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Convert the frame to PIL Image to easily convert to bytes
im_pil = Image.fromarray(frame)
# Convert to bytes
buffered = BytesIO()
# frame_format = str(os.getenv('FRAME_FORMAT', "JPEG"))
im_pil.save(buffered, format="PNG")
frame_bytes = buffered.getvalue()
# Return the bytes of the frame
return frame_bytes
def encode_video_base64(video_path, num_frames=16):
import cv2
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Could not open video file:{video_path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"target_frames: {num_frames}")
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for i in range(total_frames):
ret, frame = cap.read()
if ret:
frames.append(frame)
else:
# Handle the case where the frame could not be read
# print(f"Warning: Could not read frame at index {i}.")
pass
cap.release()
# Safely select frames based on frame_indices, avoiding IndexError
frames = [frames[i] for i in frame_indices if i < len(frames)]
# If there are not enough frames, duplicate the last frame until we reach the target
while len(frames) < num_frames:
frames.append(frames[-1])
# Use ThreadPoolExecutor to process and encode frames in parallel
with ThreadPoolExecutor() as executor:
encoded_frames = list(executor.map(encode_frame, frames))
# encoded_frames = list(map(encode_frame, frames))
# Concatenate all frames bytes
video_bytes = b"".join(encoded_frames)
# Encode the concatenated bytes to base64
video_base64 = "video:" + base64.b64encode(video_bytes).decode("utf-8")
return video_base64
def _is_chinese_char(cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
......@@ -170,4 +241,4 @@ def run_with_timeout(func, args=(), kwargs=None, timeout=None):
if not ret_value:
raise RuntimeError()
return ret_value[0]
\ No newline at end of file
return ret_value[0]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment