Commit 07dbc76b authored by dongchy920's avatar dongchy920
Browse files

MiniGemini_pytorch

parents
import os
from .clip_encoder import CLIPVisionTower
from .eva_encoder import EVAVisionTower
from .openclip_encoder import OpenCLIPVisionTower
def build_vision_tower(vision_tower_cfg, **kwargs):
vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
image_processor = getattr(vision_tower_cfg, 'image_processor', getattr(vision_tower_cfg, 'image_processor', "../processor/clip-patch14-224"))
if not os.path.exists(vision_tower):
raise ValueError(f'Not find vision tower: {vision_tower}')
if "openai" in vision_tower.lower() or "ShareGPT4V" in vision_tower:
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
elif "lavis" in vision_tower.lower() or "eva" in vision_tower.lower():
return EVAVisionTower(vision_tower, image_processor, args=vision_tower_cfg, **kwargs)
else:
raise ValueError(f'Unknown vision tower: {vision_tower}')
def build_vision_tower_aux(vision_tower_cfg, **kwargs):
vision_tower_aux = getattr(vision_tower_cfg, 'mm_vision_tower_aux', getattr(vision_tower_cfg, 'vision_tower_aux', None))
if not os.path.exists(vision_tower_aux):
raise ValueError(f'Not find vision tower: {vision_tower_aux}')
if "openclip" in vision_tower_aux.lower():
return OpenCLIPVisionTower(vision_tower_aux, args=vision_tower_cfg, **kwargs)
elif "openai" in vision_tower_aux.lower():
return CLIPVisionTower(vision_tower_aux, args=vision_tower_cfg, **kwargs)
else:
raise ValueError(f'Unknown vision tower: {vision_tower_aux}')
\ No newline at end of file
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
from ..processor.video_processor import VideoFramesProcessor
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.is_optimize = getattr(args, 'optimize_vision_tower', False)
if not delay_load:
self.load_model()
elif getattr(args, 'unfreeze_mm_vision_tower', False):
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
def load_model(self):
self.image_processor = VideoFramesProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
def image_forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features
def forward(self, images):
if not self.is_optimize:
with torch.no_grad():
image_features = self.image_forward(images)
else:
image_features = self.image_forward(images)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
# Based on EVA, BEIT, timm and DeiT code bases
# https://github.com/baaivision/EVA
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/facebookresearch/deit/
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
import math
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from transformers import CLIPImageProcessor, CLIPVisionConfig
from ..processor.video_processor import VideoFramesProcessor
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
**kwargs
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
window_size=None, attn_head_dim=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values is not None and init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x, rel_pos_bias=None):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, **kwargs):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class RelativePositionBias(nn.Module):
def __init__(self, window_size, num_heads):
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
# trunc_normal_(self.relative_position_bias_table, std=.02)
def forward(self):
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
class VisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
super().__init__()
self.image_size = img_size
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
self.use_checkpoint = use_checkpoint
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.use_rel_pos_bias = use_rel_pos_bias
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
for i in range(depth)])
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
# trunc_normal_(self.mask_token, std=.02)
# if isinstance(self.head, nn.Linear):
# trunc_normal_(self.head.weight, std=.02)
self.apply(self._init_weights)
self.fix_init_weight()
# if isinstance(self.head, nn.Linear):
# self.head.weight.data.mul_(init_scale)
# self.head.bias.data.mul_(init_scale)
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
else:
x = blk(x, rel_pos_bias)
return x
# x = self.norm(x)
# if self.fc_norm is not None:
# t = x[:, 1:, :]
# return self.fc_norm(t.mean(1))
# else:
# return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
# x = self.head(x)
return x
def get_intermediate_layers(self, x):
x = self.patch_embed(x)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
features = []
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
x = blk(x, rel_pos_bias)
features.append(x)
return features
@property
def dtype(self):
return self.cls_token.dtype
@property
def device(self):
return self.cls_token.device
def get_num_layer(self, var_name=""):
if var_name in ("cls_token", "mask_token", "pos_embed"):
return 0
elif var_name.startswith("patch_embed"):
return 0
elif var_name.startswith("rel_pos_bias"):
return len(self.blocks) - 1
elif var_name.startswith("blocks"):
layer_id = int(var_name.split('.')[1])
return layer_id + 1
else:
return len(self.blocks)
def interpolate_pos_embed(model, checkpoint_model):
if 'pos_embed' in checkpoint_model:
pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
embedding_size = pos_embed_checkpoint.shape[-1]
num_patches = model.patch_embed.num_patches
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
# height (== width) for the checkpoint position embedding
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
# height (== width) for the new position embedding
new_size = int(num_patches ** 0.5)
# class_token and dist_token are kept unchanged
if orig_size != new_size:
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
checkpoint_model['pos_embed'] = new_pos_embed
def convert_weights_to_fp16(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
# if isinstance(l, (nn.MultiheadAttention, Attention)):
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
# tensor = getattr(l, attr)
# if tensor is not None:
# tensor.data = tensor.data.half()
model.apply(_convert_weights_to_fp16)
class EVAVisionTower(nn.Module):
def __init__(self, vision_tower, image_processor, args, use_checkpoint=False, drop_path_rate=0.0, delay_load=False, dtype=torch.float32):
super().__init__()
self.is_loaded = False
self.use_checkpoint = use_checkpoint
self.vision_tower_name = vision_tower
self.image_processor_name = image_processor
self.drop_path_rate = drop_path_rate
self.patch_size = 14
self.out_channel = 1408
if not delay_load:
self.load_model()
self.vision_config = CLIPVisionConfig.from_pretrained(image_processor)
def load_model(self):
# self.image_processor = CLIPImageProcessor.from_pretrained(self.image_processor_name)
self.image_processor = VideoFramesProcessor.from_pretrained(self.image_processor_name)
self.vision_tower = VisionTransformer(
img_size=self.image_processor.size['shortest_edge'],
patch_size=self.patch_size,
use_mean_pooling=False,
embed_dim=self.out_channel,
depth=39,
num_heads=self.out_channel//88,
mlp_ratio=4.3637,
qkv_bias=True,
drop_path_rate=self.drop_path_rate,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
use_checkpoint=self.use_checkpoint,
)
state_dict = torch.load(self.vision_tower_name, map_location="cpu")
interpolate_pos_embed(self.vision_tower, state_dict)
incompatible_keys = self.vision_tower.load_state_dict(state_dict, strict=False)
print(incompatible_keys)
self.vision_tower.requires_grad_(False)
self.is_loaded = True
@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
image_feature = image_forward_out.to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype))
image_features = image_forward_outs.to(images.dtype)
return image_features
def feature_select(self, image_features):
# image_features = image_features.hidden_states[self.select_layer]
if self.select_feature == 'patch':
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def config(self):
return self.vision_config
@property
def hidden_size(self):
return self.out_channel
@property
def num_patches(self):
return (self.image_processor.size['shortest_edge'] // self.patch_size) ** 2
def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,model_path=None,precision="fp16"):
model = VisionTransformer(
img_size=img_size,
patch_size=14,
use_mean_pooling=False,
embed_dim=1408,
depth=39,
num_heads=1408//88,
mlp_ratio=4.3637,
qkv_bias=True,
drop_path_rate=drop_path_rate,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
use_checkpoint=use_checkpoint,
)
# url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
# cached_file = download_cached_file(
# url, check_hash=False, progress=True
# )
state_dict = torch.load(model_path, map_location="cpu")
interpolate_pos_embed(model,state_dict)
incompatible_keys = model.load_state_dict(state_dict, strict=False)
print(incompatible_keys)
if precision == "fp16":
convert_weights_to_fp16(model)
return model
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import json
import logging
import deepspeed
from pathlib import Path
from open_clip.factory import load_state_dict, get_model_config
from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower, convert_to_custom_text_state_dict, resize_pos_embed
from typing import Dict, Optional
from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
class OpenCLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.vision_config = json.load(open(os.path.join(vision_tower,'open_clip_config.json'), 'r'))
self.is_optimize = getattr(args, 'optimize_vision_tower_aux', False)
self.is_droppath = getattr(args, 'drop_path', True)
if not delay_load:
self.load_model()
def load_model(self):
ckpt_path = os.path.join(self.vision_tower_name, 'open_clip_pytorch_model.bin')
if 'convnext' in self.vision_tower_name:
if 'large' in self.vision_tower_name and 'd-320' in self.vision_tower_name:
self.model_type = 'convnext_large_d_320'
self.model_channel = [192, 384, 768, 1536] # stage 0-3
elif 'base' in self.vision_tower_name and 'w-320' in self.vision_tower_name:
self.model_type = 'convnext_base_w_320'
self.model_channel = [128, 256, 512, 1024]
elif 'xxlarge' in self.vision_tower_name:
self.model_type = 'convnext_xxlarge'
self.model_channel = [384, 768, 1536, 3072]
clip_model = CLIP(**get_model_config(self.model_type), drop_path=self.is_droppath)
clip_model.visual.trunk.norm_pre = None
clip_model.visual.trunk.head = None
clip_model.visual.head = None
print(f'Loading pretrained weights ({self.model_type}).')
load_checkpoint(clip_model, ckpt_path, strict=False)
self.is_loaded = True
# decompose stem and stages blocks in vision tower
self.vision_stem = clip_model.visual.trunk.stem
self.vision_stages = clip_model.visual.trunk.stages
self.vision_stem.requires_grad_(False)
self.vision_stages.requires_grad_(False)
def forward(self, images):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.backbone(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
image_features.append(image_feature)
else:
image_features = self.backbone(images.to(device=self.device, dtype=self.dtype))
return image_features
def backbone(self, images):
if not self.is_optimize:
with torch.no_grad():
results = self.basic_forward(images)
else:
results = self.basic_forward(images)
target_size = (results['stage_0'].shape[-2], results['stage_0'].shape[-1])
result_cat = []
for _stage in results:
if _stage == 'stage_0':
result_cat.append(results[_stage].contiguous())
else:
result_cat.append(F.interpolate(results[_stage].float().contiguous() ,
size=target_size,
mode='bilinear',
align_corners=False).to(dtype=results[_stage].dtype))
result_cat = torch.cat(result_cat, dim=1)
return result_cat.contiguous()
def basic_forward(self, images):
results = {}
x = self.vision_stem(images)
for _idx in range(len(self.vision_stages)):
x = self.vision_stages[_idx](x)
results[f'stage_{_idx}'] = x
return results
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_stem[0].weight.dtype
@property
def device(self):
return self.vision_stem[0].weight.device
@property
def config(self):
return self.vision_config
@property
def hidden_size(self):
return sum(self.model_channel)
# modified function from open_clip to support zero3 stage
def load_checkpoint(model, checkpoint_path, strict=True):
if Path(checkpoint_path).suffix in ('.npz', '.npy'):
from open_clip.big_vision import load_big_vision_weights
load_big_vision_weights(model, checkpoint_path)
return {}
state_dict = load_state_dict(checkpoint_path)
# detect old format and make compatible with new format
if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
state_dict = convert_to_custom_text_state_dict(state_dict)
# If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
# if 'logit_bias' not in state_dict and model.logit_bias is not None:
# state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
# Certain text transformers no longer expect position_ids after transformers==4.31
position_id_key = 'text.transformer.embeddings.position_ids'
if position_id_key in state_dict and not hasattr(model, position_id_key):
del state_dict[position_id_key]
resize_pos_embed(state_dict, model)
# resize_text_pos_embed(state_dict, model)
#incompatible_keys = model.load_state_dict(state_dict, strict=strict)
if is_deepspeed_zero3_enabled():
error_msgs = []
def load(module: nn.Module, state_dict, prefix=""):
metadata = None
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
# Parameters of module and children will start with prefix. We can exit early if there are none in this
# state_dict
if len([key for key in state_dict if key.startswith(prefix)]) > 0:
if is_deepspeed_zero3_enabled():
# In sharded models, each shard has only part of the full state_dict, so only gather
# parameters that are in the current state_dict.
named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
if len(params_to_gather) > 0:
# because zero3 puts placeholders in model params, this context
# manager gathers (unpartitions) the params of the current layer, then loads from
# the state dict and then re-partitions them again
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, state_dict, prefix + name + ".")
load(model, state_dict)
incompatible_keys = []
else:
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
return incompatible_keys
class CLIP(nn.Module):
output_dict: torch.jit.Final[bool]
def __init__(
self,
embed_dim: int,
vision_cfg: CLIPVisionCfg,
text_cfg: CLIPTextCfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
output_dict: bool = False,
drop_path: bool = False,
):
super().__init__()
self.output_dict = output_dict
# Fix drop path during training
if not drop_path:
print('Not using drop path during training.')
vision_cfg['timm_drop_path'] = 0.0
self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
import torch
import torch.nn as nn
import re
class IdentityMap(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": 'identity'}
class SimpleResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.pre_norm = nn.LayerNorm(channels)
self.proj = nn.Sequential(
nn.Linear(channels, channels),
nn.GELU(),
nn.Linear(channels, channels)
)
def forward(self, x):
x = self.pre_norm(x)
return x + self.proj(x)
def build_vision_projector(config, delay_load=False, **kwargs):
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return IdentityMap()
raise ValueError(f'Unknown projector type: {projector_type}')
\ No newline at end of file
from transformers import CLIPImageProcessor
from transformers.image_processing_utils import BatchFeature, get_size_dict
from transformers.image_transforms import get_resize_output_image_size
import torch
import torch.nn.functional as F
import numpy as np
class VideoFramesProcessor(CLIPImageProcessor):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def preprocess(self, images, **kwargs):
if not isinstance(images, np.ndarray):
return super().preprocess(images=images, **kwargs)
do_resize = kwargs.get('do_resize', self.do_resize)
size = kwargs.get('size', self.size)
size = get_size_dict(size, param_name="size", default_to_square=False)
do_center_crop = kwargs.get('do_center_crop', self.do_center_crop)
crop_size = kwargs.get('crop_size', self.crop_size)
crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
do_rescale = kwargs.get('do_rescale', self.do_rescale)
rescale_factor = kwargs.get('rescale_factor', self.rescale_factor)
do_normalize = kwargs.get('do_normalize', self.do_normalize)
image_mean = kwargs.get('image_mean', self.image_mean)
image_std = kwargs.get('image_std', self.image_std)
return_tensors = kwargs.get('return_tensors', None)
def resize(images, output_size):
images = images.permute((0, 3, 1, 2))
images = F.interpolate(images, size=output_size, mode='bicubic')
images = images.permute((0, 2, 3, 1))
return images
def center_crop(images, crop_size):
crop_width, crop_height = crop_size["width"], crop_size["height"]
img_width, img_height = images.shape[1:3]
x = (img_width - crop_width) // 2
y = (img_height - crop_height) // 2
images = images[:, x:x+crop_width, y:y+crop_height]
return images
def rescale(images, rescale_factor):
images = images * rescale_factor
return images
def normalize(images, mean, std):
mean = torch.tensor(mean)
std = torch.tensor(std)
images = (images - mean) / std
return images
images = torch.from_numpy(images).float()
if do_resize:
output_size = get_resize_output_image_size(images[0], size=size["shortest_edge"], default_to_square=False)
images = resize(images, output_size)
if do_center_crop:
images = center_crop(images, crop_size)
if do_rescale:
images = rescale(images, rescale_factor)
if do_normalize:
images = normalize(images, image_mean, image_std)
images = images.permute((0, 3, 1, 2))
data = {"pixel_values": images}
return BatchFeature(data=data, tensor_type=return_tensors)
import argparse
import torch
from mgm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from mgm.conversation import conv_templates, SeparatorStyle
from mgm.model.builder import load_pretrained_model
from mgm.utils import disable_torch_init
from mgm.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
try:
from diffusers import StableDiffusionXLPipeline
except:
print('please install diffusers==0.26.3')
try:
from paddleocr import PaddleOCR
except:
print('please install paddleocr following https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/README_en.md')
def load_image(image_file):
if image_file.startswith('http://') or image_file.startswith('https://'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def main(args):
# Model
disable_torch_init()
if args.ocr and args.image_file is not None:
ocr = PaddleOCR(use_angle_cls=True, use_gpu=True, lang="ch")
result = ocr.ocr(args.image_file)
str_in_image = ''
if result[0] is not None:
result = [res[1][0] for res in result[0] if res[1][1] > 0.1]
if len(result) > 0:
str_in_image = ', '.join(result)
print('OCR Token: ' + str_in_image)
if args.gen:
# import pdb
# pdb.set_trace()
pipe = StableDiffusionXLPipeline.from_pretrained(
"stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
).to("cuda")
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
if '8x7b' in model_name.lower():
conv_mode = "mistral_instruct"
elif '34b' in model_name.lower():
conv_mode = "chatml_direct"
elif '2b' in model_name.lower():
conv_mode = "gemma"
else:
conv_mode = "vicuna_v1"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
else:
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
if "mpt" in model_name.lower():
roles = ('user', 'assistant')
else:
roles = conv.roles
if args.image_file is not None:
images = []
if ',' in args.image_file:
images = args.image_file.split(',')
else:
images = [args.image_file]
image_convert = []
for _image in images:
image_convert.append(load_image(_image))
if hasattr(model.config, 'image_size_aux'):
if not hasattr(image_processor, 'image_size_raw'):
image_processor.image_size_raw = image_processor.crop_size.copy()
image_processor.crop_size['height'] = model.config.image_size_aux
image_processor.crop_size['width'] = model.config.image_size_aux
image_processor.size['shortest_edge'] = model.config.image_size_aux
# Similar operation in model_worker.py
image_tensor = process_images(image_convert, image_processor, model.config)
image_grid = getattr(model.config, 'image_grid', 1)
if hasattr(model.config, 'image_size_aux'):
raw_shape = [image_processor.image_size_raw['height'] * image_grid,
image_processor.image_size_raw['width'] * image_grid]
image_tensor_aux = image_tensor
image_tensor = torch.nn.functional.interpolate(image_tensor,
size=raw_shape,
mode='bilinear',
align_corners=False)
else:
image_tensor_aux = []
if image_grid >= 2:
raw_image = image_tensor.reshape(3,
image_grid,
image_processor.image_size_raw['height'],
image_grid,
image_processor.image_size_raw['width'])
raw_image = raw_image.permute(1, 3, 0, 2, 4)
raw_image = raw_image.reshape(-1, 3,
image_processor.image_size_raw['height'],
image_processor.image_size_raw['width'])
if getattr(model.config, 'image_global', False):
global_image = image_tensor
if len(global_image.shape) == 3:
global_image = global_image[None]
global_image = torch.nn.functional.interpolate(global_image,
size=[image_processor.image_size_raw['height'],
image_processor.image_size_raw['width']],
mode='bilinear',
align_corners=False)
# [image_crops, image_global]
raw_image = torch.cat([raw_image, global_image], dim=0)
image_tensor = raw_image.contiguous()
image_tensor = image_tensor.unsqueeze(0)
if type(image_tensor) is list:
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
image_tensor_aux = [image.to(model.device, dtype=torch.float16) for image in image_tensor_aux]
else:
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
image_tensor_aux = image_tensor_aux.to(model.device, dtype=torch.float16)
else:
images = None
image_tensor = None
image_tensor_aux = []
while True:
try:
inp = input(f"{roles[0]}: ")
except EOFError:
inp = ""
if not inp:
print("exit...")
break
print(f"{roles[1]}: ", end="")
if args.ocr and len(str_in_image) > 0:
inp = inp + '\nReference OCR Token: ' + str_in_image + '\n'
if args.gen:
inp = inp + ' <GEN>'
# print(inp, '====')
if images is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = (DEFAULT_IMAGE_TOKEN + '\n')*len(images) + inp
conv.append_message(conv.roles[0], inp)
images = None
else:
# later messages
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# add image split string
if prompt.count(DEFAULT_IMAGE_TOKEN) >= 2:
final_str = ''
sent_split = prompt.split(DEFAULT_IMAGE_TOKEN)
for _idx, _sub_sent in enumerate(sent_split):
if _idx == len(sent_split) - 1:
final_str = final_str + _sub_sent
else:
final_str = final_str + _sub_sent + f'Image {_idx+1}:' + DEFAULT_IMAGE_TOKEN
prompt = final_str
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
images_aux=image_tensor_aux if len(image_tensor_aux)>0 else None,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
bos_token_id=tokenizer.bos_token_id, # Begin of sequence token
eos_token_id=tokenizer.eos_token_id, # End of sequence token
pad_token_id=tokenizer.pad_token_id, # Pad token
streamer=streamer,
use_cache=True)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
conv.messages[-1][-1] = outputs
if args.gen and '<h>' in outputs and '</h>' in outputs:
common_neg_prompt = "out of frame, lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature"
prompt = outputs.split("</h>")[-2].split("<h>")[-1]
output_img = pipe(prompt, negative_prompt=common_neg_prompt).images[0]
output_img.save(args.output_file)
print(f'Generate an image, save at {args.output_file}')
if args.debug:
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, default=None) # file_0.jpg,file_1.jpg for multi image
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--ocr", action="store_true")
parser.add_argument("--gen", action="store_true")
parser.add_argument("--output-file", type=str, default='generate.png')
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
main(args)
\ No newline at end of file
"""
A controller manages distributed workers.
It sends worker addresses to clients.
"""
import argparse
import asyncio
import dataclasses
from enum import Enum, auto
import json
import logging
import time
from typing import List, Union
import threading
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import numpy as np
import requests
import uvicorn
from mgm.constants import CONTROLLER_HEART_BEAT_EXPIRATION
from mgm.utils import build_logger, server_error_msg
logger = build_logger("controller", "controller.log")
class DispatchMethod(Enum):
LOTTERY = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, name):
if name == "lottery":
return cls.LOTTERY
elif name == "shortest_queue":
return cls.SHORTEST_QUEUE
else:
raise ValueError(f"Invalid dispatch method")
@dataclasses.dataclass
class WorkerInfo:
model_names: List[str]
speed: int
queue_length: int
check_heart_beat: bool
last_heart_beat: str
def heart_beat_controller(controller):
while True:
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
controller.remove_stable_workers_by_expiration()
class Controller:
def __init__(self, dispatch_method: str):
# Dict[str -> WorkerInfo]
self.worker_info = {}
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
self.heart_beat_thread = threading.Thread(
target=heart_beat_controller, args=(self,))
self.heart_beat_thread.start()
logger.info("Init controller")
def register_worker(self, worker_name: str, check_heart_beat: bool,
worker_status: dict):
if worker_name not in self.worker_info:
logger.info(f"Register a new worker: {worker_name}")
else:
logger.info(f"Register an existing worker: {worker_name}")
if not worker_status:
worker_status = self.get_worker_status(worker_name)
if not worker_status:
return False
self.worker_info[worker_name] = WorkerInfo(
worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
check_heart_beat, time.time())
logger.info(f"Register done: {worker_name}, {worker_status}")
return True
def get_worker_status(self, worker_name: str):
try:
r = requests.post(worker_name + "/worker_get_status", timeout=5)
except requests.exceptions.RequestException as e:
logger.error(f"Get status fails: {worker_name}, {e}")
return None
if r.status_code != 200:
logger.error(f"Get status fails: {worker_name}, {r}")
return None
return r.json()
def remove_worker(self, worker_name: str):
del self.worker_info[worker_name]
def refresh_all_workers(self):
old_info = dict(self.worker_info)
self.worker_info = {}
for w_name, w_info in old_info.items():
if not self.register_worker(w_name, w_info.check_heart_beat, None):
logger.info(f"Remove stale worker: {w_name}")
def list_models(self):
model_names = set()
for w_name, w_info in self.worker_info.items():
model_names.update(w_info.model_names)
return list(model_names)
def get_worker_address(self, model_name: str):
if self.dispatch_method == DispatchMethod.LOTTERY:
worker_names = []
worker_speeds = []
for w_name, w_info in self.worker_info.items():
if model_name in w_info.model_names:
worker_names.append(w_name)
worker_speeds.append(w_info.speed)
worker_speeds = np.array(worker_speeds, dtype=np.float32)
norm = np.sum(worker_speeds)
if norm < 1e-4:
return ""
worker_speeds = worker_speeds / norm
if True: # Directly return address
pt = np.random.choice(np.arange(len(worker_names)),
p=worker_speeds)
worker_name = worker_names[pt]
return worker_name
# Check status before returning
while True:
pt = np.random.choice(np.arange(len(worker_names)),
p=worker_speeds)
worker_name = worker_names[pt]
if self.get_worker_status(worker_name):
break
else:
self.remove_worker(worker_name)
worker_speeds[pt] = 0
norm = np.sum(worker_speeds)
if norm < 1e-4:
return ""
worker_speeds = worker_speeds / norm
continue
return worker_name
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
worker_names = []
worker_qlen = []
for w_name, w_info in self.worker_info.items():
if model_name in w_info.model_names:
worker_names.append(w_name)
worker_qlen.append(w_info.queue_length / w_info.speed)
if len(worker_names) == 0:
return ""
min_index = np.argmin(worker_qlen)
w_name = worker_names[min_index]
self.worker_info[w_name].queue_length += 1
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
return w_name
else:
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
def receive_heart_beat(self, worker_name: str, queue_length: int):
if worker_name not in self.worker_info:
logger.info(f"Receive unknown heart beat. {worker_name}")
return False
self.worker_info[worker_name].queue_length = queue_length
self.worker_info[worker_name].last_heart_beat = time.time()
logger.info(f"Receive heart beat. {worker_name}")
return True
def remove_stable_workers_by_expiration(self):
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
to_delete = []
for worker_name, w_info in self.worker_info.items():
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
to_delete.append(worker_name)
for worker_name in to_delete:
self.remove_worker(worker_name)
def worker_api_generate_stream(self, params):
worker_addr = self.get_worker_address(params["model"])
if not worker_addr:
logger.info(f"no worker: {params['model']}")
ret = {
"text": server_error_msg,
"error_code": 2,
}
yield json.dumps(ret).encode() + b"\0"
try:
response = requests.post(worker_addr + "/worker_generate_stream",
json=params, stream=True, timeout=5)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield chunk + b"\0"
except requests.exceptions.RequestException as e:
logger.info(f"worker timeout: {worker_addr}")
ret = {
"text": server_error_msg,
"error_code": 3,
}
yield json.dumps(ret).encode() + b"\0"
# Let the controller act as a worker to achieve hierarchical
# management. This can be used to connect isolated sub networks.
def worker_api_get_status(self):
model_names = set()
speed = 0
queue_length = 0
for w_name in self.worker_info:
worker_status = self.get_worker_status(w_name)
if worker_status is not None:
model_names.update(worker_status["model_names"])
speed += worker_status["speed"]
queue_length += worker_status["queue_length"]
return {
"model_names": list(model_names),
"speed": speed,
"queue_length": queue_length,
}
app = FastAPI()
@app.post("/register_worker")
async def register_worker(request: Request):
data = await request.json()
controller.register_worker(
data["worker_name"], data["check_heart_beat"],
data.get("worker_status", None))
@app.post("/refresh_all_workers")
async def refresh_all_workers():
models = controller.refresh_all_workers()
@app.post("/list_models")
async def list_models():
models = controller.list_models()
return {"models": models}
@app.post("/get_worker_address")
async def get_worker_address(request: Request):
data = await request.json()
addr = controller.get_worker_address(data["model"])
return {"address": addr}
@app.post("/receive_heart_beat")
async def receive_heart_beat(request: Request):
data = await request.json()
exist = controller.receive_heart_beat(
data["worker_name"], data["queue_length"])
return {"exist": exist}
@app.post("/worker_generate_stream")
async def worker_api_generate_stream(request: Request):
params = await request.json()
generator = controller.worker_api_generate_stream(params)
return StreamingResponse(generator)
@app.post("/worker_get_status")
async def worker_api_get_status(request: Request):
return controller.worker_api_get_status()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21001)
parser.add_argument("--dispatch-method", type=str, choices=[
"lottery", "shortest_queue"], default="shortest_queue")
args = parser.parse_args()
logger.info(f"args: {args}")
controller = Controller(args.dispatch_method)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment