Commit 58d33d4c authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #1904 canceled with stages
# Copyright 2023 Haotian Liu
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
import shutil
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
from transformers.models.clip.image_processing_clip import CLIPImageProcessor
import torch
from mplug_docowl.model import *
from icecream import ic
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
kwargs = {"device_map": device_map}
if device != "cuda":
kwargs['device_map'] = {"": device}
if load_8bit:
kwargs['load_in_8bit'] = True
elif load_4bit:
kwargs['load_in_4bit'] = True
kwargs['quantization_config'] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type='nf4'
)
else:
kwargs['torch_dtype'] = torch.float16
if 'paperowl' or 'docowl' in model_name.lower():
if model_base is not None:
# this may be mm projector only
print('Loading mPLUG-DocOwl from base model...')
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
model = MPLUGDocOwlLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = MPLUGDocOwlLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
else:
# Load language model
if model_base is not None:
# PEFT model
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
print(f"Loading LoRA weights from {model_path}")
model = PeftModel.from_pretrained(model, model_path)
print(f"Merging weights")
model = model.merge_and_unload()
print('Convert to FP16...')
model.to(torch.float16)
else:
use_fast = False
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
# vision_tower = model.get_model().vision_model
# vision_tower.to(device=device, dtype=torch.float16)
# image_processor = CLIPImageProcessor.from_pretrained(model_path)
image_processor = None
if hasattr(model.config, "max_sequence_length"):
context_len = model.config.max_sequence_length
else:
context_len = 2048
return tokenizer, model, image_processor, context_len
\ No newline at end of file
# Copyright (c) Alibaba.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import copy
import os
from typing import Union
from transformers.configuration_utils import PretrainedConfig
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import logging
from transformers.models.auto import CONFIG_MAPPING
class LlamaConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the LLaMA-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32000):
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`LlamaModel`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
Llama 2 up to 4096, CodeLlama up to 16384.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
pretraining_tp (`int`, *optional*, defaults to 1):
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
issue](https://github.com/pytorch/pytorch/issues/76232).
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
```python
>>> from transformers import LlamaModel, LlamaConfig
>>> # Initializing a LLaMA llama-7b style configuration
>>> configuration = LlamaConfig()
>>> # Initializing a model from the llama-7b style configuration
>>> model = LlamaModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "llama"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.attention_bias = attention_bias
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
class MplugOwlVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`MplugOwlVisionModel`]. It is used to instantiate
a
mPLUG-Owl vision encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration defaults will yield a similar configuration to that of the mPLUG-Owl
[x-plug/x_plug-llama-7b](https://huggingface.co/x-plug/x_plug-llama-7b) architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_size (`int`, *optional*, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
intermediate_size (`int`, *optional*, defaults to 3072):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
num_hidden_layers (`int`, *optional*, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
patch_size (`int`, *optional*, defaults to 32):
The size (resolution) of each patch.
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
The epsilon used by the layer normalization layers.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
```"""
model_type = "mplug_owl_vision_model"
def __init__(
self,
hidden_size=1024,
intermediate_size=4096,
projection_dim=768,
num_hidden_layers=24,
num_attention_heads=16,
num_channels=3,
image_size=448,
patch_size=14,
hidden_act="quick_gelu",
layer_norm_eps=1e-6,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
use_flash_attn=False,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.projection_dim = projection_dim
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.use_flash_attn = use_flash_attn
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the vision config dict if we are loading from MplugOwlConfig
if config_dict.get("model_type") == "mplug-owl":
config_dict = config_dict["vision_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
class MplugDocOwlHReducerConfig(PretrainedConfig):
model_type = "mplug_docowl_hreducer"
def __init__(
self,
hidden_size=1024,
initializer_range=0.02,
layer_norm_eps=1e-6,
conv_shape='1x4',
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.conv_shape = conv_shape
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
# get the visual_abstractor config dict if we are loading from MplugOwlConfig
if config_dict.get("model_type") == "mplug-docowl":
config_dict = config_dict["hreducer_config"]
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
DEFAULT_VISUAL_CONFIG = {
"visual_model": MplugOwlVisionConfig().to_dict(),
"visual_hreducer": MplugDocOwlHReducerConfig().to_dict()
}
class MPLUGDocOwlConfig(LlamaConfig):
model_type = "mplug_docowl"
def __init__(self, visual_config=None, **kwargs):
if visual_config is None:
self.visual_config = DEFAULT_VISUAL_CONFIG
else:
self.visual_config = visual_config
super().__init__(
**kwargs,
)
if __name__ == "__main__":
print(MplugOwlVisionConfig().to_dict())
\ No newline at end of file
# Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import math
import os
import shutil
import warnings
import torch
from transformers import LlamaTokenizer
from .configuration_mplug_docowl import MPLUGDocOwlConfig
from icecream import ic
try:
from transformers import LlamaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
LlamaTokenizerFast = None
"""
Sample usage:
```
python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \
--input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B
```
Thereafter, models can be loaded via:
```py
from transformers import LlamaForCausalLM, LlamaTokenizer
model = LlamaForCausalLM.from_pretrained("/output/path")
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192}
def compute_intermediate_size(n):
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(model_path,
input_base_path,
model_size,
num_input_shards=1,
num_output_shards=2,
skip_permute=True,
norm_eps=1e-05):
# if os.path.exists(model_path):
# shutil.rmtree(model_path)
os.makedirs(model_path, exist_ok=True)
# tmp_model_path = os.path.join(model_path, "tmp")
tmp_model_path = model_path
os.makedirs(tmp_model_path, exist_ok=True)
num_shards = num_input_shards
n_layers = llama_s2layer[model_size]
n_heads = llama_s2heads[model_size]
n_heads_per_shard = n_heads // num_shards
n_dense = llama_s2dense[model_size]
n_hidden = llama_s2hidden[model_size]
hidden_per_head = n_hidden // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))
# permute for sliced rotary
def permute(w, skip_permute=skip_permute):
if skip_permute:
return w
return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if num_shards==1:
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
# /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt
if os.path.exists(os.path.join(input_base_path, 'release')):
filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt')
elif input_base_path.split('/')[-1].startswith('iter_'):
iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0'))
load_dir = '/'.join(input_base_path.split('/')[:-1])
filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt')
if not os.path.exists(filename):
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
else:
tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
iteration = 'iter_{:07d}'.format(int(metastring))
filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt')
if not os.path.exists(filename):
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
original_filename = filename
loaded = torch.load(filename, map_location="cpu")['model']['language_model']
else:
# Sharded
filenames = []
for i in range(num_shards):
if os.path.exists(os.path.join(input_base_path, 'release')):
filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt')
else:
tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
iteration = 'iter_{:07d}'.format(int(metastring))
filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt')
if not os.path.exists(filename):
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
filenames.append(filename)
loaded = [
torch.load(filenames[i], map_location="cpu")['model']['language_model']
for i in range(num_shards)
]
print('Llama-Megatron Loaded!')
param_count = 0
index_dict = {"weight_map": {}}
print(f'Weighted Converting for {n_layers} layers...')
for layer_i in range(n_layers):
print(layer_i)
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
if num_shards == 1:
# Unsharded
state_dict = {
f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"],
f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"],
f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"],
f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"],
f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"],
f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"],
f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"],
f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"],
f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"],
}
else:
raise NotImplemented
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
print(f'Sharded file saved to {filename}')
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
if num_shards==1:
# Unsharded
state_dict = {
"model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'],
"model.norm.weight": loaded['encoder']['norm.weight'],
"lm_head.weight": loaded['encoder']['lm_head.weight'],
}
else:
state_dict = {
"model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'],
"model.norm.weight": loaded[0]['encoder']['norm.weight'],
"lm_head.weight": loaded[0]['encoder']['lm_head.weight'],
}
loaded_all = torch.load(original_filename, map_location="cpu")['model']
# Vision Part
state_dict.update({
"model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'],
"model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'],
"model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'],
"model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'],
"model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'],
"model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'],
"model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'],
})
for v_layer_idx in range(24):
state_dict.update({
f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'],
})
# Vision2Text Part: HReducer
state_dict.update({
"model.vision2text.ln_q.weight": loaded_all['hreducer3']['ln_q']['weight'],
"model.vision2text.ln_q.bias": loaded_all['hreducer3']['ln_q']['bias'],
"model.vision2text.visual_fc.bias": loaded_all['hreducer3']['visual_fc']['bias'],
"model.vision2text.visual_fc.weight": loaded_all['hreducer3']['visual_fc']['weight'],
"model.vision2text.vit_eos": loaded_all['hreducer3']['vit_eos'],
})
# reducer_before conv (layer 0) + gleu (layer 1)
state_dict.update({
f"model.vision2text.reducer_before.0.weight": loaded_all['hreducer3']['reducer_before']["0.weight"],
f"model.vision2text.reducer_before.0.bias": loaded_all['hreducer3']['reducer_before']["0.bias"],
})
# reducer conv
state_dict.update({
f"model.vision2text.reducer.weight": loaded_all['hreducer3']['reducer']["weight"],
f"model.vision2text.reducer.bias": loaded_all['hreducer3']['reducer']["bias"],
})
for k, v in state_dict.items():
# ic(k, v)
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
config = MPLUGDocOwlConfig()
config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
del loaded_all
gc.collect()
def write_tokenizer(tokenizer_path, input_tokenizer_path):
# Initialize the tokenizer based on the `spm` model
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
tokenizer = tokenizer_class(input_tokenizer_path)
tokenizer.save_pretrained(tokenizer_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of LLaMA_Megatron weights",
)
parser.add_argument(
"--model_size",
type=int,
default=7,
choices=[7, 13, 30, 65, 70],
)
parser.add_argument(
"--num_input_shards",
type=int,
default=1,
)
parser.add_argument(
"--num_output_shards",
type=int,
default=1,
)
parser.add_argument('--skip_permute', action='store_true')
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
model_size=args.model_size,
num_input_shards=args.num_input_shards,
num_output_shards=args.num_output_shards,
skip_permute=args.skip_permute
)
if __name__ == "__main__":
main()
# Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import gc
import json
import math
import os
import shutil
import warnings
import torch
from transformers import LlamaTokenizer
from .configuration_mplug_docowl import MPLUGDocOwlConfig
from icecream import ic
try:
from transformers import LlamaTokenizerFast
except ImportError as e:
warnings.warn(e)
warnings.warn(
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
)
LlamaTokenizerFast = None
"""
Sample usage:
```
python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \
--input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B
```
Thereafter, models can be loaded via:
```py
from transformers import LlamaForCausalLM, LlamaTokenizer
model = LlamaForCausalLM.from_pretrained("/output/path")
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
"""
llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192}
def compute_intermediate_size(n):
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
def read_json(path):
with open(path, "r") as f:
return json.load(f)
def write_json(text, path):
with open(path, "w") as f:
json.dump(text, f)
def write_model(model_path,
input_base_path,
model_size,
num_input_shards=1,
num_output_shards=2,
skip_permute=True,
norm_eps=1e-05):
# if os.path.exists(model_path):
# shutil.rmtree(model_path)
os.makedirs(model_path, exist_ok=True)
# tmp_model_path = os.path.join(model_path, "tmp")
tmp_model_path = model_path
os.makedirs(tmp_model_path, exist_ok=True)
num_shards = num_input_shards
n_layers = llama_s2layer[model_size]
n_heads = llama_s2heads[model_size]
n_heads_per_shard = n_heads // num_shards
n_dense = llama_s2dense[model_size]
n_hidden = llama_s2hidden[model_size]
hidden_per_head = n_hidden // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))
# permute for sliced rotary
def permute(w, skip_permute=skip_permute):
if skip_permute:
return w
return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden)
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
# Load weights
if num_shards==1:
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
# /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt
if os.path.exists(os.path.join(input_base_path, 'release')):
filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt')
elif input_base_path.split('/')[-1].startswith('iter_'):
iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0'))
load_dir = '/'.join(input_base_path.split('/')[:-1])
filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt')
if not os.path.exists(filename):
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
else:
tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
iteration = 'iter_{:07d}'.format(int(metastring))
filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt')
if not os.path.exists(filename):
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
original_filename = filename
loaded = torch.load(filename, map_location="cpu")['model']['language_model']
else:
# Sharded
filenames = []
for i in range(num_shards):
if os.path.exists(os.path.join(input_base_path, 'release')):
filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt')
else:
tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
iteration = 'iter_{:07d}'.format(int(metastring))
filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt')
if not os.path.exists(filename):
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
filenames.append(filename)
loaded = [
torch.load(filenames[i], map_location="cpu")['model']['language_model']
for i in range(num_shards)
]
print('Llama-Megatron Loaded!')
param_count = 0
index_dict = {"weight_map": {}}
state_dict = {}
print(f'Weighted Converting for {n_layers} layers...')
for layer_i in range(n_layers):
print(layer_i)
# filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
if num_shards == 1:
# Unsharded
state_dict.update({
f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"],
f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"],
f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"],
f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"],
f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"],
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"],
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"],
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"],
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"],
f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"],
f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"],
f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"],
f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"],
})
else:
raise NotImplemented
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
for k, v in state_dict.items():
index_dict["weight_map"][k] = filename
param_count += v.numel()
# torch.save(state_dict, os.path.join(tmp_model_path, filename))
# print(f'Sharded file saved to {filename}')
# filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
filename = "pytorch_model.bin"
if num_shards==1:
# Unsharded
state_dict.update({
"model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'],
"model.norm.weight": loaded['encoder']['norm.weight'],
"lm_head.weight": loaded['encoder']['lm_head.weight'],
})
else:
state_dict.update({
"model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'],
"model.norm.weight": loaded[0]['encoder']['norm.weight'],
"lm_head.weight": loaded[0]['encoder']['lm_head.weight'],
})
loaded_all = torch.load(original_filename, map_location="cpu")['model']
# Vision Part
state_dict.update({
"model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'],
"model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'],
"model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'],
"model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'],
"model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'],
"model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'],
"model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'],
})
for v_layer_idx in range(24):
state_dict.update({
f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'],
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'],
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'],
})
# Vision2Text Part: HReducer
state_dict.update({
"model.vision2text.ln_q.weight": loaded_all['hreducer3']['ln_q']['weight'],
"model.vision2text.ln_q.bias": loaded_all['hreducer3']['ln_q']['bias'],
"model.vision2text.visual_fc.bias": loaded_all['hreducer3']['visual_fc']['bias'],
"model.vision2text.visual_fc.weight": loaded_all['hreducer3']['visual_fc']['weight'],
"model.vision2text.vit_eos": loaded_all['hreducer3']['vit_eos'],
})
# reducer_before conv (layer 0) + gleu (layer 1)
state_dict.update({
f"model.vision2text.reducer_before.0.weight": loaded_all['hreducer3']['reducer_before']["0.weight"],
f"model.vision2text.reducer_before.0.bias": loaded_all['hreducer3']['reducer_before']["0.bias"],
})
# reducer conv
state_dict.update({
f"model.vision2text.reducer.weight": loaded_all['hreducer3']['reducer']["weight"],
f"model.vision2text.reducer.bias": loaded_all['hreducer3']['reducer']["bias"],
})
for k, v in state_dict.items():
# ic(k, v)
index_dict["weight_map"][k] = filename
param_count += v.numel()
torch.save(state_dict, os.path.join(tmp_model_path, filename))
print(f'save to {os.path.join(tmp_model_path, filename)}')
# Write configs
index_dict["metadata"] = {"total_size": param_count * 2}
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
config = MPLUGDocOwlConfig()
config.save_pretrained(tmp_model_path)
# Make space so we can load the model properly now.
del state_dict
del loaded
del loaded_all
gc.collect()
def write_tokenizer(tokenizer_path, input_tokenizer_path):
# Initialize the tokenizer based on the `spm` model
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
tokenizer = tokenizer_class(input_tokenizer_path)
tokenizer.save_pretrained(tokenizer_path)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir",
help="Location of LLaMA_Megatron weights",
)
parser.add_argument(
"--model_size",
type=int,
default=7,
choices=[7, 13, 30, 65, 70],
)
parser.add_argument(
"--num_input_shards",
type=int,
default=1,
)
parser.add_argument(
"--num_output_shards",
type=int,
default=1,
)
parser.add_argument('--skip_permute', action='store_true')
parser.add_argument(
"--output_dir",
help="Location to write HF model and tokenizer",
)
args = parser.parse_args()
write_model(
model_path=args.output_dir,
input_base_path=args.input_dir,
model_size=args.model_size,
num_input_shards=args.num_input_shards,
num_output_shards=args.num_output_shards,
skip_permute=args.skip_permute
)
if __name__ == "__main__":
main()
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple, Union
import torch
class AttentionMaskConverter:
"""
A utility attention mask class that allows one to:
- Create a causal 4d mask
- Create a causal 4d mask with slided window
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
key_value_length) that can be multiplied with attention scores
Parameters:
is_causal (`bool`):
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
sliding_window (`int`, *optional*):
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
"""
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
self.is_causal = is_causal
self.sliding_window = sliding_window
if self.sliding_window is not None and self.sliding_window <= 0:
raise ValueError(
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
)
def to_causal_4d(
self,
batch_size: int,
query_length: int,
key_value_length: int,
dtype: torch.dtype = torch.float32,
device: Union[torch.device, "str"] = "cpu",
) -> torch.Tensor:
"""
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
bias to upper right hand triangular matrix (causal mask).
"""
if not self.is_causal:
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
# If shape is not cached, create a new causal mask and cache it
input_shape = (batch_size, query_length)
past_key_values_length = key_value_length - query_length
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if input_shape[-1] > 1 or self.sliding_window is not None:
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
return causal_4d_mask
def to_4d(
self,
attention_mask_2d: torch.Tensor,
query_length: int,
key_value_length: Optional[int] = None,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
causal, a causal mask will be added.
"""
input_shape = (attention_mask_2d.shape[0], query_length)
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
causal_4d_mask = None
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
)
past_key_values_length = key_value_length - query_length
causal_4d_mask = self._make_causal_mask(
input_shape,
dtype,
device=attention_mask_2d.device,
past_key_values_length=past_key_values_length,
sliding_window=self.sliding_window,
)
elif self.sliding_window is not None:
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
attention_mask_2d.device
)
expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
return expanded_4d_mask
@staticmethod
def _make_causal_mask(
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
@staticmethod
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def _prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
input_shape: Union[torch.Size, Tuple, List],
inputs_embeds: torch.Tensor,
past_key_values_length: int,
sliding_window: Optional[int] = None,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
attention_mask (`torch.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
The input shape should be a tuple that defines `(batch_size, query_length)`.
inputs_embeds (`torch.Tensor`):
The embedded inputs as a torch Tensor.
past_key_values_length (`int`):
The length of the key value cache.
sliding_window (`int`, *optional*):
If the model uses windowed attention, a sliding window should be passed.
"""
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
key_value_length = input_shape[-1] + past_key_values_length
# 4d mask is passed through the layers
if attention_mask is not None:
attention_mask = attn_mask_converter.to_4d(
attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
)
else:
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
return attention_mask
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`
Args:
mask (`torch.Tensor` or `None`):
A 2D attention mask of shape `(batch_size, key_value_length)`
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
tgt_len (`int`):
The target length or query length the created mask shall have.
"""
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _create_4d_causal_attention_mask(
input_shape: Union[torch.Size, Tuple, List],
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: Optional[int] = None,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
Args:
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
The input shape should be a tuple that defines `(batch_size, query_length)`.
dtype (`torch.dtype`):
The torch dtype the created mask shall have.
device (`int`):
The torch device the created mask shall have.
sliding_window (`int`, *optional*):
If the model uses windowed attention, a sliding window should be passed.
"""
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
key_value_length = past_key_values_length + input_shape[-1]
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
)
return attention_mask
\ No newline at end of file
import math
import warnings
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import transformers
from transformers.models.llama.modeling_llama import *
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from .configuration_mplug_docowl import LlamaConfig
class MultiwayNetwork(nn.Module):
def __init__(self, module_provider, num_multiway=2):
super(MultiwayNetwork, self).__init__()
self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
def forward(self, hidden_states, multiway_indices):
if len(self.multiway) == 1:
return self.multiway[0](hidden_states)
output_hidden_states = torch.empty_like(hidden_states)
for idx, subway in enumerate(self.multiway):
local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
if hidden.numel():
output = subway(hidden)
if isinstance(output, tuple):
output = output[0]
output = output.squeeze(1)
output_hidden_states[local_indices] = output
return output_hidden_states.contiguous()
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = MultiwayNetwork(module_provider=partial(
nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
)
self.v_proj = MultiwayNetwork(module_provider=partial(
nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
modality_indicators: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states, )
key_states = self.k_proj(hidden_states, modality_indicators)
value_states = self.v_proj(hidden_states, modality_indicators)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaAttention(config=config)
self.mlp = LlamaMLP(config)
self.input_layernorm = MultiwayNetwork(module_provider=partial(
LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
))
self.post_attention_layernorm = MultiwayNetwork(module_provider=partial(
LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
))
def forward(
self,
hidden_states: torch.Tensor,
modality_indicators: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states, modality_indicators)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
modality_indicators=modality_indicators,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
def model_forward(
self,
input_ids: torch.LongTensor = None,
modality_indicators: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
modality_indicators,
attention_mask,
position_ids,
)
else:
layer_outputs = decoder_layer(
hidden_states,
modality_indicators=modality_indicators,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def causal_model_forward(
self,
input_ids: torch.LongTensor = None,
modality_indicators: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
modality_indicators=modality_indicators,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def replace_llama_modality_adaptive():
transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
if __name__ == "__main__":
replace_llama_modality_adaptive()
config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
model = transformers.LlamaForCausalLM(config)
print(model)
\ No newline at end of file
# Copyright 2023 Haotian Liu & Qinghao Ye (Modified from LLaVA)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaModel, LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_mplug_docowl import (MPLUGDocOwlConfig, MplugOwlVisionConfig, MplugDocOwlHReducerConfig)
from .visual_encoder import MplugOwlVisionModel, MplugDocOwlHReducerModel
from .modeling_llama2 import replace_llama_modality_adaptive
from mplug_docowl.constants import IMAGE_TOKEN_INDEX, IGNORE_INDEX
from icecream import ic
class MPLUGDocOwlMetaModel:
def __init__(self, config):
super(MPLUGDocOwlMetaModel, self).__init__(config)
self.vision_model = MplugOwlVisionModel(
MplugOwlVisionConfig(**config.visual_config["visual_model"])
)
self.vision2text = MplugDocOwlHReducerModel(
MplugDocOwlHReducerConfig(**config.visual_config["visual_hreducer"]), config.hidden_size
)
def get_vision_tower(self):
vision_model = getattr(self, 'vision_model', None)
if type(vision_model) is list:
vision_model = vision_model[0]
return vision_model
def get_vision2text(self):
vision2text = getattr(self, 'vision2text', None)
if type(vision2text) is list:
vision2text = vision2text[0]
return vision2text
class MPLUGDocOwlMetaForCausalLM(ABC):
@abstractmethod
def get_model(self):
pass
def encode_images(self, images, patch_positions):
image_features = self.get_model().vision_model(images).last_hidden_state
image_features = self.get_model().vision2text(encoder_hidden_states=image_features)
return image_features
def prepare_inputs_labels_for_multimodal(
self, input_ids, attention_mask, past_key_values, labels, images, patch_positions
):
if images is None or input_ids.shape[1] == 1:
if past_key_values is not None and images is not None and input_ids.shape[1] == 1:
attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device)
multiway_indices = torch.zeros_like(input_ids).long().to(self.device)
return input_ids, multiway_indices, attention_mask, past_key_values, None, labels
if type(images) is list or images.ndim == 5:
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images, patch_positions)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
image_features = [x.flatten(0, 1) for x in image_features]
else:
image_features = self.encode_images(images, patch_positions) # Sum(Crop Image Number) x L x d
new_input_embeds = []
new_modality_indicators = []
new_labels = [] if labels is not None else None
cur_image_idx = 0
for batch_idx, cur_input_ids in enumerate(input_ids):
if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0:
# multimodal LLM, but the current sample is not multimodal
# FIXME: this is a hacky fix, for deepspeed zero3 to work
half_len = cur_input_ids.shape[0] // 2
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len])
cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:])
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0)
new_input_embeds.append(cur_input_embeds)
cur_modality_indicators = torch.zeros(len(cur_input_embeds)).long().to(self.device)
new_modality_indicators.append(cur_modality_indicators)
if labels is not None:
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
cur_new_input_embeds = []
cur_modality_indicators = []
if labels is not None:
cur_labels = labels[batch_idx]
cur_new_labels = []
assert cur_labels.shape == cur_input_ids.shape
while image_token_indices.numel() > 0:
cur_image_features = image_features[cur_image_idx]
image_token_start = image_token_indices[0]
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start]))
cur_new_input_embeds.append(cur_image_features)
# Add modality indicator
assert image_token_start == len(cur_input_ids[:image_token_start])
cur_modality_indicators.append(torch.zeros(len(cur_input_ids[:image_token_start])).long())
cur_modality_indicators.append(torch.ones(len(cur_image_features)).long())
if labels is not None:
cur_new_labels.append(cur_labels[:image_token_start])
cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype))
cur_labels = cur_labels[image_token_start+1:]
cur_image_idx += 1
cur_input_ids = cur_input_ids[image_token_start+1:]
image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0]
if cur_input_ids.numel() > 0:
cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
cur_modality_indicators.append(torch.zeros(len(cur_input_ids)).long())
if labels is not None:
cur_new_labels.append(cur_labels)
cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds]
cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0)
new_input_embeds.append(cur_new_input_embeds)
# Modality
cur_modality_indicators = [x.to(device=self.device) for x in cur_modality_indicators]
cur_modality_indicators = torch.cat(cur_modality_indicators, dim=0)
new_modality_indicators.append(cur_modality_indicators)
if labels is not None:
cur_new_labels = torch.cat(cur_new_labels, dim=0)
new_labels.append(cur_new_labels)
if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds):
max_len = max(x.shape[0] for x in new_input_embeds)
# Embedding
new_input_embeds_align = []
for cur_new_embed in new_input_embeds:
cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0)
new_input_embeds_align.append(cur_new_embed)
new_input_embeds = torch.stack(new_input_embeds_align, dim=0)
# Modality
new_modality_indicators_align = []
for cur_modality_indicator in new_modality_indicators:
cur_new_embed = torch.cat((cur_modality_indicator, torch.zeros(max_len - cur_modality_indicator.shape[0], dtype=cur_modality_indicator.dtype, device=cur_modality_indicator.device)), dim=0)
new_modality_indicators_align.append(cur_new_embed)
new_modality_indicators = torch.stack(new_modality_indicators_align, dim=0)
# Label
if labels is not None:
new_labels_align = []
_new_labels = new_labels
for cur_new_label in new_labels:
cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0)
new_labels_align.append(cur_new_label)
new_labels = torch.stack(new_labels_align, dim=0)
# Attention Mask
if attention_mask is not None:
new_attention_mask = []
for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels):
new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device)
new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device)
cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0)
new_attention_mask.append(cur_new_attention_mask)
attention_mask = torch.stack(new_attention_mask, dim=0)
assert attention_mask.shape == new_labels.shape
else:
new_input_embeds = torch.stack(new_input_embeds, dim=0)
new_modality_indicators = torch.stack(new_modality_indicators, dim=0)
if labels is not None:
new_labels = torch.stack(new_labels, dim=0)
if attention_mask is not None:
new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device)
attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1)
assert attention_mask.shape == new_input_embeds.shape[:2]
return None, new_modality_indicators, attention_mask, past_key_values, new_input_embeds, new_labels
class MPLUGDocOwlLlamaModel(MPLUGDocOwlMetaModel, LlamaModel):
config_class = MPLUGDocOwlConfig
def __init__(self, config: MPLUGDocOwlConfig):
super(MPLUGDocOwlLlamaModel, self).__init__(config)
class MPLUGDocOwlLlamaForCausalLM(LlamaForCausalLM, MPLUGDocOwlMetaForCausalLM):
config_class = MPLUGDocOwlConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = MPLUGDocOwlLlamaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_model(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
# modality_indicators: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
images: Optional[torch.FloatTensor] = None,
patch_positions: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
# print('modeling_mplug_docow2.py patch_positions:', patch_positions)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_ids, modality_indicators, attention_mask, past_key_values, inputs_embeds, labels = \
self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, patch_positions)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
modality_indicators=modality_indicators,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model/pipeline parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values:
input_ids = input_ids[:, -1:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"images": kwargs.get("images", None),
"patch_positions": kwargs.get("patch_positions", None),
}
)
return model_inputs
AutoConfig.register("mplug_docowl", MPLUGDocOwlConfig)
AutoModelForCausalLM.register(MPLUGDocOwlConfig, MPLUGDocOwlLlamaForCausalLM)
replace_llama_modality_adaptive()
\ No newline at end of file
from transformers import AutoConfig
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
if 'mplug_owl2' in config and 'mplug_owl2' not in cfg.model_type:
assert cfg.model_type == 'mplug_owl2'
print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "mplug_owl2")
cfg.architectures[0] = 'LlavaLlamaForCausalLM'
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
print("Checkpoint upgrade aborted.")
exit(1)
\ No newline at end of file
import math
from typing import Any, Optional, Tuple, Union
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, BaseModelOutputWithPastAndCrossAttentions
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
import numpy as np
import torch
import torch.nn as nn
import torch.utils.checkpoint
from icecream import ic
import einops
from einops import rearrange
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class MplugOwlVisionEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_size))
self.patch_embed = nn.Conv2d(
in_channels=3,
out_channels=self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.position_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, self.hidden_size))
self.pre_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.size(0)
image_embeds = self.patch_embed(pixel_values)
image_embeds = image_embeds.flatten(2).transpose(1, 2)
class_embeds = self.cls_token.expand(batch_size, 1, -1).to(image_embeds.dtype)
embeddings = torch.cat([class_embeds, image_embeds], dim=1)
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1)].to(image_embeds.dtype)
embeddings = self.pre_layernorm(embeddings)
return embeddings
class MplugOwlVisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
if self.head_dim * self.num_heads != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = nn.Dropout(config.attention_dropout)
self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size)
self.dense = nn.Linear(self.hidden_size, self.hidden_size)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
head_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
bsz, seq_len, embed_dim = hidden_states.size()
mixed_qkv = self.query_key_value(hidden_states)
mixed_qkv = mixed_qkv.reshape(bsz, seq_len, self.num_heads, 3, embed_dim // self.num_heads).permute(
3, 0, 2, 1, 4
) # [3, b, np, sq, hn]
query_states, key_states, value_states = (
mixed_qkv[0],
mixed_qkv[1],
mixed_qkv[2],
)
# if self.config.use_flash_attn and flash_attn_func is not None:
if False:
# [b*sq, np, hn]
query_states = query_states.permute(0, 2, 1, 3).contiguous()
query_states = query_states.view(query_states.size(0) * query_states.size(1), query_states.size(2), -1)
key_states = key_states.permute(0, 2, 1, 3).contiguous()
key_states = key_states.view(key_states.size(0) * key_states.size(1), key_states.size(2), -1)
value_states = value_states.permute(0, 2, 1, 3).contiguous()
value_states = value_states.view(value_states.size(0) * value_states.size(1), value_states.size(2), -1)
cu_seqlens = torch.arange(
0, (bsz + 1) * seq_len, step=seq_len, dtype=torch.int32, device=query_states.device
)
context_layer = flash_attn_func(
query_states,
key_states,
value_states,
cu_seqlens,
cu_seqlens,
seq_len,
seq_len,
self.dropout if self.training else 0.0,
softmax_scale=self.scale,
causal=False,
return_attn_probs=False,
)
# [b*sq, np, hn] => [b, sq, np, hn]
context_layer = context_layer.view(bsz, seq_len, context_layer.size(1), context_layer.size(2))
else:
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
attention_scores = attention_scores * self.scale
# Normalize the attention scores to probabilities.
attention_probs = torch.softmax(attention_scores, dim=-1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
# Mask heads if we want to
if head_mask is not None:
attention_probs = attention_probs * head_mask
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
context_layer = context_layer.reshape(new_context_layer_shape)
output = self.dense(context_layer)
outputs = (output, attention_probs) if output_attentions else (output, None)
return outputs
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class MplugOwlMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = QuickGELU()
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class MplugOwlVisionEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MplugOwlVisionAttention(config)
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
self.mlp = MplugOwlMLP(config)
self.post_attention_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(config.encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
head_mask=attention_mask,
output_attentions=output_attentions,
)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states + residual
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class MplugOwlVisionEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`MplugOwlVisionEncoderLayer`].
Args:
config (`MplugOwlVisionConfig`):
The corresponding vision configuration for the `MplugOwlEncoder`.
"""
def __init__(self, config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([MplugOwlVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = True
def forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
class MplugOwlVisionModel(PreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config):
super().__init__(config)
self.config = config
self.hidden_size = config.hidden_size
self.embeddings = MplugOwlVisionEmbeddings(config)
self.encoder = MplugOwlVisionEncoder(config)
self.post_layernorm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
self.post_init()
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
hidden_states = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.post_layernorm(last_hidden_state)
pooled_output = last_hidden_state[:, 0, :]
pooled_output = self.post_layernorm(pooled_output)
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.embeddings
class MplugDocOwlHReducerModel(PreTrainedModel):
def __init__(self, config, language_hidden_size):
super().__init__(config)
self.config = config
self.ln_q = torch.nn.LayerNorm(self.config.hidden_size, eps=1e-6)
self.conv_shape = (int(self.config.conv_shape.split('x')[0]), int(self.config.conv_shape.split('x')[1])) #
self.conv_patch=self.conv_shape[0]*self.conv_shape[1]
## feature interaction with a conv layer
self.reducer_before = torch.nn.Sequential(
nn.Conv2d(self.config.hidden_size, self.conv_patch*self.config.hidden_size, kernel_size=self.conv_shape, stride=self.conv_shape, bias=True),
nn.GELU()
)
## reduce visual feature length with a conv layer
self.reducer = nn.Conv2d(self.config.hidden_size, self.config.hidden_size, kernel_size=self.conv_shape, stride=self.conv_shape, bias=True)
## align visual features with language embedding with fc
self.visual_fc = torch.nn.Linear(self.config.hidden_size, language_hidden_size)
self.vit_eos = torch.nn.Parameter(torch.randn(1, 1, language_hidden_size))
self.post_init()
def forward(
self,
encoder_hidden_states=None
):
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
batch_size is the number of all images (global+crop) in a batch
Sequence of hidden-states at the output of the last layer of the encoder.
"""
encoder_hidden_states = encoder_hidden_states[:,1:,:] # remove the first cls token
B, L, C = encoder_hidden_states.shape # B, 1024=(448/14)^2, 1024
## feature interaction with a conv layer
encoder_hidden_states = rearrange(encoder_hidden_states, 'B (H W) D -> B D H W', H=int(math.sqrt(L)))
hidden_states = self.reducer_before(encoder_hidden_states) # B 4D H W/4
## reduce seq length with a conv layer
hidden_states = rearrange(hidden_states, 'B (X D) H W -> B D H (W X)', X=self.conv_patch) # B 4D H W/4 -> B D H W
sequence_output = self.reducer(hidden_states) # B,C,H,W -> B,C,H/conv_shape[0],W/(conv_shape[1])
sequence_output = sequence_output.flatten(2).transpose(1, 2) # B,C,H/conv_shape[0],W/(conv_shape[1]) -> B,C,L/conv_patch -> B,L/conv_patch,C
sequence_output = sequence_output.transpose(0, 1).contiguous() # L/conv_patch, B, C
## align visual features with language embedding with fc
sequence_output = self.visual_fc(sequence_output) # L/conv_patch, B, h
sequence_output = sequence_output.transpose(0, 1).contiguous() # B, s/4, h
sequence_output = torch.cat([sequence_output, self.vit_eos.repeat(B, 1, 1)], dim=1)
return sequence_output
from einops import rearrange, repeat
import torch
from torchvision import transforms
from PIL import Image, ImageFile
import random
from torchvision.ops.boxes import box_area
from torchvision.transforms.transforms import InterpolationMode
from torchvision.transforms import functional as F
import numpy as np
from icecream import ic
ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None
def box_iou(boxes1, area1, boxes2, eps=1e-5):
area2 = box_area(boxes2)
lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
wh = (rb - lt).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
union = area1[:, None] + area2 - inter
iou = inter / (union+eps)
return iou, union
def anchor_rank(anchors, anchors_areas, input_image_size, eps=1e-5):
# anchors x1 y1 x2 y2
# image_size: (h, w)
# xyxy
input_image_bbox = torch.tensor([0, 0, input_image_size[1], input_image_size[0]]).unsqueeze(0)
boxes1 = anchors
boxes2 = input_image_bbox
boxes3 = anchors.clone()
# y2
boxes3[:,3] = input_image_size[0]/input_image_size[1]*anchors[:,2] # 用于算分辨率无关的iou
area1 = anchors_areas
iou, _ = box_iou(boxes1, area1, boxes2)
iou = iou.squeeze(1)
shape_iou, _ = box_iou(boxes1, area1, boxes3)
shape_iou = shape_iou.diag()
# 优先匹配形状接近 再匹配分辨率接近
index = torch.argmax(shape_iou*100+iou,dim=0)
return index
class AnchorResize(torch.nn.Module):
def __init__(self, image_size, anchors, interpolation=InterpolationMode.BILINEAR, antialias=None):
super().__init__()
# xyxy
self.anchors = torch.tensor(
[[0, 0, _[1]*image_size[1], _[0]*image_size[0]]
for _ in anchors], requires_grad=False
)
self.anchor_areas = box_area(self.anchors)
self.interpolation = interpolation
self.antialias = antialias
def forward(self, img, skip_resize=False):
"""
Args:
img (PIL Image or Tensor): Image to be scaled.
Returns:
PIL Image or Tensor: Rescaled image.
"""
selected_anchor = anchor_rank(self.anchors, self.anchor_areas, (img.size[1], img.size[0]))
target_size = self.anchors[selected_anchor][2:].tolist() # w,h
if skip_resize:
# for debug
return selected_anchor
return F.resize(img, [target_size[1],target_size[0]], self.interpolation, max_size=None, antialias=self.antialias), selected_anchor
def __repr__(self) -> str:
detail = f"(size={self.image_size}, anchor={self.anchors}, interpolation={self.interpolation.value}, antialias={self.antialias})"
return f"{self.__class__.__name__}{detail}"
grid_dict = {
'grid_1':[
(1,1)],
'grid_4':[
(1,1),
(1,2),(2,1),
(1,3),(3,1),
(2,2),(1,4),(4,1)],
'grid_9':[
(1,1),
(1,2),(2,1),
(1,3),(3,1),
(2,2),(1,4),(4,1),
(1,5),(5,1),
(1,6),(6,1),(2,3),(3,2),
(1,7),(7,1),
(4,2),(2,4),(1,8),(8,1),
(3,3),(1,9),(9,1)],
'grid_3x3':[
(3,3)],
'grid_20':[
(1, 1),
(1, 2), (2, 1),
(1, 3), (3, 1), (1, 4), (2, 2), (4, 1),
(1, 5), (5, 1),
(1, 6), (2, 3), (3, 2), (6, 1),
(1, 7), (7, 1),
(1, 8), (2, 4), (4, 2), (8, 1),
(1, 9), (3, 3), (9, 1),
(1, 10), (2, 5), (5, 2), (10, 1),
(1, 11), (11, 1),
(2, 6), (3, 4), (4, 3), (6, 2),
(2, 7), (7, 2),
(3, 5), (5, 3),
(2, 8), (4, 4), (8, 2),
(2, 9), (3, 6), (6, 3), (9, 2),
(2, 10), (4, 5), (5, 4), (10, 2)]
}
class DocProcessor():
def __init__(self, image_size=224, anchors='grid_9', add_global_img=True, add_textual_crop_indicator=False):
self.add_global_img = add_global_img
self.add_textual_crop_indicator = add_textual_crop_indicator
self.media_token= "<|image|>"
# h,w
if isinstance(image_size, int):
image_size = (image_size, image_size)
self.image_size = image_size
# h,w
anchors = grid_dict[anchors]
self.anchors = [tuple(_) for _ in anchors]
self.anchor_max = max([max(_) for _ in self.anchors])
# xywh -> xyxy
self.resizer = AnchorResize(image_size=image_size, anchors=anchors, interpolation=InterpolationMode.BICUBIC)
self.old_resizer = transforms.Resize(image_size,interpolation=InterpolationMode.BICUBIC)
self.image_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def _process_image(self, images):
new_images = []
new_patch_position = []
num_image_mult = []
for image in images:
if self.add_global_img:
nocut_image = self.image_transform(self.old_resizer(image)).unsqueeze(0)
image, selected_anchor = self.resizer(image)
image_input = self.image_transform(image) # h,w,3 -> 3,h,w
# rearrange(x,'B C (n1 h) (n2 w) -> (B n1 n2) C h w', n1=self.down_sample[0], n2=self.down_sample[1])
image_input = rearrange(image_input, 'C (num_h h) (num_w w) -> (num_h num_w) C h w', h=self.image_size[0], w=self.image_size[1])
if self.add_global_img:
image_input = torch.cat([nocut_image, image_input], dim=0)
anchor = self.anchors[selected_anchor] # w,h
patch_position = torch.cat([
repeat(torch.arange(anchor[0]), 'num_h -> num_h num_w 1', num_w=anchor[1]),
repeat(torch.arange(anchor[1]), 'num_w -> num_h num_w 1', num_h=anchor[0])],dim=2)
patch_position = rearrange(patch_position, 'num_h num_w p-> (num_h num_w) p', p=2) # num_patch, (ph,pw)
if self.add_global_img:
patch_position = torch.cat([torch.ones(1,2).long()*self.anchor_max, patch_position], dim=0)
new_images.append(image_input)
new_patch_position.append(patch_position)
num_image_mult.append(patch_position.shape[0])
new_images = torch.cat(new_images,dim=0)
new_patch_position = torch.cat(new_patch_position, dim=0)
return new_images, new_patch_position, num_image_mult
def __call__(self, images=None, query=None):
assert images is not None
if not isinstance(images, list):
images = [images]
image_pils = []
for image in images:
if isinstance(image, str):
image = Image.open(image).convert('RGB')
else:
image = image.convert('RGB')
# ic(image.size)
image_pils.append(image)
image_data, patch_position, num_image_mult = self._process_image(image_pils)
assert self.media_token in query
text_list = query.split(self.media_token)
text = text_list[0]
image_token_ptr = 0
for next_text in text_list[1:]:
if self.add_textual_crop_indicator:
# generate image placeholders with interleaved texutual crop indicator
# e.g. <global_img><|image|><crop_img_row0_col0><|image|><crop_img_row0_col1><|image|>...
for patch_pos in patch_position.tolist():
# global non-crop image
if patch_pos[0] == self.anchor_max and patch_pos[1] == self.anchor_max:
text += '<global_img><|image|>'
else:
row_col = 'row'+str(patch_pos[0])+'_col'+str(patch_pos[1])
text += '<crop_img_'+row_col+'><|image|>'
else:
# generate successive image placeholders for a image, 1 crop img == 1 <|image|>
text += '<|image|>'*num_image_mult[image_token_ptr]
text += next_text
image_token_ptr += 1
return image_data, patch_position, text
\ No newline at end of file
from typing import Optional, Tuple
import warnings
import torch
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import unpad_input, pad_input
def forward(
self,
hidden_states: torch.Tensor,
modality_indicators: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: bool = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states, modality_indicators)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states, modality_indicators)
.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
.transpose(1, 2)
) # shape: (b, num_heads, s, head_dim)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
if past_key_value is not None:
# reuse k, v
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Transform the data into the format required by flash attention
qkv = torch.stack([query_states, key_states, value_states], dim=2)
qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
max_s = q_len
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = output.view(bsz, q_len, -1)
else:
qkv = qkv.reshape(bsz, q_len, -1)
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
output_unpad = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len)
return self.o_proj(output), None, past_key_value
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
\ No newline at end of file
import os
import torch
from torch.utils.data import Sampler
from transformers import Trainer
from transformers.trainer import (
is_sagemaker_mp_enabled,
get_parameter_names,
has_length,
ALL_LAYERNORM_LAYERS,
ShardedDDPOption,
logger,
)
from typing import List, Optional
from icecream import ic
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
print(name, 'no ignore status')
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
return to_return
def split_to_even_chunks(indices, lengths, num_chunks):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
return [indices[i::num_chunks] for i in range(num_chunks)]
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float("inf")
return chunks
def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
assert all(l != 0 for l in lengths), "Should not have zero length."
if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
# all samples are in the same modality
return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
megabatch_size = world_size * batch_size
mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
last_mm = mm_megabatches[-1]
last_lang = lang_megabatches[-1]
additional_batch = last_mm + last_lang
megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
megabatch_indices = torch.randperm(len(megabatches), generator=generator)
megabatches = [megabatches[i] for i in megabatch_indices]
if len(additional_batch) > 0:
megabatches.append(sorted(additional_batch))
return [i for megabatch in megabatches for i in megabatch]
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = world_size * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
world_size: int,
lengths: Optional[List[int]] = None,
generator=None,
group_by_modality: bool = False,
):
if lengths is None:
raise ValueError("Lengths must be provided.")
self.batch_size = batch_size
self.world_size = world_size
self.lengths = lengths
self.generator = generator
self.group_by_modality = group_by_modality
def __len__(self):
return len(self.lengths)
def __iter__(self):
if self.group_by_modality:
indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
else:
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
return iter(indices)
class MPLUGDocOwlTrainer(Trainer):
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
if self.args.group_by_modality_length:
lengths = self.train_dataset.modality_lengths
return LengthGroupedSampler(
self.args.train_batch_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
lengths=lengths,
group_by_modality=True,
)
else:
return super()._get_train_sampler()
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if is_sagemaker_mp_enabled():
return super().create_optimizer()
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
return super().create_optimizer()
opt_model = self.model
if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name]
if self.args.vision2text_lr is not None:
projector_parameters = [name for name, _ in opt_model.named_parameters() if "vision2text_lr" in name]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
"lr": self.args.vision2text_lr,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"lr": self.args.vision2text_lr,
},
]
else:
optimizer_grouped_parameters = [
{
"params": [
p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
ic(len(optimizer_grouped_parameters[0]['params']),len(optimizer_grouped_parameters[1]['params']))
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
**optimizer_kwargs,
)
else:
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(module, "weight", {"optim_bits": 32})
logger.debug(f"bitsandbytes: will optimize {module} in fp32")
logger.info(f"skipped: {skipped/2**20}M params")
return self.optimizer
def _save_checkpoint(self, model, trial, metrics=None):
super(MPLUGDocOwlTrainer, self)._save_checkpoint(model, trial, metrics)
def _save(self, output_dir: Optional[str] = None, state_dict=None):
super(MPLUGDocOwlTrainer, self)._save(output_dir, state_dict)
\ No newline at end of file
# Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
# Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
## if you fail to import mplug_docowl, try adding absolute path of the mPLUG-DocOwl1.5 directory
import os
import copy
from dataclasses import dataclass, field
import json
import jsonlines
import logging
import pathlib
from typing import Dict, Optional, Sequence, List
import torch
import transformers
from transformers.models.clip.image_processing_clip import CLIPImageProcessor
from torch.utils.data import Dataset
from PIL import Image
from icecream import ic
import sys
sys.path.append('/home/wanglch/projects/mPLUG-DocOwl1.5-Omni')
print(sys.path)
# Need to call this before importing transformers.
from mplug_docowl.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
from mplug_docowl.train.mplug_docowl_trainer import MPLUGDocOwlTrainer
from mplug_docowl.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from mplug_docowl import conversation as conversation_lib
from mplug_docowl.model import *
from mplug_docowl.mm_utils import tokenizer_image_token
from mplug_docowl.processor import DocProcessor
local_rank = None
def read_jsonl(filename):
lines = []
with open(filename, 'r', encoding='utf-8') as f:
for line in jsonlines.Reader(f):
lines.append(line)
return lines
def rank0_print(*args):
if local_rank == 0:
print(*args)
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
version: Optional[str] = field(default="v0")
freeze_backbone: bool = field(default=False)
@dataclass
class DataArguments:
data_path: str = field(default=None,
metadata={"help": "Path to the training data."})
lazy_preprocess: bool = False
is_multimodal: bool = False
image_folder: Optional[str] = field(default=None)
image_aspect_ratio: str = 'square'
image_grid_pinpoints: Optional[str] = field(default=None)
image_size: int = 448
crop_anchors: str = 'grid_9'
add_global_img: bool = True
add_textual_crop_indicator: bool = True
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
remove_unused_columns: bool = field(default=False)
tune_vision2text: bool = field(default=True)
freeze_vision_model: bool = field(default=True)
model_max_length: int = field(
default=512,
metadata={
"help":
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
double_quant: bool = field(
default=True,
metadata={"help": "Compress the quantization statistics through double quantization."}
)
quant_type: str = field(
default="nf4",
metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
)
bits: int = field(
default=16,
metadata={"help": "How many bits to use."}
)
lora_enable: bool = False
lora_r: int = 64
lora_alpha: int = 16
lora_dropout: float = 0.05
lora_weight_path: str = ""
lora_bias: str = "none"
vision2text_lr: Optional[float] = None
group_by_modality_length: bool = field(default=False)
def maybe_zero_3(param, ignore_status=False, name=None):
from deepspeed import zero
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
if hasattr(param, "ds_id"):
if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if not ignore_status:
logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
with zero.GatheredParameters([param]):
param = param.data.detach().cpu().clone()
else:
param = param.detach().cpu().clone()
return param
# Borrowed from peft.utils.get_peft_model_state_dict
def get_peft_state_maybe_zero_3(named_params, bias):
if bias == "none":
to_return = {k: t for k, t in named_params if "lora_" in k}
elif bias == "all":
to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = {}
maybe_lora_bias = {}
lora_bias_names = set()
for k, t in named_params:
if "lora_" in k:
to_return[k] = t
bias_name = k.split("lora_")[0] + "bias"
lora_bias_names.add(bias_name)
elif "bias" in k:
maybe_lora_bias[k] = t
for k, t in maybe_lora_bias:
if bias_name in lora_bias_names:
to_return[bias_name] = t
else:
raise NotImplementedError
to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()}
return to_return
def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
to_return = {k: t for k, t in named_params if "lora_" not in k}
if require_grad_only:
to_return = {k: t for k, t in to_return.items() if t.requires_grad}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
return to_return
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ['vision_model', 'vision2text']
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
lora_module_names.add(name)
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
output_dir: str):
"""Collects the state dict and dump to disk."""
if trainer.deepspeed:
torch.cuda.synchronize()
trainer.save_model(output_dir)
return
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {
key: value.cpu()
for key, value in state_dict.items()
}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _tokenize_fn(strings: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
) for text in strings
]
input_ids = labels = [
tokenized.input_ids[0] for tokenized in tokenized_list
]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def _mask_targets(target, tokenized_lens, speakers):
# cur_idx = 0
cur_idx = tokenized_lens[0]
tokenized_lens = tokenized_lens[1:]
target[:cur_idx] = IGNORE_INDEX
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker == "human":
target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
def _add_speaker_and_signal(header, source, get_conversation=True):
"""Add speaker and start/end signal on each round."""
BEGIN_SIGNAL = "### "
END_SIGNAL = "\n"
conversation = header
for sentence in source:
from_str = sentence["from"]
if from_str.lower() == "human":
from_str = conversation_lib.default_conversation.roles[0]
elif from_str.lower() == "gpt":
from_str = conversation_lib.default_conversation.roles[1]
else:
from_str = 'unknown'
sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
sentence["value"] + END_SIGNAL)
if get_conversation:
conversation += sentence["value"]
conversation += BEGIN_SIGNAL
return conversation
"""def preprocess_multimodal(
sources: Sequence[str],
data_args: DataArguments
) -> Dict:
is_multimodal = data_args.is_multimodal
if not is_multimodal:
return sources
for source in sources:
for sentence in source:
if DEFAULT_IMAGE_TOKEN in sentence['value']:
sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value']
sentence['value'] = sentence['value'].strip()
replace_token = DEFAULT_IMAGE_TOKEN
sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
return sources"""
def docowl_text_preprocess_v1(
source,
tokenizer: transformers.PreTrainedTokenizer,
has_image: bool = False
) -> Dict:
"""
source: list of {'role':'user'/'assistant', 'content':xxxx}
"""
conv = conversation_lib.default_conversation.copy()
# conv.roles: ("USER", "ASSISTANT")
roles = {"user": conv.roles[0], "assistant": conv.roles[1]}
# Apply prompt templates
conversations = []
# Skip the first one if it is not from human
if roles[source[0]["role"]] != conv.roles[0]:
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence["role"]]
assert role == conv.roles[j % 2]
conv.append_message(role, sentence["content"])
# conv.get_prompt(): USER: {content} ASSISTANT: {content}</s>USER: {content} ASSISTANT: {content}</s>...
conversations.append(conv.get_prompt())
# Tokenize conversations
if has_image:
input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
else:
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS
# Mask targets
sep = conv.sep + conv.roles[1] + ": " # ' ASSISTANT: '
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep2) # split by </s>
cur_len = 1
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(rounds):
if rou == "":
break
parts = rou.split(sep) # split each round by ' ASSISTANT: '
if len(parts) != 2:
break
parts[0] += sep # input query, ignore for loss
if has_image:
round_len = len(tokenizer_image_token(rou, tokenizer))
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2
else:
round_len = len(tokenizer(rou).input_ids)
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
cur_len += round_len
target[cur_len:] = IGNORE_INDEX
if cur_len < tokenizer.model_max_length: # ignore padding
if cur_len != total_len:
target[:] = IGNORE_INDEX
print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)"
)
return dict(
input_ids=input_ids,
labels=targets,
)
def preprocess_plain(
sources: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
# add end signal and concatenate together
conversations = []
for source in sources:
assert len(source) == 2
assert DEFAULT_IMAGE_TOKEN in source[0]['value']
source[0]['value'] = DEFAULT_IMAGE_TOKEN
conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep
conversations.append(conversation)
# tokenize conversations
input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
target[:tokenized_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=targets)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str,
tokenizer: transformers.PreTrainedTokenizer,
data_args: DataArguments):
super(LazySupervisedDataset, self).__init__()
list_data_dict = read_jsonl(data_path)
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.list_data_dict = list_data_dict
self.data_args = data_args
self.doc_image_processor = DocProcessor(image_size=self.data_args.image_size,
anchors=self.data_args.crop_anchors,
add_global_img=self.data_args.add_global_img,
add_textual_crop_indicator=self.data_args.add_textual_crop_indicator)
def __len__(self):
return len(self.list_data_dict)
"""@property
def lengths(self):
length_list = []
for sample in self.list_data_dict:
img_tokens = 128 if 'image' in sample else 0
length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens)
return length_list"""
"""@property
def modality_lengths(self):
length_list = []
for sample in self.list_data_dict:
cur_len = sum(len(conv['value'].split()) for conv in sample['conversations'])
cur_len = cur_len if 'image' in sample else -cur_len
length_list.append(cur_len)
return length_list"""
def next_rand(self):
import random
return random.randint(0,len(self)-1)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
while True:
source = self.list_data_dict[i]
assert 'image' in source
image_file = source['image'][0]
image_folder = self.data_args.image_folder
from pathlib import Path
if not Path(os.path.join(image_folder, image_file)).exists():
i = self.next_rand()
continue
## crop image and revise query acccording to the cropping shape
image_path=os.path.join(image_folder, image_file)
query = source['messages'][0]['content'] # e.g. <|image|>what is the contact person name mentioned in letter?."
# image: tensor, [N_crops+1, c, h, w]
# patch_positions: tensor, [N_crops+1, 2]
# text: string, add textual crop indicators and repeat the <|image|> N_crops+1 times
image, patch_positions, processed_query = self.doc_image_processor(images=image_path, query=query)
source['messages'][0]['content'] = processed_query
## text tokenization
data_dict = docowl_text_preprocess_v1(
source['messages'],
self.tokenizer,
has_image=('image' in self.list_data_dict[i]))
if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
# image exist in the data
data_dict['image'] = image
data_dict['patch_positions'] = patch_positions
return data_dict
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=IGNORE_INDEX)
input_ids = input_ids[:, :self.tokenizer.model_max_length]
labels = labels[:, :self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
if 'image' in instances[0]:
images = [instance['image'] for instance in instances]
batch['images'] = torch.cat(images) # Sum(Crop+1) x c x h x w
if 'patch_positions' in instances[0]:
patch_positions = [instance['patch_positions'] for instance in instances]
batch['patch_positions'] = torch.cat(patch_positions) # Sum(Crop+1) x 2
return batch
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset,
eval_dataset=None,
data_collator=data_collator)
def train():
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
bnb_model_from_pretrained_args = {}
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
)
))
#
model = MPLUGDocOwlLlamaForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
**bnb_model_from_pretrained_args
)
model.config.use_cache = False
ic(model_args.freeze_backbone)
if model_args.freeze_backbone:
model.model.requires_grad_(False)
if training_args.bits in [4, 8]:
from peft import prepare_model_for_kbit_training
model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
if training_args.gradient_checkpointing:
if hasattr(model, "enable_input_require_grads"): # default true
model.enable_input_require_grads()
# print('model.enable_input_require_grads()')
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
if training_args.lora_enable:
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
r=training_args.lora_r,
lora_alpha=training_args.lora_alpha,
target_modules=find_all_linear_names(model),
lora_dropout=training_args.lora_dropout,
bias=training_args.lora_bias,
task_type="CAUSAL_LM",
)
if training_args.bits == 16:
if training_args.bf16:
model.to(torch.bfloat16)
if training_args.fp16:
model.to(torch.float16)
rank0_print("Adding LoRA adapters...")
model = get_peft_model(model, lora_config)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
if model_args.version in conversation_lib.conv_templates:
conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version]
else:
conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1"]
if not training_args.freeze_vision_model and training_args.bits in [4, 8]:
model.get_model().vision_model.to(dtype=compute_dtype, device=training_args.device)
else:
vision_tower = model.get_model().vision_model
vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
if training_args.tune_vision2text and training_args.bits in [4, 8]:
model.get_model().vision2text.to(dtype=compute_dtype, device=training_args.device)
else:
vision2text = model.get_model().vision2text
vision2text.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device)
data_args.is_multimodal = True
model.config.image_aspect_ratio = data_args.image_aspect_ratio
model.config.image_grid_pinpoints = data_args.image_grid_pinpoints
model.config.tune_vision2text = model_args.tune_vision2text = training_args.tune_vision2text
ic(training_args.tune_vision2text)
# model.requires_grad_(True)
if training_args.tune_vision2text:
# model.requires_grad_(False)
for p in model.get_model().vision2text.parameters():
p.requires_grad = True
model.config.freeze_vision_model = training_args.freeze_vision_model
ic(training_args.freeze_vision_model)
if training_args.freeze_vision_model:
for p in model.get_model().vision_model.parameters():
p.requires_grad = False
model.config.vision2text_lr = training_args.vision2text_lr
if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)
data_module = make_supervised_data_module(tokenizer=tokenizer,
data_args=data_args)
trainer = MPLUGDocOwlTrainer(model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
trainer.train()
trainer.save_state()
model.config.use_cache = True
if training_args.lora_enable:
state_dict = get_peft_state_maybe_zero_3(
model.named_parameters(), training_args.lora_bias
)
non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3(
model.named_parameters()
)
if training_args.local_rank == 0 or training_args.local_rank == -1:
model.config.save_pretrained(training_args.output_dir)
model.save_pretrained(training_args.output_dir, state_dict=state_dict)
torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin'))
else:
safe_save_model_for_hf_trainer(trainer=trainer,
output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment