Commit 1b9205c9 authored by yangzhong's avatar yangzhong
Browse files

v1.0

parents
Pipeline #2931 failed with stages
in 0 seconds
## 2.0.0
* Add gradient checkpointing, FullyShardedDataParallel
* Model releases
* (CLIP ViT-L-14 / MPT-1B)
* (CLIP ViT-L-14 / MPT-1B Dolly)
* (CLIP ViT-L-14 / RedPajama-3B)
* (CLIP ViT-L-14 / RedPajama-3B Instruct)
* (CLIP ViT-L-14 / MPT-7B)
* Remove color jitter when training
* Fix cross-attention bug when calling generate()
## 1.0.0
* Initial code release
* Early model release (CLIP ViT-L-14 / LLaMA-7B)
\ No newline at end of file
MIT License
Copyright (c) [year] [fullname]
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
\ No newline at end of file
install: ## [Local development] Upgrade pip, install requirements, install package.
python -m pip install -U pip
python -m pip install -e .
install-dev: ## [Local development] Install test requirements
python -m pip install -r requirements-dev.txt
lint: ## [Local development] Run mypy, pylint and black
python -m mypy open_flamingo
python -m pylint open_flamingo
python -m black --check -l 120 open_flamingo
black: ## [Local development] Auto-format python code using black
python -m black -l 120 .
.PHONY: help
help: # Run `make help` to get help on the make commands
@grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
# BLIP-3
## 论文
xGen-MM (BLIP-3): A Family of Open Large Multimodal Models
https://arxiv.org/pdf/2408.08872
## 模型结构
BLIP-3,也叫xGen-MM,是一个用于开发Large的框架多模态模型(lmm)。该框架包括精心准备的数据集、训练配方、模型体系结构,以及最终的lmm套件。xGen-MM是xGen-MultiModal的缩写,扩展了Salesforce xGen计划的基础人工智能模型。模型经过一系列严格的评估的任务,包括单图像和多图像基准。预训练基础模型显示出很强的情境学习能力和指令微调模型在具有相似模型大小的开源lmm中展示了优异的竞争表现。此外,模型还引入了一个安全调优模型DPO,旨在减轻幻觉等有害行为,提高安全性。
## 环境配置
### Docker(方法一)
```
# 在光源可拉取docker镜像:
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.4.1-ubuntu22.04-dtk25.04.1-py3.10
# 创建并启动容器:
docker run -it --network=host -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=80G --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --privileged=true --device=/dev/kfd --device=/dev/dri/ --ipc=host --group-add video --privileged --name <your_proiect_name> <image_id> bash
# 安装依赖包:
python setup.py install
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/
```
### Dockerfile(方法二)
```
docker build --no-cache -t blip3_pytorch:latest .
docker run -it --network=host --name=blip3_pytorch --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 -v /opt/hyhal/:/opt/hyhal/:ro -v /usr/local/hyhal:/usr/local/hyhal:ro blip3_pytorch:latest bash
安装依赖:
python setup.py install
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/
```
### Anaconda(方法三)
```
1.创建conda虚拟环境:
conda create -n blip3_pytorch python=3.10
2.关于本项目DCU显卡所需的工具包、深度学习库等均可从光合开发者社区下载安装:https://developer.hpccube.com/tool/
DTK驱动:dtk25.04.1
python:python3.10
torch:2.4.1
```
Tips:以上DTK、python、torch等DCU相关工具包,版本需要严格一一对应。
```
3.其它非特殊库参照requirements.txt安装
pip install -r requirements-training.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
python setup.py install
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/
```
# 训练
## 数据集
模型支持llava格式的json数据集文件,json文件结构如下。参考数据集[llava_pretrain](http://113.200.138.88:18080/aimodels/llava_pretrain)。您可以放置多个不同的数据集。
接着您需要配置[`data/example_data_config.yaml`](./data_configs/example_data_config.yaml)文件,包括所有json文件路径和图片数量。如果您的json文件内是数据的相对路径,则还需要配置路径映射文件[`data/data_paths.py`](./data/data_paths.py)
```
yaml文件:
data_path: {
'/path/to/llava_pretrain.json': 558128
'/path/to/som_qa_coco20k.json': 20160,
'/path/to/som_listing_coco10k.json': 10000,
}
```
```
json文件:
{
"id": "000000033471",
"image": "coco/train2017/000000033471.jpg",
"conversations": [
{
"from": "human",
"value": "<image>\nWhat are the colors of the bus in the image?"
},
{
"from": "gpt",
"value": "The bus in the image is white and red."
},
...
]
}
```
LLaVA-Pretrain数据集目录结构如下:
```
/path/to/LLaVA-Pretrain/
├── blip_laion_cc_sbu_558k.json
├── 00000
│ ├── 000000010.jpg
│ ├── 000000012.jpg
│ └── ...
├── 00001
├── 00002
└── ...
```
## 微调
#### 预训练权重
可从scnet快速[下载链接](http://113.200.138.88:18080/aimodels/xgen-mm-phi3-mini-base-r-v1.5)获取预训练模型`xgen-mm-phi3-mini-base-r-v1.5`
并运行如下脚本生成pytorch原生格式pt文件:
```
# 修改dest_fn参数为保存路径和pt文件名,以及修改model_name_or_path为预训练模型权重路径
python convert_hf_model.py
```
#### 单机多卡
```
bash scripts/finetune.sh
```
训练脚本参数说明如下
* `exp_name`: 训练日志文件名
* `data_path`: yaml文件路径
* `pretrained_ckpt`: pt文件路径
* `--nproc_per_node=2`: 多卡训练的卡数
* `--nnodes=1`: 节点数
* `--master_port 9650`: 端口
* `--lm_path`: 语言模型(LM)的路径,默认"microsoft/Phi-3-mini-4k-instruct"
* `--tokenizer_path`: 分词器的路径,用于处理文本数据,默认"microsoft/Phi-3-mini-4k-instruct"
* `--vision_encoder_path`: 视觉编码器,默认"google/siglip-so400m-patch14-384"
## result
### 应用场景
### 算法类别
图生文
### 热点应用行业
AIGC,设计
## 源码仓库及问题反馈
- https://developer.sourcefind.cn/codes/modelzoo/blip-3
## 参考资料
- https://github.com/salesforce/LAVIS/tree/xgen-mm
**Please read the following information carefully before proceeding.**
OpenFlamingo is a **research prototype** that aims to enable users to interact with AI through both language and images. AI agents equipped with both language and visual understanding can be useful on a larger variety of tasks compared to models that communicate solely via language. By releasing an open-source research prototype, we hope to help the research community better understand the risks and limitations of modern visual-language AI models and accelerate the development of safer and more reliable methods.
- [ ] I understand that OpenFlamingo is a research prototype and I will only use it for non-commercial research purposes.
**Limitations.** OpenFlamingo is built on top of the LLaMA large language model developed by Meta AI. Large language models, including LLaMA, are trained on mostly unfiltered internet data, and have been shown to be able to produce toxic, unethical, inaccurate, and harmful content. On top of this, OpenFlamingo’s ability to support visual inputs creates additional risks, since it can be used in a wider variety of applications; image+text models may carry additional risks specific to multimodality. Please use discretion when assessing the accuracy or appropriateness of the model’s outputs, and be mindful before sharing its results.
- [ ] I understand that OpenFlamingo may produce unintended, inappropriate, offensive, and/or inaccurate results. I agree to take full responsibility for any use of the OpenFlamingo outputs that I generate.
**Privacy and data collection.** This demo does NOT store any personal information on its users, and it does NOT store user queries.
**Licensing.** As OpenFlamingo is built on top of the LLaMA large language model from Meta AI, the LLaMA license agreement (as documented in the Meta request form) also applies.
- [ ] I have read and agree to the terms of the LLaMA license agreement.
from .src.xgenmm import XGenMMPerceiver
from .src.factory import create_model_and_transforms, SUPPORTED_MODEL_FAMILIES
# OpenFlamingo: Modeling
We provide modules to mix-and-match into several vision-language model architectures.
## What is a VLM?
A **vision-language model (VLM)** is a language model capable of processing a sequence of arbitraily interleaved images/videos with text to output text.
![A VLM takes in a sequence of interleaved images/videos with text and outputs text.](../../docs/signature.png)
The forward signature of a VLM is as follows:
* `vision_x`: The batch of images / videos to process. This is a tensor of the shape `(B, T_img, F, C, H, W)`, where `B` is the batch dimension, `T_img` collates the images/videos within one input sequence, `F` collates frames within a video, and `(C, H, W)` are the channel, height, and width dimensions respectively.
* `lang_x`: The batch of input_ids (text) to process. This is a tensor of the shape `(B, T_txt)`, where `T_txt` is the number of text tokens within one input sequence.
To explain to the model how to interleave the image/text elements within a sequence, `lang_x` should include `<image>` tokens ("media tokens") that specify where the images/videos are placed. (See figure below)
![Illustration of what the inputs to a VLM look like.](../../docs/inputs.png)
## VLM modeling with the open_flamingo repository
This repository provides modules for constructing various VLM architectures.
All models inherit from the `VLM` (vision-language model) class defined in `src/vlm.py`. As documented there, a VLM is defined by four component modules:
1. A **vision encoder** that extracts features from pixels (e.g. CLIP). This module should take in vision inputs of the shape `(B, T_img, F, C, H, W)` and output features of the shape `(B, T_img, F, v, d)`.
2. A **vision tokenizer** that converts features from the vision encoder into token-like embeddings (e.g. PerceiverResampler). This module should take in vision features of the shape `(B, T_img, F, v, d)` and output tokens of the shape `(B, T_img, n, d)`.
3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention (as done in [Flamingo](https://arxiv.org/abs/2204.14198)), or placing the tokens directly in the language model's input sequence (as done in [Kosmos](https://arxiv.org/abs/2306.14824)).
4. A language model.
This repository allows us to construct architectures by mixing-and-matching options for all four kinds of modules.
### Supported vision encoders
All CLIP-style encoders from the [OpenCLIP](https://github.com/mlfoundations/open_clip) library are supported. This includes OpenAI's models.
### Supported vision tokenizers
* [Perceiver Resampler](https://arxiv.org/abs/2103.03206)
* [Q-former](https://arxiv.org/abs/2301.12597)
* Linear projection
### Supported fusion methods
Models are further split into those that inherit from `VLMWithCrossAttention` (dense cross attention to fuse vision + language, Flamingo-style) vs. `VLMWithLanguageStream` (insert vision tokens into the language stream, Kosmos-style).
![A VLM with cross attention and a VLM with language stream represent two methods for fusing the vision and language inputs.](../../docs/xattn_langstream.png)
### Supported language models
All autoregressive language models from [Huggingface Transformers](https://huggingface.co/models) are supported.
## Example architectures
Using these modules, the following architectures are implemented as examples.
|Model|Vision tokenizer|Fusion method|Trainable parameters|
|----|------------|------------|------------|
|[Flamingo](https://arxiv.org/abs/2204.14198)|Perceiver|Cross attention|Added language model embeddings, vision tokenizer|
|[Kosmos](https://arxiv.org/abs/2306.14824)|Perceiver|Language stream|Everything except the vision encoder|
|[BLIP](https://arxiv.org/abs/2301.12597)|Q-former|Language stream|Added language model embeddings, vision tokenizer|
We welcome contributions! If you'd like to add additional vision tokenizers, fusion methods, or model types, please open a PR.
from .helpers import VLMOutputWithPast
\ No newline at end of file
import torch.nn as nn
import torch
from .helpers import GatedCrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive
class DecoderLayerWithCrossAttention(nn.Module):
"""
DecoderLayerWithCrossAttention is a wrapper around the GatedCrossAttentionBlock and DecoderLayer.
"""
def __init__(
self, gated_cross_attn_layer, decoder_layer, gradient_checkpointing=False
):
super().__init__()
self.gated_cross_attn_layer = gated_cross_attn_layer
self.decoder_layer = decoder_layer
self.vis_x = None
self.media_locations = None
if self.gated_cross_attn_layer is not None:
self.gated_cross_attn_layer._use_gradient_checkpointing = (
gradient_checkpointing
)
self.decoder_layer._use_gradient_checkpointing = gradient_checkpointing
def is_conditioned(self) -> bool:
"""Check whether the layer is conditioned."""
return self.vis_x is not None and self.media_locations is not None
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
def condition_vis_x(self, vis_x):
self.vis_x = vis_x
def condition_media_locations(self, media_locations):
self.media_locations = media_locations
def forward(
self,
lang_x,
attention_mask=None,
**decoder_layer_kwargs,
):
# Cross attention
contains_media = (self.media_locations == 1).any()
if contains_media and self.gated_cross_attn_layer is not None:
if self.vis_x is None:
raise ValueError("vis_x must be conditioned before forward pass")
if self.media_locations is None:
raise ValueError(
"media_locations must be conditioned before forward pass"
)
lang_x = self.gated_cross_attn_layer(
lang_x,
self.vis_x,
media_locations=self.media_locations,
)
# Normal decoder layer
lang_x = self.decoder_layer(
lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
)
return lang_x
class CrossAttentionMixin(nn.Module):
"""
Mixin to add cross-attention layers to a language model.
"""
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
self.decoder_layers_attr_name = decoder_layers_attr_name
def _get_decoder_layers(self):
return getattr_recursive(self, self.decoder_layers_attr_name)
def _set_decoder_layers(self, value):
setattr_recursive(self, self.decoder_layers_attr_name, value)
def init_cross_attention_layers(
self,
lang_hidden_size,
vis_hidden_size,
cross_attn_every_n_layers,
gradient_checkpointing,
):
"""
Add gated cross attn layers to the decoder.
"""
old_decoder_blocks = self._get_decoder_layers()
self.decoder_block_class = old_decoder_blocks[0].__class__
self.gated_cross_attn_layers = nn.ModuleList(
[
GatedCrossAttentionBlock(
dim=lang_hidden_size, dim_visual=vis_hidden_size
)
if (layer_idx + 1) % cross_attn_every_n_layers == 0
else None
for layer_idx, _ in enumerate(old_decoder_blocks)
]
)
self._set_decoder_layers(
nn.ModuleList(
[
DecoderLayerWithCrossAttention(
gated_cross_attn_layer, decoder_layer, gradient_checkpointing
)
for gated_cross_attn_layer, decoder_layer in zip(
self.gated_cross_attn_layers, old_decoder_blocks
)
]
)
)
self.initialized_cross_attention = True
def _condition_media_before_forward(
self,
input_ids: torch.Tensor,
vision_tokens: torch.Tensor = None,
past_media_locations: torch.Tensor = None,
past_vision_tokens: torch.Tensor = None,
num_beams: int = 1,
):
"""Each xattn layer needs to save the vision tokens and the locations of the media tokens in the language sequence"""
assert (
self.initialized_cross_attention
), "Cross attention layers have not been initialized. "
# concat with past
if past_media_locations is not None and past_vision_tokens is not None:
if vision_tokens is not None:
updated_vision_tokens = torch.cat(
[
past_vision_tokens,
vision_tokens,
],
dim=1,
)
else:
updated_vision_tokens = past_vision_tokens
updated_media_locations = torch.cat(
[
past_media_locations,
input_ids == self.media_token_id,
],
dim=1,
)
else:
updated_vision_tokens = vision_tokens
updated_media_locations = input_ids == self.media_token_id
# repeat the vision tokens and media locations for each beam
updated_vision_tokens = updated_vision_tokens.repeat_interleave(
num_beams, dim=0
)
updated_media_locations = updated_media_locations.repeat_interleave(
num_beams, dim=0
)
# condition
for layer in self._get_decoder_layers():
layer.condition_vis_x(updated_vision_tokens)
layer.condition_media_locations(updated_media_locations)
def is_conditioned(self) -> bool:
"""Check whether all decoder layers are already conditioned."""
return all(l.is_conditioned() for l in self._get_decoder_layers())
def clear_conditioned_layers(self):
for layer in self._get_decoder_layers():
layer.condition_vis_x(None)
layer.condition_media_locations(None)
\ No newline at end of file
import os
from typing import Optional
import torch
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, Lambda
from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPVisionModel, CLIPImageProcessor, AutoModel, AutoProcessor
import open_clip
from .xgenmm import XGenMMPerceiver
from .utils import hasattr_recursive, setattr_recursive
from PIL import Image
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
MODEL_FAMILY_TO_CLASS = {
"xgenmm_v1": XGenMMPerceiver,
}
SUPPORTED_MODEL_FAMILIES = MODEL_FAMILY_TO_CLASS.keys()
def _convert_image_to_rgb(image):
return image.convert("RGB")
def create_model_and_transforms(
clip_vision_encoder_path: str,
clip_vision_encoder_pretrained: str,
lang_model_path: str,
tokenizer_path: str,
model_family: str = "flamingo",
pretrained_vision_tokenizer: Optional[str] = None,
use_local_files: bool = False,
decoder_layers_attr_name: str = None,
cache_dir: Optional[str] = None,
gradient_checkpointing: bool = False,
verbose: bool = True,
**model_kwargs,
):
"""
Initialize a Flamingo model from a pretrained vision encoder and language encoder.
Appends special tokens to the tokenizer and freezes backbones.
Args:
clip_vision_encoder_path (str): path to pretrained clip model (e.g. "ViT-B-32")
clip_vision_encoder_pretrained (str): name of pretraining dataset for clip model (e.g. "laion2b_s32b_b79k")
lang_model_path (str): path to pretrained language encoder
tokenizer_path (str): path to pretrained tokenizer
cross_attn_every_n_layers (int, optional): determines how often to add a cross-attention layer. Defaults to 1.
use_local_files (bool, optional): whether to use local files. Defaults to False.
decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
cache_dir (str, optional): path to cache directory for downloading OpenClip/HF weights.
gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
verbose (bool, optional): whether to print model info. Defaults to True.
Returns:
Flamingo: Flamingo model from pretrained vision and language encoders
Image processor: Pipeline to preprocess input images
Tokenizer: A tokenizer for the language model
"""
assert model_family in SUPPORTED_MODEL_FAMILIES
# load vision encoder
if clip_vision_encoder_pretrained == 'openai':
vision_encoder = CLIPVisionModel.from_pretrained(clip_vision_encoder_path)
hf_processor = CLIPImageProcessor.from_pretrained(clip_vision_encoder_path)
n_px = hf_processor.crop_size['height']
# Use torchvision processor to be consistent with other vision encoders.
# https://github.com/openai/CLIP/blob/main/clip/clip.py
image_processor = Compose([
Resize((n_px, n_px), interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
vis_hidden_dim = vision_encoder.config.hidden_size
elif clip_vision_encoder_pretrained == 'google':
# "google/siglip-so400m-patch14-384"
model = AutoModel.from_pretrained(clip_vision_encoder_path)
hf_processor = AutoProcessor.from_pretrained(clip_vision_encoder_path)
n_px = hf_processor.image_processor.size['height']
vision_encoder = model.vision_model
vis_hidden_dim = vision_encoder.config.hidden_size
# Define the transformation sequence
image_processor = Compose([
Resize((n_px, n_px), interpolation=InterpolationMode.BICUBIC, antialias=True),
Lambda(lambda x: x.convert('RGB') if x.mode != 'RGB' else x),
ToTensor(),
Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
else:
vision_encoder, _, image_processor = open_clip.create_model_and_transforms(
clip_vision_encoder_path,
pretrained=clip_vision_encoder_pretrained,
)
vision_encoder.visual.output_tokens = True
vision_encoder = vision_encoder.visual
vision_encoder_config = open_clip.get_model_config(clip_vision_encoder_path)
if "SigLIP" in clip_vision_encoder_path or "EVA" in clip_vision_encoder_path: # SigLIP models have a different config format
vis_hidden_dim = vision_encoder_config["embed_dim"]
else:
vis_hidden_dim = vision_encoder_config["vision_cfg"]["width"]
# load tokenizer and ensure there is a pad token
text_tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path,
local_files_only=use_local_files,
trust_remote_code=True,
use_fast=False,
)
if text_tokenizer.pad_token is None or text_tokenizer.pad_token == text_tokenizer.eos_token:
# add a pad token if it doesn't exist
text_tokenizer.add_special_tokens({"pad_token": "<pad>"})
added_pad_token = True
else:
added_pad_token = False
# load langauge model
if ('phi3' in lang_model_path.lower()) or ('phi-3' in lang_model_path.lower()):
if 'instruct' not in lang_model_path.lower():
raise ValueError("As of now, we only support instruct models for phi3. Please use a model with 'instruct' in the path.")
trust_remote_code_flag = True # phi3 is not stable yet, so we trust the remote code
else:
trust_remote_code_flag = False # froce to use modeling code from local files so that the fsdp wrapper can be applied
lang_model = AutoModelForCausalLM.from_pretrained(
lang_model_path,
local_files_only=use_local_files,
trust_remote_code=trust_remote_code_flag,
)
check_embedding_fns(lang_model)
# init the model
if decoder_layers_attr_name is None:
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_model)
model = MODEL_FAMILY_TO_CLASS[model_family](
vision_encoder=vision_encoder,
lang_model=lang_model,
vis_feature_dim=vis_hidden_dim,
initial_tokenizer_len=len(text_tokenizer),
gradient_checkpointing=gradient_checkpointing,
decoder_layers_attr_name=decoder_layers_attr_name,
pad_token_id=text_tokenizer.pad_token_id,
**model_kwargs,
)
if pretrained_vision_tokenizer is not None:
assert os.path.exists(pretrained_vision_tokenizer), "pretrained weight must exist."
vis_tok_weight = torch.load(pretrained_vision_tokenizer)
model.vision_tokenizer.load_state_dict(vis_tok_weight, strict=True)
# add special tokens to the tokenizer and language models
text_tokenizer.add_special_tokens(
{"additional_special_tokens": list(model.special_tokens.values())}
)
model.lang_model.config.vocab_size = len(text_tokenizer)
model.set_special_token_ids(
{
v: text_tokenizer.convert_tokens_to_ids(v)
for v in model.special_tokens.values()
}
)
# freeze appropriate parameters
model.set_trainable()
# log model info
if verbose:
print(
f"{model_family} model initialized with {model.num_trainable_params:,} trainable parameters"
)
print(f"==========Trainable Parameters\n{model.num_trainable_params_per_module}")
print(f"==========Total Parameters\n{model.num_params_per_module}\n==========")
return model, image_processor, text_tokenizer
def _infer_decoder_layers_attr_name(model):
"""
Infer the name of the attribute storing the decoder layers (as a ModuleList) in the model.
"""
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
if k.lower() in model.__class__.__name__.lower():
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
raise ValueError(
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
)
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
"opt": "model.decoder.layers",
"gptj": "transformer.h",
"gpt-j": "transformer.h",
"pythia": "gpt_neox.layers",
"llama": "model.layers",
"gptneoxforcausallm": "gpt_neox.layers",
"mpt": "transformer.blocks",
"mosaicgpt": "transformer.blocks",
"gemma": "model.layers",
"phi": "model.layers",
"minicpm": "model.layers",
"stablelm": "model.layers",
"qwen": "model.layers",
"mistral": "model.layers"
}
def check_embedding_fns(lang_model):
"""Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
if not has_fn(lang_model, "get_input_embeddings"):
if hasattr_recursive(lang_model, "transformer.wte"): # MPT
lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
else:
raise ValueError(
"We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
)
if not has_fn(lang_model, "set_input_embeddings"):
if hasattr_recursive(lang_model, "transformer.wte"): # MPT
lang_model.set_input_embeddings = lambda x: setattr_recursive(
lang_model, "transformer.wte", x
)
elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
lang_model.set_input_embeddings = lambda x: setattr_recursive(
lang_model, "model.decoder.embed_tokens", x
)
else:
raise ValueError(
"We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
)
if not has_fn(lang_model, "get_output_embeddings"):
if hasattr_recursive(lang_model, "lm_head"):
lang_model.get_output_embeddings = lambda: lang_model.lm_head
else:
raise ValueError(
"We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
)
if not has_fn(lang_model, "set_output_embeddings"):
if hasattr_recursive(lang_model, "lm_head"):
lang_model.set_output_embeddings = lambda x: setattr_recursive(
lang_model, "lm_head", x
)
else:
raise ValueError(
"We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
)
def has_fn(model, fn_name):
"""Check if model has a function fn_name"""
return callable(getattr(model, fn_name, None))
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment