Commit 1d5a34cf authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License atd
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Phi-3 model configuration"""
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
PHI3_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'microsoft/Phi-3-mini-4k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/resolve/main/config.json',
'microsoft/Phi-3-mini-128k-instruct': 'https://huggingface.co/microsoft/Phi-3-mini-128k-instruct/resolve/main/config.json',
}
class Phi3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the
[microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct).
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 32064):
Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`Phi3Model`].
hidden_size (`int`, *optional*, defaults to 3072):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 8192):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
resid_pdrop (`float`, *optional*, defaults to 0.0):
Dropout probability for mlp outputs.
embd_pdrop (`int`, *optional*, defaults to 0.0):
The dropout ratio for the embeddings.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio after computing the attention scores.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used with.
original_max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model was trained with. This is used to determine the size of the
original RoPE embeddings when using long scaling.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon value used for the RMSNorm.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must
contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be either `su` or `yarn` and
the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size
divided by the number of attention heads divided by 2.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 32000):
The id of the "end-of-sequence" token.
pad_token_id (`int`, *optional*, defaults to 32000):
The id of the padding token.
sliding_window (`int`, *optional*):
Sliding window attention window size. If `None`, no sliding window is applied.
Example:
```python
>>> from transformers import Phi3Model, Phi3Config
>>> # Initializing a Phi-3 style configuration
>>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
>>> # Initializing a model from the configuration
>>> model = Phi3Model(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = 'phi3'
keys_to_ignore_at_inference = ['past_key_values']
def __init__(
self,
vocab_size=32064,
hidden_size=3072,
intermediate_size=8192,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
resid_pdrop=0.0,
embd_pdrop=0.0,
attention_dropout=0.0,
hidden_act='silu',
max_position_embeddings=4096,
original_max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
bos_token_id=1,
eos_token_id=32000,
pad_token_id=32000,
sliding_window=None,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
self.sliding_window = sliding_window
super().__init__(
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3:
raise ValueError(
'`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, '
f'got {self.rope_scaling}'
)
rope_scaling_type = self.rope_scaling.get('type', None)
rope_scaling_short_factor = self.rope_scaling.get('short_factor', None)
rope_scaling_long_factor = self.rope_scaling.get('long_factor', None)
if rope_scaling_type is None or rope_scaling_type not in ['su', 'yarn']:
raise ValueError(f"`rope_scaling`'s type field must be one of ['su', 'yarn'], got {rope_scaling_type}")
if not (
isinstance(rope_scaling_short_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
):
raise ValueError(
f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
)
if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
raise ValueError(
f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
)
if not (
isinstance(rope_scaling_long_factor, list)
and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
):
raise ValueError(
f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
)
if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
raise ValueError(
f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
)
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch Phi-3 model."""
import inspect
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import \
_prepare_4d_causal_attention_mask
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
TokenClassifierOutput)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10, logging,
replace_return_docstrings)
from .configuration_phi3 import Phi3Config
logger = logging.get_logger(__name__)
# Transformers scans dependencies in the modeling file, causing issues on conditional loading. The regex only ignores try/catch blocks, but not if statements
# if is_flash_attn_2_available():
_flash_supports_window_size = False
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import (index_first_axis, pad_input, # noqa
unpad_input)
_flash_supports_window_size = 'window_size' in list(inspect.signature(flash_attn_func).parameters)
except ImportError as error:
logger.warning(
f'`flash-attention` package not found, consider installing for better performance: {error}.'
)
if not _flash_supports_window_size:
logger.warning(
"Current `flash-attenton` does not support `window_size`. Either upgrade or use `attn_implementation='eager'`."
)
_CHECKPOINT_FOR_DOC = 'microsoft/Phi-3-mini-4k-instruct'
_CONFIG_FOR_DOC = 'Phi3Config'
PHI3_PRETRAINED_MODEL_ARCHIVE_LIST = [
'microsoft/Phi-3-mini-4k-instruct',
'microsoft/Phi-3-mini-128k-instruct',
# See all Phi-3 models at https://huggingface.co/models?filter=Phi-3
]
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3
class Phi3RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Phi3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3
class Phi3RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.register_buffer('inv_freq', None, persistent=False)
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if self.inv_freq is None:
self.inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim)
)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(self, dim, config, device=None):
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
self.short_factor = config.rope_scaling['short_factor']
self.long_factor = config.rope_scaling['long_factor']
self.original_max_position_embeddings = config.original_max_position_embeddings
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
seq_len = torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding):
def __init__(self, dim, config, device=None):
super().__init__(dim, config.max_position_embeddings, config.rope_theta, device)
self.short_factor = config.rope_scaling['short_factor']
self.long_factor = config.rope_scaling['long_factor']
self.original_max_position_embeddings = config.original_max_position_embeddings
@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
seq_len = torch.max(position_ids) + 1
if seq_len > self.original_max_position_embeddings:
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
else:
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device)
inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim
self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu'
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
scale = self.max_position_embeddings / self.original_max_position_embeddings
if scale <= 1.0:
scaling_factor = 1.0
else:
scaling_factor = 0.1 * math.log(scale) + 1.0
cos = emb.cos() * scaling_factor
sin = emb.sin() * scaling_factor
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Phi3MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.activation_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
up_states = self.gate_up_proj(hidden_states)
gate, up_states = up_states.chunk(2, dim=-1)
up_states = up_states * self.activation_fn(gate)
return self.down_proj(up_states)
# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class Phi3Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f'Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will '
'lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` '
'when creating this class.'
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.original_max_position_embeddings = config.original_max_position_embeddings
self.rope_theta = config.rope_theta
self.rope_scaling = config.rope_scaling
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f'hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}'
f' and `num_heads`: {self.num_heads}).'
)
op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False)
self._init_rope()
def _init_rope(self):
if self.rope_scaling is None:
self.rotary_emb = Phi3RotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling['type']
if scaling_type == 'su':
self.rotary_emb = Phi3SuScaledRotaryEmbedding(self.head_dim, self.config)
elif scaling_type == 'yarn':
self.rotary_emb = Phi3YarnScaledRotaryEmbedding(self.head_dim, self.config)
else:
raise ValueError(f'Unknown RoPE scaling type {scaling_type}')
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
logger.warning_once('You are not running the flash-attention implementation, expect numerical differences.')
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)
query_pos = self.num_heads * self.head_dim
query_states = qkv[..., :query_pos]
key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
'with a layer index.'
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f'Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is'
f' {attn_weights.size()}'
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
f' {attn_output.size()}'
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class Phi3FlashAttention2(Phi3Attention):
"""
Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Phi3FlashAttention2 attention does not support output_attentions
if not _flash_supports_window_size:
logger.warning_once(
"The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library."
)
raise ValueError('The current flash attention version does not support sliding window attention.')
output_attentions = False
if 'padding_mask' in kwargs:
warnings.warn(
'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
)
# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop('padding_mask')
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)
query_pos = self.num_heads * self.head_dim
query_states = qkv[..., :query_pos]
key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} '
'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class '
'with a layer index.'
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
# Because the input can be padded, the absolute sequence length depends on the max position id.
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
use_sliding_windows = (
_flash_supports_window_size
and getattr(self.config, 'sliding_window', None) is not None
and kv_seq_len > self.config.sliding_window
)
if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, 'sliding_window', None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f'past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got'
f' {past_key.shape}'
)
if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_dropout = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32.
if query_states.dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, '_pre_quantization_dtype'):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.qkv_proj.weight.dtype
logger.warning_once(
f'The input hidden states seems to be silently casted in float32, this might be related to'
f' the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in'
f' {target_dtype}.'
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = self._flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=attn_dropout,
use_sliding_windows=use_sliding_windows,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._flash_attention_forward
def _flash_attention_forward(
self,
query_states,
key_states,
value_states,
attention_mask,
query_length,
dropout=0.0,
softmax_scale=None,
use_sliding_windows=False,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`float`):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
use_sliding_windows (`bool`, *optional*):
Whether to activate sliding window attention.
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
if not use_sliding_windows:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
else:
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
if not use_sliding_windows:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
)
else:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout,
softmax_scale=softmax_scale,
causal=causal,
window_size=(self.config.sliding_window, self.config.sliding_window),
)
return attn_output
# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2._upad_input
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape
# On the first iteration we need to properly re-create the padding mask
# by slicing it on the proper place
if kv_seq_len != attention_mask.shape[-1]:
attention_mask_num_tokens = attention_mask.shape[-1]
attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :]
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
# TODO @Arthur no longer copied from LLama after static cache
class Phi3SdpaAttention(Phi3Attention):
"""
Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from Phi3Attention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
logger.warning_once(
'Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, '
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
return super().forward(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
bsz, q_len, _ = hidden_states.size()
qkv = self.qkv_proj(hidden_states)
query_pos = self.num_heads * self.head_dim
query_states = qkv[..., :query_pos]
key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
cache_kwargs = {'sin': sin, 'cos': cos} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
)
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == 'cuda' and attention_mask is not None:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
value_states = value_states.contiguous()
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
is_causal=self.is_causal and attention_mask is None and q_len > 1,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
PHI3_ATTENTION_CLASSES = {
'eager': Phi3Attention,
'flash_attention_2': Phi3FlashAttention2,
'sdpa': Phi3SdpaAttention,
}
class Phi3DecoderLayer(nn.Module):
def __init__(self, config: Phi3Config, layer_idx: int):
super().__init__()
self.config = config
self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx)
self.mlp = Phi3MLP(config)
self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.resid_attn_dropout = nn.Dropout(config.resid_pdrop)
self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop)
self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if 'padding_mask' in kwargs:
warnings.warn(
'Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`'
)
"""
Args:
hidden_states (`torch.FloatTensor`):
input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
`[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
attn_outputs, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = residual + self.resid_attn_dropout(attn_outputs)
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + self.resid_mlp_dropout(hidden_states)
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
PHI3_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Phi3Config`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
'The bare Phi-3 model outputting raw hidden-states without any specific head on top.',
PHI3_START_DOCSTRING,
)
class Phi3PreTrainedModel(PreTrainedModel):
config_class = Phi3Config
base_model_prefix = 'model'
supports_gradient_checkpointing = True
_no_split_modules = ['Phi3DecoderLayer']
_skip_keys_device_placement = 'past_key_values'
_supports_flash_attn_2 = True
_supports_sdpa = False
_supports_cache_class = True
_version = '0.0.5'
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
PHI3_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance;
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@add_start_docstrings(
'The bare Phi-3 model outputting raw hidden-states without any specific head on top.',
PHI3_START_DOCSTRING,
)
class Phi3Model(Phi3PreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
Args:
config: Phi3Config
"""
def __init__(self, config: Phi3Config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.embed_dropout = nn.Dropout(config.embd_pdrop)
self.layers = nn.ModuleList(
[Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self._attn_implementation = config._attn_implementation
self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError('You have to specify either input_ids or inputs_embeds')
past_key_values_length = 0
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
'`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
)
use_cache = False
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if attention_mask is not None and self._attn_implementation == 'flash_attention_2' and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
raise ValueError(
"You are attempting to perform batched generation with padding_side='right'"
' this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to '
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
)
if self._attn_implementation == 'flash_attention_2':
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
sliding_window=self.config.sliding_window,
)
hidden_states = inputs_embeds
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None
for decoder_layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
class Phi3ForCausalLM(Phi3PreTrainedModel):
_tied_weights_keys = ['lm_head.weight']
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3
def __init__(self, config):
super().__init__(config)
self.model = Phi3Model(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings
def get_input_embeddings(self):
return self.model.embed_tokens
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings
def set_input_embeddings(self, value):
self.model.embed_tokens = value
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings
def get_output_embeddings(self):
return self.lm_head
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder
def set_decoder(self, decoder):
self.model = decoder
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder
def get_decoder(self):
return self.model
# Ignore copy
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, Phi3ForCausalLM
>>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")
>>> prompt = "This is an example script ."
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get('position_ids', None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids}
model_inputs.update(
{
'position_ids': position_ids,
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
}
)
return model_inputs
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
@add_start_docstrings(
"""
The [`Phi3Model`] with a sequence classification head on top (linear layer).
[`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
""",
PHI3_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs
class Phi3ForSequenceClassification(Phi3PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Phi3Model(config)
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
model_outputs = self.model(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = model_outputs[0]
logits = self.score(hidden_states)
if input_ids is not None:
batch_size = input_ids.shape[0]
else:
batch_size = inputs_embeds.shape[0]
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError('Cannot handle batch sizes > 1 if no padding token is defined.')
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
if input_ids is not None:
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
sequence_lengths = sequence_lengths % input_ids.shape[-1]
sequence_lengths = sequence_lengths.to(logits.device)
else:
sequence_lengths = -1
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = 'regression'
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = 'single_label_classification'
else:
self.config.problem_type = 'multi_label_classification'
if self.config.problem_type == 'regression':
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == 'single_label_classification':
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == 'multi_label_classification':
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + model_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutputWithPast(
loss=loss,
logits=pooled_logits,
past_key_values=model_outputs.past_key_values,
hidden_states=model_outputs.hidden_states,
attentions=model_outputs.attentions,
)
@add_start_docstrings(
"""
[`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
Named-Entity-Recognition (NER) tasks.
""",
PHI3_START_DOCSTRING,
)
# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs
class Phi3ForTokenClassification(Phi3PreTrainedModel):
def __init__(self, config: Phi3Config):
super().__init__(config)
self.num_labels = config.num_labels
self.model = Phi3Model(config)
if hasattr(config, 'classifier_dropout') and config.classifier_dropout is not None:
classifier_dropout = config.classifier_dropout
elif hasattr(config, 'hidden_dropout') and config.hidden_dropout is not None:
classifier_dropout = config.hidden_dropout
else:
classifier_dropout = 0.1
self.dropout = nn.Dropout(classifier_dropout)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
model_outputs = self.model(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = model_outputs[0]
hidden_states = self.dropout(hidden_states)
logits = self.classifier(hidden_states)
loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
batch_size, seq_length = labels.shape
loss_fct = CrossEntropyLoss()
loss = loss_fct(
logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
)
if not return_dict:
output = (logits,) + model_outputs[2:]
return ((loss,) + output) if loss is not None else output
return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=model_outputs.hidden_states,
attentions=model_outputs.attentions,
)
from .llama2_flash_attn_monkey_patch import replace_llama2_attn_with_flash_attn
from .llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
from .llama_rmsnorm_monkey_patch import \
replace_llama_rmsnorm_with_fused_rmsnorm
from .pad_data_collator import concat_pad_data_collator, pad_data_collator
from .train_sampler_patch import replace_train_sampler
__all__ = ['replace_llama_attn_with_flash_attn',
'replace_llama_rmsnorm_with_fused_rmsnorm',
'replace_llama2_attn_with_flash_attn',
'replace_train_sampler',
'pad_data_collator',
'concat_pad_data_collator']
"""
This file is copied from: https://github.com/lm-sys/FastChat
"""
import warnings
from typing import Optional, Tuple
import torch
from flash_attn import __version__ as flash_attn_version
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import (flash_attn_func,
flash_attn_varlen_kvpacked_func)
from transformers.models.llama.modeling_llama import (LlamaAttention,
LlamaModel, rotate_half)
def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
gather_indices = gather_indices.repeat(
1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
)
bsz = gather_indices.shape[0]
cos, sin = (
torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
for x in cos_sin
)
q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
return q, k
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions:
warnings.warn(
'Output attentions is not supported for patched `LlamaAttention`, returning `None` instead.'
)
bsz, q_len, _ = hidden_states.size()
kv_heads = getattr(self, 'num_key_value_heads', self.num_heads)
q, k, v = (
op(hidden_states).view(bsz, q_len, nh, self.head_dim)
for op, nh in (
(self.q_proj, self.num_heads),
(self.k_proj, kv_heads),
(self.v_proj, kv_heads),
)
)
# shape: (b, s, num_heads, head_dim)
kv_seq_len = k.shape[1]
past_kv_len = 0
if past_key_value is not None:
past_kv_len = past_key_value[0].shape[2]
kv_seq_len += past_kv_len
cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
if past_key_value is not None:
assert (
flash_attn_version >= '2.1.0'
), 'past_key_value support requires flash-attn >= 2.1.0'
# reuse k, v
k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
if attention_mask is None:
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
bsz, q_len, -1
)
else:
q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
# We can skip concat and call unpad twice but seems better to call unpad only once.
kv, _, cu_k_lens, max_k = unpad_input(
torch.stack((k, v), dim=2), attention_mask
)
output_unpad = flash_attn_varlen_kvpacked_func(
q,
kv,
cu_q_lens,
cu_k_lens,
max_s,
max_k,
0.0,
softmax_scale=None,
causal=True,
)
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
output = pad_input(output_unpad, indices, bsz, q_len)
return self.o_proj(output), None, past_key_value
# Disable the transformation of the attention mask in LlamaModel as flash attention
# takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
if past_key_values_length > 0 and attention_mask is not None:
attention_mask = torch.cat(
(
torch.full(
(input_shape[0], past_key_values_length),
True,
dtype=attention_mask.dtype,
device=attention_mask.device,
),
attention_mask,
),
dim=-1,
)
if attention_mask is not None and torch.all(attention_mask):
return None # This uses the faster call when training with full samples
return attention_mask
def replace_llama2_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
warnings.warn(
'Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward.'
'ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593'
)
LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
LlamaAttention.forward = forward
def test():
from fastchat.train.llama_flash_attn_monkey_patch import \
forward as fastchat_forward
from transformers.models.llama.configuration_llama import LlamaConfig
config = LlamaConfig(
hidden_size=1024,
intermediate_size=128,
num_hidden_layers=1,
num_attention_heads=8,
max_position_embeddings=16,
)
device = torch.device('cuda')
model = LlamaModel(config)
attn = LlamaAttention(config).to(device).half()
bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
-1, seqlen
)
mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
for i in range(4):
hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
if i:
mask[0, -i:] = False
mask[1, :i] = False
lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
ref, _, _ = attn.forward(
hidden, attention_mask=lmask, position_ids=position_ids
)
fast, _, _ = fastchat_forward(
attn, hidden, attention_mask=mask, position_ids=position_ids
)
lmask = _prepare_decoder_attention_mask(
model, mask, hidden.shape[:2], hidden, 0
)
test, _, _ = forward(
attn, hidden, attention_mask=lmask, position_ids=position_ids
)
print(f'Mean(abs(ref)) = {torch.mean(torch.abs(ref))}')
print(f'Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}')
print(f'Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}')
print(f'Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}')
print(f'allclose(fast, test) = {torch.allclose(fast, test)}')
with torch.no_grad():
# Also check that past_kv is handled properly
hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
part_len = seqlen // 4
assert part_len * 4 == seqlen
mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
mask[0, -2:] = False
lmask = _prepare_decoder_attention_mask(
model, mask, hidden.shape[:2], hidden, 0
)
oneshot, _, _ = forward(
attn, hidden, attention_mask=lmask, position_ids=position_ids
)
parts = []
past_kv, past_kv_len = None, 0
for i in range(4):
start = part_len * i
end = start + part_len
hidden_part = hidden[:, start:end, ...]
lmask = _prepare_decoder_attention_mask(
model,
mask[:, start:end],
hidden_part.shape[:2],
hidden_part,
past_kv_len,
)
part, _, past_kv = forward(
attn,
hidden_part.clone(),
attention_mask=lmask,
position_ids=position_ids[:, start:end],
past_key_value=past_kv,
use_cache=True,
)
parts.append(part)
past_kv_len = past_kv[0].shape[2]
print(
f'allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}'
)
print(
f'allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}'
)
if __name__ == '__main__':
test()
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers
from torch import nn
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
from einops import rearrange
try: # v1
from flash_attn.flash_attn_interface import \
flash_attn_unpadded_qkvpacked_func
except: # v2
from flash_attn.flash_attn_interface import \
flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import pad_input, unpad_input
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, 'past_key_value is not supported'
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
assert not output_attentions, 'output_attentions is not supported'
assert not use_cache, 'use_cache is not supported'
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack(
[query_states, key_states, value_states], dim=2
) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = q_len
cu_q_lens = torch.arange(
0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads
)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(
pad_input(
rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices, bsz, q_len
),
'b s (h d) -> b s h d',
h=nheads,
)
return self.o_proj(rearrange(output, 'b s h d -> b s (h d)')), None, None
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# [bsz, seq_len]
return attention_mask
def forward_2(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = (
self.q_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states)
.view(bsz, q_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
assert not output_attentions, 'output_attentions is not supported'
assert not use_cache, 'use_cache is not supported'
assert past_key_value is None, 'past_key_value is not supported'
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
if self.training:
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, dropout_p=0.0, is_causal=True
)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f'Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is'
f' {attn_weights.size()}'
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f'Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}'
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
)
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f'`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is'
f' {attn_output.size()}'
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def replace_llama_attn_with_flash_attn():
if hasattr(F, 'scaled_dot_product_attention'):
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward_2
else:
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask
)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
import transformers
def replace_llama_rmsnorm_with_fused_rmsnorm():
try:
from functools import partial
from apex.normalization import FusedRMSNorm
LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
print('Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm')
except ImportError:
# using the normal LlamaRMSNorm
pass
except Exception:
print('discovered apex but it failed to load, falling back to LlamaRMSNorm')
pass
import numpy as np
import torch
IGNORE_INDEX = -100
def pad_data_collator(features, pad_id=0):
first = features[0]
batch = {}
batch_lens = [feat['input_ids'].shape for feat in features]
max_item_length = max(batch_lens)[0]
for idx in range(len(features)):
feat = features[idx]
temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
feat['input_ids'] = temp_input_ids
temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
temp_labels[:feat['labels'].shape[0]] = feat['labels']
feat['labels'] = temp_labels
feat['attention_mask'] = feat['input_ids'].ne(pad_id)
# Special handling for labels.
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if 'label' in first and first['label'] is not None:
label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
dtype = torch.long if isinstance(label, int) else torch.float
batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
elif 'label_ids' in first and first['label_ids'] is not None:
if isinstance(first['label_ids'], torch.Tensor):
batch['labels'] = torch.stack([f['label_ids'] for f in features])
else:
dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)
# Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in first.items():
if k not in ('label', 'label_ids') and v is not None and not isinstance(v, str):
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
elif isinstance(v, np.ndarray):
batch[k] = torch.tensor(np.stack([f[k] for f in features]))
else:
batch[k] = torch.tensor([f[k] for f in features])
return batch
def concat_pad_data_collator(features, pad_id=0):
first = features[0]
batch = {}
batch_lens = [feat['input_ids'].shape for feat in features]
max_item_length = max(batch_lens)[0]
for idx in range(len(features)):
feat = features[idx]
temp_input_ids = torch.LongTensor([pad_id] * max_item_length)
temp_input_ids[:feat['input_ids'].shape[0]] = feat['input_ids']
feat['input_ids'] = temp_input_ids
temp_labels = torch.LongTensor([IGNORE_INDEX] * max_item_length)
temp_labels[:feat['labels'].shape[0]] = feat['labels']
feat['labels'] = temp_labels
feat['attention_mask'] = feat['input_ids'].ne(pad_id)
# Special handling for labels.
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if 'label' in first and first['label'] is not None:
label = first['label'].item() if isinstance(first['label'], torch.Tensor) else first['label']
dtype = torch.long if isinstance(label, int) else torch.float
batch['labels'] = torch.tensor([f['label'] for f in features], dtype=dtype)
elif 'label_ids' in first and first['label_ids'] is not None:
if isinstance(first['label_ids'], torch.Tensor):
batch['labels'] = torch.stack([f['label_ids'] for f in features])
else:
dtype = torch.long if isinstance(first['label_ids'][0], int) else torch.float
batch['labels'] = torch.tensor([f['label_ids'] for f in features], dtype=dtype)
# Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in first.items():
if k not in ('label', 'label_ids', 'pixel_values', 'image_flags') and \
v is not None and not isinstance(v, str):
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
elif isinstance(v, np.ndarray):
batch[k] = torch.tensor(np.stack([f[k] for f in features]))
else:
batch[k] = torch.tensor([f[k] for f in features])
if k in ('pixel_values', 'image_flags'):
if isinstance(v, torch.Tensor):
batch[k] = torch.concat([f[k] for f in features])
elif isinstance(v, np.ndarray):
batch[k] = torch.concat(np.stack([f[k] for f in features]))
else:
batch[k] = torch.concat([f[k] for f in features])
return batch
from typing import List, Optional
import torch
import transformers
from torch.utils.data import Dataset, Sampler
from transformers.tokenization_utils_base import BatchEncoding
from transformers.trainer import (LengthGroupedSampler, RandomSampler,
has_length)
from transformers.trainer_pt_utils import logger
# copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L38
def split_to_even_chunks(indices, lengths, num_chunks):
"""
Split a list of indices into `chunks` chunks of roughly equal lengths.
"""
if len(indices) % num_chunks != 0:
return [indices[i::num_chunks] for i in range(num_chunks)]
num_indices_per_chunk = len(indices) // num_chunks
chunks = [[] for _ in range(num_chunks)]
chunks_lengths = [0 for _ in range(num_chunks)]
for index in indices:
shortest_chunk = chunks_lengths.index(min(chunks_lengths))
chunks[shortest_chunk].append(index)
chunks_lengths[shortest_chunk] += lengths[index]
if len(chunks[shortest_chunk]) == num_indices_per_chunk:
chunks_lengths[shortest_chunk] = float('inf')
return chunks
# copy from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L88
def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
indices = torch.randperm(len(lengths), generator=generator)
megabatch_size = world_size * batch_size
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
return [i for megabatch in megabatches for batch in megabatch for i in batch]
# modified from https://github.com/haotian-liu/LLaVA/blob/main/llava/train/llava_trainer.py#L99
class LengthGroupedSampler(Sampler):
r"""
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
keeping a bit of randomness.
"""
def __init__(
self,
batch_size: int,
world_size: int,
dataset: Optional[Dataset] = None,
lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None,
generator=None,
):
if dataset is None and lengths is None:
raise ValueError('One of dataset and lengths must be provided.')
self.batch_size = batch_size
if lengths is None:
model_input_name = model_input_name if model_input_name is not None else 'input_ids'
if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or model_input_name not in dataset[0]
):
raise ValueError(
'Can only automatically infer lengths for datasets whose items are dictionaries with an '
f"'{model_input_name}' key."
)
lengths = [len(feature[model_input_name]) for feature in dataset]
elif isinstance(lengths, torch.Tensor):
logger.info(
'If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]...'
)
lengths = lengths.tolist()
self.world_size = world_size
self.lengths = lengths
self.generator = generator
def __len__(self):
return len(self.lengths)
def __iter__(self):
indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
return iter(indices)
# patch trainer
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
return None
# Build the sampler.
if self.args.group_by_length:
lengths = []
for dataset in self.train_dataset.datasets:
lengths = lengths + dataset.length
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
return LengthGroupedSampler(
self.args.train_batch_size,
world_size=self.args.world_size * self.args.gradient_accumulation_steps,
# self.args.train_batch_size * self.args.gradient_accumulation_steps,
dataset=self.train_dataset,
lengths=lengths,
model_input_name=model_input_name,
)
else:
return RandomSampler(self.train_dataset)
def replace_train_sampler():
transformers.Trainer._get_train_sampler = _get_train_sampler
print('Replace train sampler!!')
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
QUAD_START_TOKEN = '<quad>'
QUAD_END_TOKEN = '</quad>'
REF_START_TOKEN = '<ref>'
REF_END_TOKEN = '</ref>'
BOX_START_TOKEN = '<box>'
BOX_END_TOKEN = '</box>'
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
CLIP_MEAN = (0.4814546, 0.4578275, 0.40821073)
CLIP_STD = (0.2686295, 0.2613025, 0.2757711)
SIGLIP_MEAN = (0.5, 0.5, 0.5)
SIGLIP_STD = (0.5, 0.5, 0.5)
import io
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
import os
import random
from typing import Dict
import cv2
import imageio
import numpy as np
import torch
import torchvision.transforms as T
import transformers
from decord import VideoReader
from internvl.conversation import get_conv_template
from PIL import Image
from torch.utils.data import ConcatDataset, WeightedRandomSampler
from torchvision.transforms.functional import InterpolationMode
from .constants import (CLIP_MEAN, CLIP_STD, IMAGENET_MEAN, IMAGENET_STD,
IMG_CONTEXT_TOKEN, IMG_END_TOKEN, IMG_START_TOKEN,
SIGLIP_MEAN, SIGLIP_STD)
try:
from petrel_client.client import Client
from petrel_client.common.config import Config
except ImportError as E:
print('petrel_client is not installed. If you read data locally instead of from ceph, ignore it.')
import sys
def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1):
if sample in ['rand', 'middle']: # uniform sampling
acc_samples = min(num_frames, vlen)
# split the video into `acc_samples` intervals, and sample from each interval.
intervals = np.linspace(start=0, stop=vlen, num=acc_samples + 1).astype(int)
ranges = []
for idx, interv in enumerate(intervals[:-1]):
ranges.append((interv, intervals[idx + 1] - 1))
if sample == 'rand':
try:
frame_indices = [random.choice(range(x[0], x[1])) for x in ranges]
except:
frame_indices = np.random.permutation(vlen)[:acc_samples]
frame_indices.sort()
frame_indices = list(frame_indices)
elif fix_start is not None:
frame_indices = [x[0] + fix_start for x in ranges]
elif sample == 'middle':
frame_indices = [(x[0] + x[1]) // 2 for x in ranges]
else:
raise NotImplementedError
if len(frame_indices) < num_frames: # padded with last frame
padded_frame_indices = [frame_indices[-1]] * num_frames
padded_frame_indices[:len(frame_indices)] = frame_indices
frame_indices = padded_frame_indices
elif 'fps' in sample: # fps0.5, sequentially sample frames at 0.5 fps
output_fps = float(sample[3:])
duration = float(vlen) / input_fps
delta = 1 / output_fps # gap between frames, this is also the clip length each frame represents
frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta)
frame_indices = np.around(frame_seconds * input_fps).astype(int)
frame_indices = [e for e in frame_indices if e < vlen]
if max_num_frames > 0 and len(frame_indices) > max_num_frames:
frame_indices = frame_indices[:max_num_frames]
# frame_indices = np.linspace(0 + delta / 2, duration + delta / 2, endpoint=False, num=max_num_frames)
else:
raise ValueError
return frame_indices
def read_frames_gif(
video_path, num_frames, sample='rand', fix_start=None,
client=None, min_num_frames=4
):
if 's3://' in video_path:
video_bytes = client.get(video_path)
gif = imageio.get_reader(io.BytesIO(video_bytes))
else:
gif = imageio.get_reader(video_path)
vlen = len(gif)
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
frame_indices = get_frame_indices(
t_num_frames, vlen, sample=sample, fix_start=fix_start
)
frames = []
for index, frame in enumerate(gif):
if index in frame_indices:
frame = cv2.cvtColor(frame, cv2.COLOR_RGBA2RGB).astype(np.uint8)
frame = Image.fromarray(frame)
frames.append(frame)
return frames
def read_frames_decord(
video_path, num_frames, sample='rand', fix_start=None,
client=None, clip=None, min_num_frames=4
):
if 's3://' in video_path:
video_bytes = client.get(video_path)
video_reader = VideoReader(io.BytesIO(video_bytes), num_threads=1)
else:
video_reader = VideoReader(video_path, num_threads=1)
vlen = len(video_reader)
fps = video_reader.get_avg_fps()
duration = vlen / float(fps)
if clip:
start, end = clip
duration = end - start
vlen = int(duration * fps)
start_index = int(start * fps)
# t_num_frames = min(max(int(duration * sample_fps), min_num_frames), num_frames)
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
frame_indices = get_frame_indices(
t_num_frames, vlen, sample=sample, fix_start=fix_start,
input_fps=fps
)
if clip:
frame_indices = [f + start_index for f in frame_indices]
frames = video_reader.get_batch(frame_indices).asnumpy() # (T, H, W, C), np.uint8
frames = [Image.fromarray(frames[i]) for i in range(frames.shape[0])]
return frames
def read_frames_folder(
video_path, num_frames, sample='rand', fix_start=None,
client=None, clip=None, min_num_frames=4
):
if 's3://' in video_path:
image_list = client.list(video_path)
frames = []
for image in image_list:
fp = os.path.join(video_path, image)
frame = Image.open(io.BytesIO(client.get(fp)))
frames.append(frame)
else:
image_list = sorted(list(os.listdir(video_path)))
frames = []
for image in image_list:
fp = os.path.join(video_path, image)
frame = Image.open(fp).convert('RGB')
frames.append(frame)
vlen = len(frames)
t_num_frames = np.random.randint(min_num_frames, num_frames + 1)
if vlen > t_num_frames:
frame_indices = get_frame_indices(
t_num_frames, vlen, sample=sample, fix_start=fix_start
)
frames = [frames[i] for i in frame_indices]
return frames
class WeightedConcatDataset(ConcatDataset):
def __init__(self, datasets, weights):
super().__init__(datasets)
self.weights = torch.DoubleTensor(weights)
self.total_size = sum(len(d) for d in datasets)
self.sampler = WeightedRandomSampler(weights=self.weights, num_samples=self.total_size, replacement=True)
def __iter__(self):
return iter(self.sampler)
def __len__(self):
return self.total_size
def pil_loader(img_str):
buff = io.BytesIO(img_str)
img = Image.open(buff)
return img.convert('RGB')
class TCSLoader(object):
def __init__(self, conf_path, sc_config_key='sensecore'):
print(f'[TCSLoader] config_path: {conf_path}')
print('--> before Client(conf_path)')
self.client = Client(conf_path)
self.sc_config_key = sc_config_key
print('--> after Client(conf_path)')
def __call__(self, fn, image_type='image', max_num_frames=-1, min_num_frames=4, sample='rand', clip=None):
if image_type == 'image':
img_value_str = self.client.get(fn)
img = pil_loader(img_value_str)
return img
elif image_type == 'video':
if fn.endswith('/'):
frames = read_frames_folder(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
client=self.client, sample=sample)
elif fn.endswith('.gif'):
frames = read_frames_gif(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
client=self.client, sample=sample)
else:
frames = read_frames_decord(fn, num_frames=max_num_frames, min_num_frames=min_num_frames,
client=self.client, sample=sample, clip=clip)
return frames
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def simulate_jpeg_degradation(quality):
def jpeg_degrade(img):
with io.BytesIO() as output:
img.convert('RGB').save(output, format='JPEG', quality=quality)
output.seek(0) # Move the reading cursor to the start of the stream
img_jpeg = Image.open(output).copy() # Use .copy() to make sure the image is loaded in memory
return img_jpeg
return jpeg_degrade
# Define the JPEG compression quality range, pre-create all JPEG compression functions
qualities = list(range(75, 101))
jpeg_degrade_functions = {quality: simulate_jpeg_degradation(quality) for quality in qualities}
def build_transform(is_train, input_size, pad2square=False, normalize_type='imagenet'):
if normalize_type == 'imagenet':
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
elif normalize_type == 'clip':
MEAN, STD = CLIP_MEAN, CLIP_STD
elif normalize_type == 'siglip':
MEAN, STD = SIGLIP_MEAN, SIGLIP_STD
else:
raise NotImplementedError
if is_train: # use data augumentation
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.RandomChoice([T.Lambda(jpeg_degrade_functions[quality]) for quality in qualities]),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
else:
if pad2square is False: # now we use this transform function by default
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
else:
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in MEAN))),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def preprocess(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] + ': '
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
turns = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(turns):
if turn == '':
break
turn_len = len(tokenizer(turn).input_ids)
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
if i != 0 and not tokenizer.legacy:
# The legacy and non-legacy modes handle special tokens differently
instruction_len -= 1
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len
if i != 0 and not tokenizer.legacy:
# The legacy and non-legacy modes handle special tokens differently
cur_len -= 1
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
logger.info(tokenizer.decode(z))
exit()
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def preprocess_mpt(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] # <|im_end|><|im_start|>assistant\n
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
turns = conversation.split(conv.sep)
re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
for conv_idx in range(3, len(turns), 2):
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
cur_len = 0
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(re_turns):
if turn == '':
break
turn_len = len(tokenizer(turn).input_ids) + 1
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
instruction_len = len(tokenizer(parts[0]).input_ids)
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
# print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
# print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
# print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def preprocess_phi3(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
tokenizer.padding_side = 'right'
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] # <|end|>\n<|assistant|>
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(int(tokenizer.pad_token_id)).sum())
turns = conversation.split(conv.sep)
re_turns = [conv.sep.join(turns[:3])] # system + user + gpt
for conv_idx in range(3, len(turns), 2):
re_turns.append(conv.sep.join(turns[conv_idx:conv_idx + 2])) # user + gpt
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
endoftext_id = tokenizer.convert_tokens_to_ids('<|endoftext|>')
target[target == endoftext_id] = IGNORE_TOKEN_ID
for i, turn in enumerate(re_turns):
if turn == '':
break
if i == 0:
turn_len = len(tokenizer(turn).input_ids)
else:
turn_len = len(tokenizer(turn).input_ids) - 1
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
if i == 0:
instruction_len = len(tokenizer(parts[0]).input_ids) - 1
else:
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
# Ignore the user instructions
target[cur_len: cur_len + instruction_len] = IGNORE_TOKEN_ID
# print(f'[question {i}]', tokenizer.decode(input_ids[:, cur_len: cur_len + instruction_len][0]))
# print(f'[answer {i}]', tokenizer.decode(input_ids[:, cur_len + instruction_len: cur_len + turn_len][0]))
# print(f'[label {i}]', target[cur_len + instruction_len: cur_len + turn_len])
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
print(repr(tokenizer.decode(z)))
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(
f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}.'
f' #turn = {len(turns) - 1}. (ignored). This dataset is {ds_name}.'
)
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def preprocess_internlm(
template_name,
sources,
tokenizer: transformers.PreTrainedTokenizer,
num_image_token_list: list,
text_only: bool = False,
group_by_length: bool = False,
use_packed_ds: bool = False,
ds_name: str = None,
num_image: int = 1
) -> Dict:
conv = get_conv_template(template_name)
roles = {'human': conv.roles[0], 'gpt': conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if roles[source[0]['from']] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
for j, sentence in enumerate(source):
role = roles[sentence['from']]
assert role == conv.roles[j % 2], f'{i}'
sentence['value'] = sentence['value'].strip()
conv.append_message(role, sentence['value'])
conversations.append(conv.get_prompt())
if not text_only:
new_conversations = []
for conversation in conversations:
for i in range(num_image):
image_tokens = f'{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * num_image_token_list[i]}{IMG_END_TOKEN}'
conversation = conversation.replace('<image>', image_tokens, 1)
new_conversations.append(conversation)
conversations = new_conversations
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors='pt',
padding=False if group_by_length or use_packed_ds else 'max_length',
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum()) # 浦语里面 pad_token_id = eos_token_id
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID # <s>
parts = conversation.split(conv.roles[1]) # [UNUSED_TOKEN_146]assistant\n
info = parts[0] + conv.roles[1]
temp_len = len(tokenizer(info).input_ids) - 1 # 去除tokenizer的<s>
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
cur_len = cur_len + temp_len
for index in range(1, len(parts) - 1):
info = parts[index]
part1, part2 = info.split(conv.roles[0])
temp_len = len(tokenizer(part1).input_ids) - 1
cur_len = cur_len + temp_len
part = conv.roles[0] + part2 + conv.roles[1]
temp_len = len(tokenizer(part).input_ids) - 1
target[cur_len: cur_len + temp_len] = IGNORE_TOKEN_ID
cur_len = cur_len + temp_len
last_info = parts[-1]
temp_len = len(tokenizer(last_info).input_ids) - 1
cur_len = cur_len + temp_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
print(repr(tokenizer.decode(z)))
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
print(f'WARNING: tokenization mismatch: {cur_len} vs. {total_len}. This dataset is {ds_name}.')
sys.stdout.flush()
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
# print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
import gc
import json
import logging
import math
import os
import random
import sys
import traceback
import warnings
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Dict, Optional
import numpy as np
import torch
import torch.distributed as dist
import transformers
from internvl.dist_utils import init_dist
from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
from internvl.model.internvl_chat import (InternVisionConfig,
InternVisionModel,
InternVLChatConfig,
InternVLChatModel)
from internvl.patch import (concat_pad_data_collator,
replace_llama_rmsnorm_with_fused_rmsnorm,
replace_train_sampler)
from internvl.train.constants import (BOX_END_TOKEN, BOX_START_TOKEN,
IMG_CONTEXT_TOKEN, IMG_END_TOKEN,
IMG_START_TOKEN, QUAD_END_TOKEN,
QUAD_START_TOKEN, REF_END_TOKEN,
REF_START_TOKEN)
from internvl.train.dataset import (ConcatDataset, TCSLoader,
WeightedConcatDataset, build_transform,
dynamic_preprocess, preprocess,
preprocess_internlm, preprocess_mpt,
preprocess_phi3)
from internvl.train.trainer_monkey_patch import replace_create_optimizer
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError
from torch.utils.data import Dataset
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
HfArgumentParser, Trainer, TrainingArguments,
set_seed)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.logging import (enable_default_handler,
enable_explicit_format, set_verbosity)
# Apply necessary patches for the transformers library
replace_llama_rmsnorm_with_fused_rmsnorm()
replace_train_sampler()
# Try to import petrel_client for image loading, fallback to PIL if unavailable
try:
from petrel_client.client import Client
from petrel_client.common.config import Config
has_tcs_loader = True
except ImportError as E:
print('petrel_client is not installed. Using PIL to load images.')
has_tcs_loader = False
# Set constants for image processing and logging
IGNORE_INDEX = -100
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
MaximumDecompressedSize = 1024
MegaByte = 2 ** 20
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
warnings.filterwarnings('ignore')
logger = logging.getLogger(__name__)
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
@dataclass
class ModelArguments:
"""
Arguments for specifying model, tokenizer, and configurations.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
vision_path: Optional[str] = field(
default=None,
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
llm_path: Optional[str] = field(
default=None,
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
mlp_path: Optional[str] = field(
default=None,
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
freeze_llm: bool = field(
default=False,
metadata={'help': 'Set to True to freeze the LLM decoder.'},
)
freeze_backbone: bool = field(
default=False,
metadata={'help': 'Set to True to freeze the vision backbone of the model.'},
)
freeze_mlp: bool = field(
default=False,
metadata={'help': 'Set to True to freeze the MLP layers of the model.'},
)
unfreeze_vit_layers: int = field(
default=0,
metadata={'help': 'Specify the number of ViT layers to unfreeze. Default is 0.'},
)
vision_select_layer: int = field(
default=-1,
metadata={'help': 'Specify the layer of ViT feature map to use. Default is last layer.'},
)
use_backbone_lora: int = field(
default=0,
metadata={'help': 'Set the LoRA adapter rank for the backbone model. Default is 0.'}
)
use_llm_lora: int = field(
default=0,
metadata={'help': 'Set the LoRA adapter rank for the LLM. Default is 0.'}
)
unfreeze_lm_head: bool = field(
default=False,
metadata={'help': "Set to True to unfreeze the language model's head."},
)
use_custom_trainer: bool = field(
default=False,
metadata={'help': 'Set to True to enable the use of a custom trainer.'},
)
grad_checkpoint: Optional[bool] = field(
default=False,
metadata={'help': 'Set to True to use gradient checkpointing.'},
)
drop_path_rate: float = field(
default=0.0,
metadata={'help': 'Set the drop path rate for the ViT model. Default is 0.'},
)
ps_version: str = field(
default='v2',
metadata={'help': 'Specify the version of pixel shuffle implementation. Default is `v1`.'
'Please use `v2` to fix the bug of transposed image.'}
)
@dataclass
class DataTrainingArguments:
"""
Arguments for specifying data input for training and evaluation.
"""
max_seq_length: Optional[int] = field(
default=2048,
metadata={
'help': (
'The maximum total input sequence length after tokenization. Sequences longer '
'than this will be truncated, sequences shorter will be padded.'
)
},
)
force_image_size: Optional[int] = field(
default=448,
metadata={'help': 'Set the desired size for the image. Default is 224.'},
)
down_sample_ratio: Optional[float] = field(
default=0.5,
metadata={'help': 'Set the desired down-sampling ratio for the image. Default is 1.0.'},
)
pad2square: Optional[bool] = field(
default=False,
metadata={'help': 'Pad the image to a square shape if set to True.'},
)
conv_style: Optional[str] = field(
default='internlm2-chat', metadata={'help': 'Prompt style for a conversation.'}
)
meta_path: Optional[str] = field(
default=None,
metadata={'help': 'The path of the meta file of datasets.'},
)
use_data_resampling: Optional[bool] = field(
default=False,
metadata={'help': 'Set to True to use data resampling.'},
)
dynamic_image_size: Optional[bool] = field(
default=False,
metadata={'help': 'Set to True to use dynamic image size.'},
)
use_thumbnail: Optional[bool] = field(
default=False,
metadata={'help': 'Set to True to add a thumbnail image.'},
)
min_dynamic_patch: Optional[int] = field(
default=1,
metadata={'help': 'The minimum number of dynamic patches. Default is 1.'},
)
max_dynamic_patch: Optional[int] = field(
default=12,
metadata={'help': 'The maximum number of dynamic patches. Default is 6.'},
)
normalize_type: Optional[str] = field(
default='imagenet',
metadata={'help': 'The normalize type for the image. Default is imagenet.'},
)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
template_name,
meta,
tokenizer,
tcs_loader,
ds_name,
num_image_token,
image_size=224,
is_train=True,
pad2square=False,
group_by_length=False,
dynamic_image_size=False,
use_thumbnail=False,
min_dynamic_patch=1,
max_dynamic_patch=6,
min_num_frame=4, # for video data
max_num_frame=12, # for video data
sampling_method='rand', # for video data
repeat_time=1,
normalize_type='imagenet',
random_seed=0,
):
super(LazySupervisedDataset, self).__init__()
self.ds_name = ds_name
self.tokenizer = tokenizer
self.template_name = template_name
self.num_image_token = num_image_token
logger.info(f'[Dataset] num_image_token: {num_image_token}')
logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}')
logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}')
logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}')
self.image_size = image_size
self.is_train = is_train
self.pad2square = pad2square
self.max_num_frame = max_num_frame
self.min_num_frame = min_num_frame
self.sampling_method = sampling_method
logger.info('Formatting inputs...Skip in lazy mode')
assert meta['annotation'].endswith('jsonl'), f'annotation must be jsonl, but got {meta["annotation"]}'
with open(meta['annotation'], 'r') as f:
self.raw_data = f.readlines()
if repeat_time < 1:
# If repeat_time is less than 1, select a portion of the data
self.raw_data = self.raw_data[:int(len(self.raw_data) * repeat_time)]
if repeat_time > 1:
assert isinstance(repeat_time, int)
# Repeat the list if repeat_time is greater than 1
self.raw_data = self.raw_data * repeat_time
self.rng = np.random.default_rng(seed=random_seed)
self.rng.shuffle(self.raw_data)
gc.collect()
self.root = meta['root']
self.cached_data_dict = {}
self.tcs_loader = tcs_loader
self.group_by_length = group_by_length
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
self.normalize_type = normalize_type
# If the precomputed length does not exist, roughly estimate the length of
# each sample to improve the efficiency of group_by_length.
if self.group_by_length:
self.conv2length = {} # Using a dictionary to speed up token length calculation
self.length = []
for data_item in self.raw_data:
data_item = json.loads(data_item)
if 'length' in data_item:
token_length = data_item['length'] # Use precomputed length if available
else:
# Compute token length using the tokenizer
conversations = '\n'.join([temp['value'] for temp in data_item['conversations']])
str_length = len(conversations)
if str_length not in self.conv2length:
token_length = tokenizer(
conversations, return_tensors='pt', padding=False, truncation=False,
).input_ids.size(1)
self.conv2length[str_length] = token_length + num_image_token * (
max_dynamic_patch + use_thumbnail)
else:
token_length = self.conv2length[str_length]
self.length.append(token_length)
gc.collect()
def __len__(self):
return len(self.raw_data)
def get_preprocess_function(self):
# Select the appropriate preprocessing function based on the template name
if self.template_name == 'Hermes-2':
preprocess_function = preprocess_mpt
elif self.template_name == 'internlm2-chat':
preprocess_function = preprocess_internlm
elif self.template_name == 'phi3-chat':
preprocess_function = preprocess_phi3
else:
preprocess_function = preprocess
return preprocess_function
def load_image(self, image_path):
# Load the image using tcs_loader if available, otherwise use PIL
if self.tcs_loader is not None and 's3://' in image_path:
return self.tcs_loader(image_path)
return Image.open(image_path).convert('RGB')
def get_image_path(self, image_path):
if image_path.startswith('s3://'): # for ceph
image_path = self.root + image_path
else: # for local image
image_path = os.path.join(self.root, image_path)
return image_path
def get_transform(self):
# Build transformation function
transform = build_transform(is_train=self.is_train, input_size=self.image_size,
pad2square=self.pad2square, normalize_type=self.normalize_type)
return transform
def multi_modal_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
# Ensure the first conversation contains an image placeholder
if '<image>' not in data_item['conversations'][0]['value']:
data_item['conversations'][0]['value'] = '<image>\n' + data_item['conversations'][0]['value']
# Merge the image path
image_path = self.get_image_path(data_item['image'])
# Load the image using tcs_loader if available, otherwise use PIL
image = self.load_image(image_path)
if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch,
image_size=self.image_size, use_thumbnail=self.use_thumbnail)
else: # Otherwise, use the original image as a single patch
images = [image]
# Apply the transformation to each image and stack the results into a tensor
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
# Ensure that there is only one patch if dynamic image size is not enabled
num_patches = pixel_values.size(0)
if not self.dynamic_image_size:
assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, [self.num_image_token * num_patches],
group_by_length=self.group_by_length, ds_name=self.ds_name)
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
)
return ret
def multi_modal_multi_image_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
images, num_tiles = [], []
num_image = len(data_item['image'])
for image_path in data_item['image']:
# Merge the image path
image_path = self.get_image_path(image_path)
# Load the image using tcs_loader if available, otherwise use PIL
image = self.load_image(image_path)
if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
image = dynamic_preprocess(image, min_num=self.min_dynamic_patch,
max_num=self.max_dynamic_patch // num_image,
image_size=self.image_size, use_thumbnail=self.use_thumbnail)
images += image
num_tiles.append(len(image))
else: # Otherwise, use the original image as a single patch
images.append(image)
num_tiles.append(1)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
num_patches = pixel_values.size(0)
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
num_image_tokens = [self.num_image_token * num_tile for num_tile in num_tiles]
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, num_image_tokens, group_by_length=self.group_by_length,
ds_name=self.ds_name, num_image=num_image)
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
)
return ret
def video_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
# Ensure the first conversation contains a video placeholder
if '<video>' not in data_item['conversations'][0]['value']:
data_item['conversations'][0]['value'] = '<video>\n' + data_item['conversations'][0]['value']
# Get the video file path
video_file = data_item['video']
video_path = os.path.join(self.root, video_file)
# Load the video frames using tcs_loader
# TODO: Load videos without using tcsloader.
image_list = self.tcs_loader(
video_path,
image_type='video',
max_num_frames=self.max_num_frame,
min_num_frames=self.min_num_frame,
sample=self.sampling_method,
clip=data_item.get('clip', None))
# Generate special tokens for each video frame
special_tokens = '\n'.join(['Frame{}: <image>'.format(i + 1) for i in range(len(image_list))])
data_item['conversations'][0]['value'] = data_item['conversations'][0]['value'].replace(
'<video>\n', special_tokens)
# Transform each frame image and stack them into a tensor
pixel_values = [transform(image) for image in image_list]
pixel_values = torch.stack(pixel_values)
num_patches = pixel_values.size(0)
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
num_image_tokens = [self.num_image_token] * num_patches
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, num_image_tokens, group_by_length=self.group_by_length,
ds_name=self.ds_name, num_image=num_patches)
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
)
return ret
def pure_text_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
# Create a blank white image
image = Image.new('RGB', (224, 224), (255, 255, 255))
# Dynamically preprocess the image to generate patches
images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=1,
image_size=self.image_size, use_thumbnail=self.use_thumbnail)
# Apply the transformation to each image patch and stack them into a tensor
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
num_patches = pixel_values.size(0)
# Ensure there is only one patch
assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, [self.num_image_token * num_patches], text_only=True,
group_by_length=self.group_by_length, ds_name=self.ds_name)
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
pixel_values=pixel_values,
image_flags=torch.tensor([0] * num_patches, dtype=torch.long)
)
return ret
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
i = i % len(self.raw_data)
while True:
try:
data_item = json.loads(self.raw_data[i])
if 'image' in data_item and len(data_item['image']) != 0:
if type(data_item['image']) == list:
ret = self.multi_modal_multi_image_get_item(data_item)
else:
ret = self.multi_modal_get_item(data_item)
elif 'video' in data_item and data_item['video'] is not None and data_item['video'] != '':
ret = self.video_get_item(data_item)
else:
ret = self.pure_text_get_item(data_item)
break
except Exception as e:
print(e, self.ds_name, flush=True)
if not isinstance(e, UnidentifiedImageError):
traceback.print_exc()
data_item = json.loads(self.raw_data[i])
if 'image' in data_item:
if type(data_item['image']) == list:
images = [self.root + item for item in data_item['image']]
print(f'Failed to load image: {images}, the dataset is: {self.ds_name}')
else:
if data_item['image'].startswith('s3://'):
data_path = self.root + data_item['image']
else:
data_path = os.path.join(self.root, data_item['image'])
print(f'Failed to load image: {data_path}, the dataset is: {self.ds_name}')
elif 'video' in data_item:
data_path = os.path.join(self.root, data_item['video'])
print(f'Failed to load video: {data_path}, the dataset is: {self.ds_name}')
i = random.randint(0, len(self.raw_data) - 1)
return ret
def build_datasets(
data_args,
tokenizer,
tcs_loader,
model,
group_by_length=False,
dynamic_image_size=False,
use_thumbnail=False,
min_dynamic_patch=1,
max_dynamic_patch=12,
normalize_type='imagenet',
):
datasets = []
lengths = []
ds_collections = json.loads(open(data_args.meta_path).read())
for ds_idx, ds_name in enumerate(ds_collections.keys()):
repeat_time = ds_collections[ds_name]['repeat_time']
if 'max_dynamic_patch' in ds_collections[ds_name]:
max_num = ds_collections[ds_name]['max_dynamic_patch']
logger.info(f'max_dynamic_patch is set to {max_num} according to the meta file')
else:
max_num = max_dynamic_patch
dataset = LazySupervisedDataset(
data_args.conv_style, ds_collections[ds_name],
tokenizer,
tcs_loader,
ds_name=ds_name,
num_image_token=model.num_image_token,
image_size=data_args.force_image_size,
is_train=ds_collections[ds_name]['data_augment'],
pad2square=data_args.pad2square,
group_by_length=group_by_length,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_num,
repeat_time=repeat_time,
normalize_type=normalize_type,
random_seed=ds_idx,
)
logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}')
datasets.append(dataset)
if data_args.use_data_resampling:
lengths.append(math.sqrt(len(dataset)))
else:
lengths.append(len(dataset))
if data_args.use_data_resampling:
total_length = sum(lengths)
weights = [l / total_length for l in lengths]
train_dataset = WeightedConcatDataset(datasets, weights)
else:
train_dataset = ConcatDataset(datasets)
return train_dataset
def main():
# Parse input arguments
# See all possible arguments in src/transformers/training_args.py
# If use DeepSpeed zero3, init_dist must before HfArgumentParser
launcher = os.environ.get('LAUNCHER', 'slurm')
init_dist(launcher=launcher, backend='nccl')
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
# If we pass only one argument to the script, and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
# send_example_telemetry('InternV-Chat', model_args, data_args)
# Setup logging
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
handlers=[logging.StreamHandler(sys.stdout)],
)
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
set_verbosity(log_level)
enable_default_handler()
enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
+ f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
)
logger.info(f'Training/evaluation parameters {training_args}')
# Detecting last checkpoint and eventually continue from last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f'Output directory ({training_args.output_dir}) already exists and is not empty. '
'Use --overwrite_output_dir to overcome.'
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change '
'the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
)
# Set seed before initializing model.
set_seed(training_args.seed)
# Load pretrained model, tokenizer, and image processor
tokenizer_path = model_args.model_name_or_path or model_args.llm_path
logger.info(f'Loading Tokenizer: {tokenizer_path}')
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, add_eos_token=False, trust_remote_code=True, use_fast=False)
tokenizer.tokenizer_path = tokenizer_path
tokenizer.model_max_length = data_args.max_seq_length
token_list = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN,
QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN,
REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN]
num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
tcs_loader = TCSLoader('~/petreloss.conf') if has_tcs_loader else None
if model_args.model_name_or_path is not None:
logger.info('Loading InternVLChatModel...')
config = InternVLChatConfig.from_pretrained(model_args.model_name_or_path)
config.vision_config.drop_path_rate = model_args.drop_path_rate
if config.llm_config.model_type == 'internlm2':
config.llm_config.attn_implementation = 'flash_attention_2' # for InternLM
logger.info('Using flash_attention_2 for InternLM')
else:
config.llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
logger.info('Using flash_attention_2 for LLaMA')
config.template = data_args.conv_style
config.select_layer = model_args.vision_select_layer
config.dynamic_image_size = data_args.dynamic_image_size
config.use_thumbnail = data_args.use_thumbnail
config.ps_version = model_args.ps_version
config.min_dynamic_patch = data_args.min_dynamic_patch
config.max_dynamic_patch = data_args.max_dynamic_patch
model = InternVLChatModel.from_pretrained(
model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config)
else:
logger.info('Loading ViT-6B...')
vision_config = InternVisionConfig.from_pretrained(model_args.vision_path)
vision_config.drop_path_rate = model_args.drop_path_rate
vision_model = InternVisionModel.from_pretrained(
model_args.vision_path, torch_dtype=torch.bfloat16, config=vision_config)
logger.info('Loading LLaMA...')
llm_config = AutoConfig.from_pretrained(model_args.llm_path, trust_remote_code=True)
if llm_config.model_type == 'internlm2':
model_type = InternLM2ForCausalLM
llm_config.attn_implementation = 'flash_attention_2' # for InternLM
logger.info('Using flash_attention_2 for InternLM')
else:
model_type = AutoModelForCausalLM
llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
logger.info('Using flash_attention_2 for LLaMA')
llm = model_type.from_pretrained(
model_args.llm_path, torch_dtype=torch.bfloat16,
config=llm_config, trust_remote_code=True)
logger.info('Building InternVLChatConfig...')
internvl_chat_config = InternVLChatConfig(
vision_config.to_dict(), llm_config.to_dict(), downsample_ratio=data_args.down_sample_ratio,
pad2square=data_args.pad2square, template=data_args.conv_style,
select_layer=model_args.vision_select_layer, dynamic_image_size=data_args.dynamic_image_size,
use_thumbnail=data_args.use_thumbnail, ps_version=model_args.ps_version,
min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch)
internvl_chat_config.force_image_size = data_args.force_image_size
logger.info('Building InternVLChatModel...')
model = InternVLChatModel(internvl_chat_config, vision_model, llm)
model.img_context_token_id = img_context_token_id
assert model.config.downsample_ratio == data_args.down_sample_ratio
if model_args.mlp_path is not None:
logger.info('Loading pretrained MLP projector...')
state_dict = torch.load(model_args.mlp_path, map_location='cpu')
message = model.mlp1.load_state_dict(state_dict)
logger.info(message)
logger.info('Finished')
patch_size = model.config.vision_config.patch_size
logger.info(f'model.config.force_image_size: {model.config.force_image_size}')
logger.info(f'data_args.force_image_size: {data_args.force_image_size}')
logger.info(f'model.config.vision_config.image_size: {model.config.vision_config.image_size}')
if model.config.vision_config.image_size != data_args.force_image_size:
logger.info(f'Resizing position embedding from '
f'{model.config.vision_config.image_size} '
f'to {data_args.force_image_size}...')
model.vision_model.resize_pos_embeddings(old_size=model.config.vision_config.image_size,
new_size=data_args.force_image_size,
patch_size=patch_size)
model.config.vision_config.image_size = data_args.force_image_size
model.config.force_image_size = data_args.force_image_size
model.num_image_token = int((data_args.force_image_size // patch_size) ** 2 * (data_args.down_sample_ratio ** 2))
if num_new_tokens > 0:
model.language_model.resize_token_embeddings(len(tokenizer))
output_embeddings = model.language_model.get_output_embeddings().weight.data
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings[-num_new_tokens:] = output_embeddings_avg
model.config.llm_config.vocab_size = len(tokenizer)
model.language_model.config.vocab_size = len(tokenizer)
model.language_model.config.use_cache = False
model.vision_model.gradient_checkpointing = True
model.vision_model.encoder.gradient_checkpointing = True
if model_args.grad_checkpoint:
model.language_model._set_gradient_checkpointing()
train_dataset = build_datasets(
data_args, tokenizer, tcs_loader, model, group_by_length=training_args.group_by_length,
dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail,
min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch,
normalize_type=data_args.normalize_type)
def _freeze_params(module):
for param in module.parameters():
param.requires_grad = False
if model_args.freeze_backbone:
# model.vision_model = model.vision_model.eval()
_freeze_params(model.vision_model)
if model_args.freeze_llm:
model.language_model = model.language_model.eval()
_freeze_params(model.language_model)
if model_args.unfreeze_lm_head:
model.language_model.lm_head.requires_grad = True
if model_args.use_backbone_lora:
model.wrap_backbone_lora(r=model_args.use_backbone_lora, lora_alpha=2 * model_args.use_backbone_lora)
model.config.use_backbone_lora = model_args.use_backbone_lora
if model_args.use_llm_lora:
model.wrap_llm_lora(r=model_args.use_llm_lora, lora_alpha=2 * model_args.use_llm_lora)
model.config.use_llm_lora = model_args.use_llm_lora
if model_args.freeze_mlp:
_freeze_params(model.mlp1)
if model_args.unfreeze_vit_layers != 0:
layers = model.vision_model.encoder.layers[model_args.unfreeze_vit_layers:]
for k, v in layers.named_parameters():
logger.info(f'Unfreezing ViT layer: {k}')
v.requires_grad = True
# print trainable parameters
if dist.get_rank() == 0:
for name, param in model.named_parameters():
if param.requires_grad:
logger.info(name)
# set seed for torch dataloaders
set_seed(training_args.seed)
# Initialize our Trainer
if model_args.use_custom_trainer:
replace_create_optimizer()
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=None,
tokenizer=tokenizer,
data_collator=concat_pad_data_collator
)
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
try:
metrics['train_samples'] = len(train_dataset)
except:
metrics['train_samples'] = -1
trainer.log_metrics('train', metrics)
trainer.save_metrics('train', metrics)
trainer.save_state()
if __name__ == '__main__':
main()
import gc
import json
import logging
import math
import os
import random
import sys
import traceback
import warnings
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Dict, Optional
import numpy as np
import torch
import torch.distributed as dist
import transformers
from internvl.dist_utils import init_dist
from internvl.model.internlm2.modeling_internlm2 import InternLM2ForCausalLM
from internvl.model.internvl_chat import (InternVisionConfig,
InternVisionModel,
InternVLChatConfig,
InternVLChatModel)
from internvl.patch import (concat_pad_data_collator,
replace_llama_rmsnorm_with_fused_rmsnorm,
replace_train_sampler)
from internvl.train.constants import (BOX_END_TOKEN, BOX_START_TOKEN,
IMG_CONTEXT_TOKEN, IMG_END_TOKEN,
IMG_START_TOKEN, QUAD_END_TOKEN,
QUAD_START_TOKEN, REF_END_TOKEN,
REF_START_TOKEN)
from internvl.train.dataset import (ConcatDataset, TCSLoader,
WeightedConcatDataset, build_transform,
dynamic_preprocess, preprocess,
preprocess_internlm, preprocess_mpt,
preprocess_phi3)
from internvl.train.trainer_monkey_patch import replace_create_optimizer
from PIL import Image, ImageFile, PngImagePlugin, UnidentifiedImageError
from torch.utils.data import Dataset
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
HfArgumentParser, Trainer, TrainingArguments,
set_seed)
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils.logging import (enable_default_handler,
enable_explicit_format, set_verbosity)
# Apply necessary patches for the transformers library
replace_llama_rmsnorm_with_fused_rmsnorm()
replace_train_sampler()
# Try to import petrel_client for image loading, fallback to PIL if unavailable
try:
from petrel_client.client import Client
from petrel_client.common.config import Config
has_tcs_loader = True
except ImportError as E:
print('petrel_client is not installed. Using PIL to load images.')
has_tcs_loader = False
# Set constants for image processing and logging
IGNORE_INDEX = -100
Image.MAX_IMAGE_PIXELS = None
ImageFile.LOAD_TRUNCATED_IMAGES = True
MaximumDecompressedSize = 1024
MegaByte = 2 ** 20
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
warnings.filterwarnings('ignore')
logger = logging.getLogger(__name__)
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
@dataclass
class ModelArguments:
"""
Arguments for specifying model, tokenizer, and configurations.
"""
model_name_or_path: Optional[str] = field(
default=None,
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
vision_path: Optional[str] = field(
default=None,
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
llm_path: Optional[str] = field(
default=None,
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
mlp_path: Optional[str] = field(
default=None,
metadata={'help': 'Path to pretrained model or model identifier from huggingface.co/models'}
)
freeze_llm: bool = field(
default=False,
metadata={'help': 'Set to True to freeze the LLM decoder.'},
)
freeze_backbone: bool = field(
default=False,
metadata={'help': 'Set to True to freeze the vision backbone of the model.'},
)
freeze_mlp: bool = field(
default=False,
metadata={'help': 'Set to True to freeze the MLP layers of the model.'},
)
unfreeze_vit_layers: int = field(
default=0,
metadata={'help': 'Specify the number of ViT layers to unfreeze. Default is 0.'},
)
vision_select_layer: int = field(
default=-1,
metadata={'help': 'Specify the layer of ViT feature map to use. Default is last layer.'},
)
use_backbone_lora: int = field(
default=0,
metadata={'help': 'Set the LoRA adapter rank for the backbone model. Default is 0.'}
)
use_llm_lora: int = field(
default=0,
metadata={'help': 'Set the LoRA adapter rank for the LLM. Default is 0.'}
)
unfreeze_lm_head: bool = field(
default=False,
metadata={'help': "Set to True to unfreeze the language model's head."},
)
use_custom_trainer: bool = field(
default=False,
metadata={'help': 'Set to True to enable the use of a custom trainer.'},
)
grad_checkpoint: Optional[bool] = field(
default=False,
metadata={'help': 'Set to True to use gradient checkpointing.'},
)
drop_path_rate: float = field(
default=0.0,
metadata={'help': 'Set the drop path rate for the ViT model. Default is 0.'},
)
ps_version: str = field(
default='v2',
metadata={'help': 'Specify the version of pixel shuffle implementation. Default is `v1`.'
'Please use `v2` to fix the bug of transposed image.'}
)
@dataclass
class DataTrainingArguments:
"""
Arguments for specifying data input for training and evaluation.
"""
max_seq_length: Optional[int] = field(
default=2048,
metadata={
'help': (
'The maximum total input sequence length after tokenization. Sequences longer '
'than this will be truncated, sequences shorter will be padded.'
)
},
)
force_image_size: Optional[int] = field(
default=448,
metadata={'help': 'Set the desired size for the image. Default is 224.'},
)
down_sample_ratio: Optional[float] = field(
default=0.5,
metadata={'help': 'Set the desired down-sampling ratio for the image. Default is 1.0.'},
)
pad2square: Optional[bool] = field(
default=False,
metadata={'help': 'Pad the image to a square shape if set to True.'},
)
conv_style: Optional[str] = field(
default='internlm2-chat', metadata={'help': 'Prompt style for a conversation.'}
)
meta_path: Optional[str] = field(
default=None,
metadata={'help': 'The path of the meta file of datasets.'},
)
use_data_resampling: Optional[bool] = field(
default=False,
metadata={'help': 'Set to True to use data resampling.'},
)
dynamic_image_size: Optional[bool] = field(
default=False,
metadata={'help': 'Set to True to use dynamic image size.'},
)
use_thumbnail: Optional[bool] = field(
default=False,
metadata={'help': 'Set to True to add a thumbnail image.'},
)
min_dynamic_patch: Optional[int] = field(
default=1,
metadata={'help': 'The minimum number of dynamic patches. Default is 1.'},
)
max_dynamic_patch: Optional[int] = field(
default=12,
metadata={'help': 'The maximum number of dynamic patches. Default is 6.'},
)
normalize_type: Optional[str] = field(
default='imagenet',
metadata={'help': 'The normalize type for the image. Default is imagenet.'},
)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
template_name,
meta,
tokenizer,
tcs_loader,
ds_name,
num_image_token,
image_size=224,
is_train=True,
pad2square=False,
group_by_length=False,
dynamic_image_size=False,
use_thumbnail=False,
min_dynamic_patch=1,
max_dynamic_patch=6,
min_num_frame=4, # for video data
max_num_frame=12, # for video data
sampling_method='rand', # for video data
repeat_time=1,
normalize_type='imagenet',
random_seed=0,
):
super(LazySupervisedDataset, self).__init__()
self.ds_name = ds_name
self.tokenizer = tokenizer
self.template_name = template_name
self.num_image_token = num_image_token
logger.info(f'[Dataset] num_image_token: {num_image_token}')
logger.info(f'[Dataset] dynamic_image_size: {dynamic_image_size}')
logger.info(f'[Dataset] use_thumbnail: {use_thumbnail}')
logger.info(f'[Dataset] min_dynamic_patch: {min_dynamic_patch}, max_dynamic_patch: {max_dynamic_patch}')
self.image_size = image_size
self.is_train = is_train
self.pad2square = pad2square
self.max_num_frame = max_num_frame
self.min_num_frame = min_num_frame
self.sampling_method = sampling_method
logger.info('Formatting inputs...Skip in lazy mode')
assert meta['annotation'].endswith('jsonl'), f'annotation must be jsonl, but got {meta["annotation"]}'
total_ranks = torch.distributed.get_world_size()
current_rank = torch.distributed.get_rank()
"""
This section of the code is used to read hundreds of millions of data entries.
By using caching and splitting the data according to rank, it ensures fast reading
speed and prevents out-of-memory.
"""
# Create a cache directory path
basename = os.path.basename(meta['annotation']).replace('.jsonl', '')
data_dir = os.path.join(os.path.dirname(meta['annotation']), f'{basename}_temp')
os.makedirs(data_dir, exist_ok=True) # Create the cache directory if it does not exist
# Create a temporary path for the current rank
temp_path = os.path.join(data_dir, f'{basename}_{current_rank}_of_{total_ranks}.jsonl')
# Check if the temporary file for the current rank already exists
if os.path.exists(temp_path):
# If it exists, read the raw data from the file
with open(temp_path, 'r') as f:
self.raw_data = f.readlines()
else:
# If it does not exist, read the raw data from the original annotation file
with open(meta['annotation'], 'r') as f:
self.raw_data = f.readlines()
# Adjust the raw data based on the repeat_time parameter
if repeat_time < 1:
self.raw_data = self.raw_data[:int(len(self.raw_data) * repeat_time)]
else:
self.raw_data = self.raw_data * int(repeat_time)
# Calculate the total number of lines and distribute lines to each rank
total_lines = len(self.raw_data)
logger.info(f'total_ranks: {total_ranks}, current_rank: {current_rank}, total_lines: {total_lines}')
lines_per_rank = total_lines // total_ranks # Number of lines each rank should process
lines_per_rank = max(1, lines_per_rank)
# Calculate the start and end line numbers for the current rank
start_line = lines_per_rank * current_rank # Starting line for the current rank
end_line = start_line + lines_per_rank # Ending line for the current rank
# Assign the appropriate lines to the current rank
self.raw_data = self.raw_data[start_line:end_line]
# Write the raw data for the current rank to the temporary file
with open(temp_path, 'w') as f:
f.writelines(self.raw_data)
self.rng = np.random.default_rng(seed=random_seed)
self.rng.shuffle(self.raw_data)
gc.collect()
self.root = meta['root']
self.cached_data_dict = {}
self.tcs_loader = tcs_loader
self.group_by_length = group_by_length
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch
self.normalize_type = normalize_type
# If the precomputed length does not exist, roughly estimate the length of
# each sample to improve the efficiency of group_by_length.
if self.group_by_length:
self.conv2length = {} # Using a dictionary to speed up token length calculation
self.length = []
for data_item in self.raw_data:
data_item = json.loads(data_item)
if 'length' in data_item:
token_length = data_item['length'] # Use precomputed length if available
else:
# Compute token length using the tokenizer
conversations = '\n'.join([temp['value'] for temp in data_item['conversations']])
str_length = len(conversations)
if str_length not in self.conv2length:
token_length = tokenizer(
conversations, return_tensors='pt', padding=False, truncation=False,
).input_ids.size(1)
self.conv2length[str_length] = token_length + num_image_token * (
max_dynamic_patch + use_thumbnail)
else:
token_length = self.conv2length[str_length]
self.length.append(token_length)
gc.collect()
def __len__(self):
return len(self.raw_data) * torch.distributed.get_world_size()
def get_preprocess_function(self):
# Select the appropriate preprocessing function based on the template name
if self.template_name == 'Hermes-2':
preprocess_function = preprocess_mpt
elif self.template_name == 'internlm2-chat':
preprocess_function = preprocess_internlm
elif self.template_name == 'phi3-chat':
preprocess_function = preprocess_phi3
else:
preprocess_function = preprocess
return preprocess_function
def load_image(self, image_path):
# Load the image using tcs_loader if available, otherwise use PIL
if self.tcs_loader is not None and 's3://' in image_path:
return self.tcs_loader(image_path)
return Image.open(image_path).convert('RGB')
def get_image_path(self, image_path):
if image_path.startswith('s3://'): # for ceph
image_path = self.root + image_path
else: # for local image
image_path = os.path.join(self.root, image_path)
return image_path
def get_transform(self):
# Build transformation function
transform = build_transform(is_train=self.is_train, input_size=self.image_size,
pad2square=self.pad2square, normalize_type=self.normalize_type)
return transform
def multi_modal_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
# Ensure the first conversation contains an image placeholder
if '<image>' not in data_item['conversations'][0]['value']:
data_item['conversations'][0]['value'] = '<image>\n' + data_item['conversations'][0]['value']
# Merge the image path
image_path = self.get_image_path(data_item['image'])
# Load the image using tcs_loader if available, otherwise use PIL
image = self.load_image(image_path)
if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=self.max_dynamic_patch,
image_size=self.image_size, use_thumbnail=self.use_thumbnail)
else: # Otherwise, use the original image as a single patch
images = [image]
# Apply the transformation to each image and stack the results into a tensor
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
# Ensure that there is only one patch if dynamic image size is not enabled
num_patches = pixel_values.size(0)
if not self.dynamic_image_size:
assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, [self.num_image_token * num_patches],
group_by_length=self.group_by_length, ds_name=self.ds_name)
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
)
return ret
def multi_modal_multi_image_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
images, num_tiles = [], []
num_image = len(data_item['image'])
for image_path in data_item['image']:
# Merge the image path
image_path = self.get_image_path(image_path)
# Load the image using tcs_loader if available, otherwise use PIL
image = self.load_image(image_path)
if self.dynamic_image_size: # If dynamic image size is enabled, preprocess the image dynamically
image = dynamic_preprocess(image, min_num=self.min_dynamic_patch,
max_num=self.max_dynamic_patch // num_image,
image_size=self.image_size, use_thumbnail=self.use_thumbnail)
images += image
num_tiles.append(len(image))
else: # Otherwise, use the original image as a single patch
images.append(image)
num_tiles.append(1)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
num_patches = pixel_values.size(0)
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
num_image_tokens = [self.num_image_token * num_tile for num_tile in num_tiles]
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, num_image_tokens, group_by_length=self.group_by_length,
ds_name=self.ds_name, num_image=num_image)
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
)
return ret
def video_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
# Ensure the first conversation contains a video placeholder
if '<video>' not in data_item['conversations'][0]['value']:
data_item['conversations'][0]['value'] = '<video>\n' + data_item['conversations'][0]['value']
# Get the video file path
video_file = data_item['video']
video_path = os.path.join(self.root, video_file)
# Load the video frames using tcs_loader
# TODO: Load videos without using tcsloader.
image_list = self.tcs_loader(
video_path,
image_type='video',
max_num_frames=self.max_num_frame,
min_num_frames=self.min_num_frame,
sample=self.sampling_method,
clip=data_item.get('clip', None))
# Generate special tokens for each video frame
special_tokens = '\n'.join(['Frame{}: <image>'.format(i + 1) for i in range(len(image_list))])
data_item['conversations'][0]['value'] = data_item['conversations'][0]['value'].replace(
'<video>\n', special_tokens)
# Transform each frame image and stack them into a tensor
pixel_values = [transform(image) for image in image_list]
pixel_values = torch.stack(pixel_values)
num_patches = pixel_values.size(0)
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
num_image_tokens = [self.num_image_token] * num_patches
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, num_image_tokens, group_by_length=self.group_by_length,
ds_name=self.ds_name, num_image=num_patches)
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
pixel_values=pixel_values,
image_flags=torch.tensor([1] * num_patches, dtype=torch.long)
)
return ret
def pure_text_get_item(self, data_item):
# Build transformation function
transform = self.get_transform()
# Create a blank white image
image = Image.new('RGB', (224, 224), (255, 255, 255))
# Dynamically preprocess the image to generate patches
images = dynamic_preprocess(image, min_num=self.min_dynamic_patch, max_num=1,
image_size=self.image_size, use_thumbnail=self.use_thumbnail)
# Apply the transformation to each image patch and stack them into a tensor
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
num_patches = pixel_values.size(0)
# Ensure there is only one patch
assert num_patches == 1, f'The number of patches should be 1, but got {num_patches}.'
# Select the appropriate preprocessing function based on the template name
preprocess_function = self.get_preprocess_function()
# Preprocess the conversations and generate the return dictionary
ret = preprocess_function(self.template_name, [deepcopy(data_item['conversations'])],
self.tokenizer, [self.num_image_token * num_patches], text_only=True,
group_by_length=self.group_by_length, ds_name=self.ds_name)
# Create the final return dictionary
ret = dict(
input_ids=ret['input_ids'][0],
labels=ret['labels'][0],
attention_mask=ret['attention_mask'][0],
pixel_values=pixel_values,
image_flags=torch.tensor([0] * num_patches, dtype=torch.long)
)
return ret
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
i = i % len(self.raw_data)
while True:
try:
data_item = json.loads(self.raw_data[i])
if 'image' in data_item and len(data_item['image']) != 0:
if type(data_item['image']) == list:
ret = self.multi_modal_multi_image_get_item(data_item)
else:
ret = self.multi_modal_get_item(data_item)
elif 'video' in data_item and data_item['video'] is not None and data_item['video'] != '':
ret = self.video_get_item(data_item)
else:
ret = self.pure_text_get_item(data_item)
break
except Exception as e:
print(e, self.ds_name, flush=True)
if not isinstance(e, UnidentifiedImageError):
traceback.print_exc()
data_item = json.loads(self.raw_data[i])
if 'image' in data_item:
if type(data_item['image']) == list:
images = [self.root + item for item in data_item['image']]
print(f'Failed to load image: {images}, the dataset is: {self.ds_name}')
else:
if data_item['image'].startswith('s3://'):
data_path = self.root + data_item['image']
else:
data_path = os.path.join(self.root, data_item['image'])
print(f'Failed to load image: {data_path}, the dataset is: {self.ds_name}')
elif 'video' in data_item:
data_path = os.path.join(self.root, data_item['video'])
print(f'Failed to load video: {data_path}, the dataset is: {self.ds_name}')
i = random.randint(0, len(self.raw_data) - 1)
return ret
def build_datasets(
data_args,
tokenizer,
tcs_loader,
model,
group_by_length=False,
dynamic_image_size=False,
use_thumbnail=False,
min_dynamic_patch=1,
max_dynamic_patch=12,
normalize_type='imagenet',
):
datasets = []
lengths = []
ds_collections = json.loads(open(data_args.meta_path).read())
for ds_idx, ds_name in enumerate(ds_collections.keys()):
repeat_time = ds_collections[ds_name]['repeat_time']
if 'max_dynamic_patch' in ds_collections[ds_name]:
max_num = ds_collections[ds_name]['max_dynamic_patch']
logger.info(f'max_dynamic_patch is set to {max_num} according to the meta file')
else:
max_num = max_dynamic_patch
dataset = LazySupervisedDataset(
data_args.conv_style, ds_collections[ds_name],
tokenizer,
tcs_loader,
ds_name=ds_name,
num_image_token=model.num_image_token,
image_size=data_args.force_image_size,
is_train=ds_collections[ds_name]['data_augment'],
pad2square=data_args.pad2square,
group_by_length=group_by_length,
dynamic_image_size=dynamic_image_size,
use_thumbnail=use_thumbnail,
min_dynamic_patch=min_dynamic_patch,
max_dynamic_patch=max_num,
repeat_time=repeat_time,
normalize_type=normalize_type,
random_seed=ds_idx,
)
logger.info(f'Add dataset: {ds_name} with length: {len(dataset)}')
datasets.append(dataset)
if data_args.use_data_resampling:
lengths.append(math.sqrt(len(dataset)))
else:
lengths.append(len(dataset))
if data_args.use_data_resampling:
total_length = sum(lengths)
weights = [l / total_length for l in lengths]
train_dataset = WeightedConcatDataset(datasets, weights)
else:
train_dataset = ConcatDataset(datasets)
return train_dataset
def main():
# Parse input arguments
# See all possible arguments in src/transformers/training_args.py
# If use DeepSpeed zero3, init_dist must before HfArgumentParser
launcher = os.environ.get('LAUNCHER', 'slurm')
init_dist(launcher=launcher, backend='nccl')
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith('.json'):
# If we pass only one argument to the script, and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
# send_example_telemetry('InternV-Chat', model_args, data_args)
# Setup logging
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
handlers=[logging.StreamHandler(sys.stdout)],
)
if training_args.should_log:
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
transformers.utils.logging.set_verbosity_info()
log_level = training_args.get_process_log_level()
logger.setLevel(log_level)
set_verbosity(log_level)
enable_default_handler()
enable_explicit_format()
# Log on each process the small summary:
logger.warning(
f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
+ f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
)
logger.info(f'Training/evaluation parameters {training_args}')
# Detecting last checkpoint and eventually continue from last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f'Output directory ({training_args.output_dir}) already exists and is not empty. '
'Use --overwrite_output_dir to overcome.'
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change '
'the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
)
# Set seed before initializing model.
set_seed(training_args.seed)
# Load pretrained model, tokenizer, and image processor
tokenizer_path = model_args.model_name_or_path or model_args.llm_path
logger.info(f'Loading Tokenizer: {tokenizer_path}')
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_path, add_eos_token=False, trust_remote_code=True, use_fast=False)
tokenizer.tokenizer_path = tokenizer_path
tokenizer.model_max_length = data_args.max_seq_length
token_list = [IMG_START_TOKEN, IMG_END_TOKEN, IMG_CONTEXT_TOKEN,
QUAD_START_TOKEN, QUAD_END_TOKEN, REF_START_TOKEN,
REF_END_TOKEN, BOX_START_TOKEN, BOX_END_TOKEN]
num_new_tokens = tokenizer.add_tokens(token_list, special_tokens=True)
img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
tcs_loader = TCSLoader('~/petreloss.conf') if has_tcs_loader else None
if model_args.model_name_or_path is not None:
logger.info('Loading InternVLChatModel...')
config = InternVLChatConfig.from_pretrained(model_args.model_name_or_path)
config.vision_config.drop_path_rate = model_args.drop_path_rate
if config.llm_config.model_type == 'internlm2':
config.llm_config.attn_implementation = 'flash_attention_2' # for InternLM
logger.info('Using flash_attention_2 for InternLM')
else:
config.llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
logger.info('Using flash_attention_2 for LLaMA')
config.template = data_args.conv_style
config.select_layer = model_args.vision_select_layer
config.dynamic_image_size = data_args.dynamic_image_size
config.use_thumbnail = data_args.use_thumbnail
config.ps_version = model_args.ps_version
config.min_dynamic_patch = data_args.min_dynamic_patch
config.max_dynamic_patch = data_args.max_dynamic_patch
model = InternVLChatModel.from_pretrained(
model_args.model_name_or_path, torch_dtype=torch.bfloat16, config=config)
else:
logger.info('Loading ViT-6B...')
vision_config = InternVisionConfig.from_pretrained(model_args.vision_path)
vision_config.drop_path_rate = model_args.drop_path_rate
vision_model = InternVisionModel.from_pretrained(
model_args.vision_path, torch_dtype=torch.bfloat16, config=vision_config)
logger.info('Loading LLaMA...')
llm_config = AutoConfig.from_pretrained(model_args.llm_path, trust_remote_code=True)
if llm_config.model_type == 'internlm2':
model_type = InternLM2ForCausalLM
llm_config.attn_implementation = 'flash_attention_2' # for InternLM
logger.info('Using flash_attention_2 for InternLM')
else:
model_type = AutoModelForCausalLM
llm_config._attn_implementation = 'flash_attention_2' # for LLaMA
logger.info('Using flash_attention_2 for LLaMA')
llm = model_type.from_pretrained(
model_args.llm_path, torch_dtype=torch.bfloat16,
config=llm_config, trust_remote_code=True)
logger.info('Building InternVLChatConfig...')
internvl_chat_config = InternVLChatConfig(
vision_config.to_dict(), llm_config.to_dict(), downsample_ratio=data_args.down_sample_ratio,
pad2square=data_args.pad2square, template=data_args.conv_style,
select_layer=model_args.vision_select_layer, dynamic_image_size=data_args.dynamic_image_size,
use_thumbnail=data_args.use_thumbnail, ps_version=model_args.ps_version,
min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch)
internvl_chat_config.force_image_size = data_args.force_image_size
logger.info('Building InternVLChatModel...')
model = InternVLChatModel(internvl_chat_config, vision_model, llm)
model.img_context_token_id = img_context_token_id
assert model.config.downsample_ratio == data_args.down_sample_ratio
if model_args.mlp_path is not None:
logger.info('Loading pretrained MLP projector...')
state_dict = torch.load(model_args.mlp_path, map_location='cpu')
message = model.mlp1.load_state_dict(state_dict)
logger.info(message)
logger.info('Finished')
patch_size = model.config.vision_config.patch_size
logger.info(f'model.config.force_image_size: {model.config.force_image_size}')
logger.info(f'data_args.force_image_size: {data_args.force_image_size}')
logger.info(f'model.config.vision_config.image_size: {model.config.vision_config.image_size}')
if model.config.vision_config.image_size != data_args.force_image_size:
logger.info(f'Resizing position embedding from '
f'{model.config.vision_config.image_size} '
f'to {data_args.force_image_size}...')
model.vision_model.resize_pos_embeddings(old_size=model.config.vision_config.image_size,
new_size=data_args.force_image_size,
patch_size=patch_size)
model.config.vision_config.image_size = data_args.force_image_size
model.config.force_image_size = data_args.force_image_size
model.num_image_token = int((data_args.force_image_size // patch_size) ** 2 * (data_args.down_sample_ratio ** 2))
if num_new_tokens > 0:
model.language_model.resize_token_embeddings(len(tokenizer))
output_embeddings = model.language_model.get_output_embeddings().weight.data
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings[-num_new_tokens:] = output_embeddings_avg
model.config.llm_config.vocab_size = len(tokenizer)
model.language_model.config.vocab_size = len(tokenizer)
model.language_model.config.use_cache = False
model.vision_model.gradient_checkpointing = True
model.vision_model.encoder.gradient_checkpointing = True
if model_args.grad_checkpoint:
model.language_model._set_gradient_checkpointing()
train_dataset = build_datasets(
data_args, tokenizer, tcs_loader, model, group_by_length=training_args.group_by_length,
dynamic_image_size=data_args.dynamic_image_size, use_thumbnail=data_args.use_thumbnail,
min_dynamic_patch=data_args.min_dynamic_patch, max_dynamic_patch=data_args.max_dynamic_patch,
normalize_type=data_args.normalize_type)
def _freeze_params(module):
for param in module.parameters():
param.requires_grad = False
if model_args.freeze_backbone:
# model.vision_model = model.vision_model.eval()
_freeze_params(model.vision_model)
if model_args.freeze_llm:
model.language_model = model.language_model.eval()
_freeze_params(model.language_model)
if model_args.unfreeze_lm_head:
model.language_model.lm_head.requires_grad = True
if model_args.use_backbone_lora:
model.wrap_backbone_lora(r=model_args.use_backbone_lora, lora_alpha=2 * model_args.use_backbone_lora)
model.config.use_backbone_lora = model_args.use_backbone_lora
if model_args.use_llm_lora:
model.wrap_llm_lora(r=model_args.use_llm_lora, lora_alpha=2 * model_args.use_llm_lora)
model.config.use_llm_lora = model_args.use_llm_lora
if model_args.freeze_mlp:
_freeze_params(model.mlp1)
if model_args.unfreeze_vit_layers != 0:
layers = model.vision_model.encoder.layers[model_args.unfreeze_vit_layers:]
for k, v in layers.named_parameters():
logger.info(f'Unfreezing ViT layer: {k}')
v.requires_grad = True
# print trainable parameters
if dist.get_rank() == 0:
for name, param in model.named_parameters():
if param.requires_grad:
logger.info(name)
# set seed for torch dataloaders
set_seed(training_args.seed)
# Initialize our Trainer
if model_args.use_custom_trainer:
replace_create_optimizer()
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=None,
tokenizer=tokenizer,
data_collator=concat_pad_data_collator
)
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
try:
metrics['train_samples'] = len(train_dataset)
except:
metrics['train_samples'] = -1
trainer.log_metrics('train', metrics)
trainer.save_metrics('train', metrics)
trainer.save_state()
if __name__ == '__main__':
main()
import json
import os
import torch
import torch.nn as nn
import transformers
from transformers import Trainer, logging
from transformers.trainer import is_sagemaker_mp_enabled
logger = logging.get_logger(__name__)
def get_num_layer_for_vit_and_qllama(var_name, vit_num_max_layer, llama_num_max_layer):
if var_name.startswith('internvl.'):
var_name = var_name[len('internvl.'):]
if var_name in ('query_tokens', 'logit_scale',):
return 0
if var_name.startswith('clip_projector.'):
return vit_num_max_layer
if var_name.startswith('clip_projector2.') or var_name.startswith('itm_head.') or \
var_name == 'text_projection':
return llama_num_max_layer
if var_name.startswith('vision_model.'):
if 'embeddings.' in var_name:
return 0
if 'layers.' in var_name:
var_name = var_name.split('layers.')[-1]
layer_id = int(var_name.split('.')[0])
return layer_id + 1
if var_name.startswith('qllama.'):
if 'embed_tokens' in var_name:
return 0
if 'layers.' in var_name:
var_name = var_name.split('layers.')[-1]
layer_id = int(var_name.split('.')[0])
return layer_id + 1
else:
return llama_num_max_layer
return 0
def param_classification(name):
if name.startswith('internvl.'):
name = name[len('internvl.'):]
if name in ['query_tokens', 'text_projection', 'logit_scale']:
return 'qllama'
elif name.startswith('vision_model.'):
return 'vit'
elif name.startswith('qllama.'):
return 'qllama'
elif name.startswith('clip_projector.'):
return 'vit'
elif name.startswith('clip_projector2.'):
return 'qllama'
elif name.startswith('itm_head.'):
return 'qllama'
else:
return 'other'
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
parameter_groups = {}
try: # for stage2 model
vit_num_layers = opt_model.config.vision_config.num_hidden_layers + 2
qllama_num_layers = opt_model.config.qllama_config.num_hidden_layers + 2
except: # for stage3 model
vit_num_layers = opt_model.internvl.config.vision_config.num_hidden_layers + 2
qllama_num_layers = opt_model.internvl.config.qllama_config.num_hidden_layers + 2
print('vit_num_layers:', vit_num_layers)
print('qllama_num_layers:', qllama_num_layers)
vit_layer_decay_rate = float(os.getenv('VIT_LAYER_DECAY_RATE', 1.0))
qllama_layer_decay_rate = float(os.getenv('QLLAMA_LAYER_DECAY_RATE', 1.0))
qllama_lr_scale = float(os.getenv('QLLAMA_LR_SCALE', 1.0))
print('vit_layer_decay_rate:', vit_layer_decay_rate)
print('qllama_layer_decay_rate:', qllama_layer_decay_rate)
print('qllama_lr_scale:', qllama_lr_scale)
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue # frozen weights
if len(param.shape) == 1 or name.endswith('.bias'):
group_name = 'no_decay'
this_weight_decay = 0.
else:
group_name = 'decay'
this_weight_decay = self.args.weight_decay
cls = param_classification(name)
layer_id = get_num_layer_for_vit_and_qllama(name, vit_num_layers, qllama_num_layers)
group_name = '%s_layer_%d_%s' % (cls, layer_id, group_name)
if group_name not in parameter_groups:
if cls == 'vit':
scale = vit_layer_decay_rate ** (vit_num_layers - layer_id - 1)
elif cls == 'qllama':
scale = qllama_layer_decay_rate ** (qllama_num_layers - layer_id - 1)
scale = scale * qllama_lr_scale
else:
scale = 1.0
scale = min(1.0, scale)
parameter_groups[group_name] = {
'weight_decay': this_weight_decay,
'params': [],
'param_names': [],
'lr_scale': scale,
'group_name': group_name,
'lr': scale * self.args.learning_rate,
}
parameter_groups[group_name]['params'].append(param)
parameter_groups[group_name]['param_names'].append(name)
rank = torch.distributed.get_rank()
if rank == 0:
to_display = {}
for key in parameter_groups:
to_display[key] = {
'param_names': parameter_groups[key]['param_names'],
'lr_scale': parameter_groups[key]['lr_scale'],
'lr': parameter_groups[key]['lr'],
'weight_decay': parameter_groups[key]['weight_decay'],
}
print('Param groups = %s' % json.dumps(to_display, indent=2))
optimizer_grouped_parameters = list(parameter_groups.values())
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if optimizer_cls.__name__ == 'Adam8bit':
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
logger.info(f'skipped {module}: {skipped / 2 ** 20}M params')
manager.register_module_override(module, 'weight', {'optim_bits': 32})
logger.debug(f'bitsandbytes: will optimize {module} in fp32')
logger.info(f'skipped: {skipped / 2 ** 20}M params')
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
self.optimizer = smp.DistributedOptimizer(self.optimizer)
return self.optimizer
def replace_create_optimizer():
print('Replace original create_optimizer with custom create_optimizer')
transformers.Trainer.create_optimizer = create_optimizer
import numpy as np
import torch
import torchvision.transforms as T
from decord import VideoReader, cpu
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
import os
import math
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file, input_size=448, max_num=6):
image = Image.open(image_file).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
def split_model(model_name):
device_map = {}
world_size = torch.cuda.device_count()
num_layers = {
'InternVL2-1B': 24, 'InternVL2-2B': 24, 'InternVL2-4B': 32, 'InternVL2-8B': 32,
'InternVL2-26B': 48, 'InternVL2-40B': 60, 'InternVL2-Llama3-76B': 80}[model_name]
# Since the first GPU will be used for ViT, treat it as half a GPU.
num_layers_per_gpu = math.ceil(num_layers / (world_size - 0.5))
num_layers_per_gpu = [num_layers_per_gpu] * world_size
num_layers_per_gpu[0] = math.ceil(num_layers_per_gpu[0] * 0.5)
layer_cnt = 0
for i, num_layer in enumerate(num_layers_per_gpu):
for j in range(num_layer):
device_map[f'language_model.model.layers.{layer_cnt}'] = i
layer_cnt += 1
device_map['vision_model'] = 0
device_map['mlp1'] = 0
device_map['language_model.model.tok_embeddings'] = 0
device_map['language_model.model.embed_tokens'] = 0
device_map['language_model.output'] = 0
device_map['language_model.model.norm'] = 0
device_map['language_model.lm_head'] = 0
device_map[f'language_model.model.layers.{num_layers - 1}'] = 0
return device_map
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
path = '/home/wanglch/projects/InternVL/InternVL2-40B'
device_map = split_model('InternVL2-40B')
model = AutoModel.from_pretrained(
path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True,
device_map=device_map).eval()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)
# set the max number of tiles in `max_num`
pixel_values = load_image('/home/wanglch/projects/InternVL/internvl_chat/examples/image1.jpg', max_num=12).to(torch.float16).cuda()
generation_config = dict(max_new_tokens=1024, do_sample=True)
# pure-text conversation (纯文本对话)
question = 'Hello, who are you?'
response, history = model.chat(tokenizer, None, question, generation_config, history=None, return_history=True)
print(f'User: {question}')
print(f'Assistant: {response}')
question = 'Can you tell me a story?'
response, history = model.chat(tokenizer, None, question, generation_config, history=history, return_history=True)
print(f'User: {question}')
print(f'Assistant: {response}')
# single-image single-round conversation (单图单轮对话)
question = '<image>\nPlease describe the image shortly.'
response = model.chat(tokenizer, pixel_values, question, generation_config)
print(f'User: {question}')
print(f'Assistant: {response}')
# single-image multi-round conversation (单图多轮对话)
question = '<image>\nPlease describe the image in detail.'
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
print(f'User: {question}')
print(f'Assistant: {response}')
question = 'Please write a poem according to the image.'
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=history, return_history=True)
print(f'User: {question}')
print(f'Assistant: {response}')
# InternVL-Chat
This folder contains the implementation of the InternVL-Chat.
## 🛠️ Installation
See [INSTALLATION.md](../INSTALLATION.md)
In addition, using this codebase requires executing the following steps:
- Install other requirements:
```bash
pip install --upgrade pip # enable PEP 660 support
pip install -e .
```
## 📖 Documents
- InternVL 2.0
- Introduction [\[link\]](https://internvl.readthedocs.io/en/latest/internvl2.0/introduction.html)
- Quick Start [\[link\]](https://internvl.readthedocs.io/en/latest/internvl2.0/quick_start.html)
- Finetune [\[link\]](https://internvl.readthedocs.io/en/latest/internvl2.0/finetune.html)
- Evaluation [\[link\]](https://internvl.readthedocs.io/en/latest/internvl2.0/evaluation.html)
- Deployment [\[link\]](https://internvl.readthedocs.io/en/latest/internvl2.0/deployment.html)
- InternVL 1.5
- Introduction [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.5/introduction.html)
- Quick Start [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.5/quick_start.html)
- Finetune [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.5/finetune.html)
- Evaluation [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.5/evaluation.html)
- Deployment [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.5/deployment.html)
- InternVL 1.2
- Introduction [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.2/introduction.html)
- Quick Start [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.2/quick_start.html)
- Reproduce [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.2/reproduce.html)
- Finetune [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.2/finetune.html)
- Evaluation [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.2/evaluation.html)
- InternVL 1.1
- Introduction [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.1/introduction.html)
- Quick Start [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.1/quick_start.html)
- Evaluation [\[link\]](https://internvl.readthedocs.io/en/latest/internvl1.1/evaluation.html)
import argparse
import itertools
import json
import os
import random
import time
from functools import partial
import torch
from internvl.model.internvl_chat import InternVLChatModel
from internvl.train.dataset import build_transform, dynamic_preprocess
from PIL import Image
from pycocoevalcap.eval import COCOEvalCap
from pycocotools.coco import COCO
from tqdm import tqdm
from transformers import AutoTokenizer
ds_collections = {
'flickr30k': {
'root': 'data/flickr30k/',
'annotation': 'data/flickr30k/flickr30k_test_karpathy.json',
'max_new_tokens': 30,
'min_new_tokens': 8,
},
'coco': {
'root': 'data/coco/',
'annotation': ['data/coco/annotations/coco_karpathy_test.json',
'data/coco/annotations/coco_karpathy_test_gt.json'],
'max_new_tokens': 30,
'min_new_tokens': 8,
},
'nocaps': {
'root': 'data/nocaps/images',
'annotation': 'data/nocaps/nocaps_val_4500_captions.json',
'max_new_tokens': 30,
'min_new_tokens': 8,
},
}
class CaptionDataset(torch.utils.data.Dataset):
def __init__(self, name, root, annotation, prompt, input_size=224, dynamic_image_size=False,
use_thumbnail=False, max_num=6):
if name == 'coco':
self.images = json.load(open(annotation))
else:
self.images = json.load(open(annotation))['images']
self.name = name
self.prompt = prompt
self.root = root
self.input_size = input_size
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.max_num = max_num
self.transform = build_transform(is_train=False, input_size=input_size)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
if self.name == 'coco':
filename = self.images[idx]['image']
image_id = int(filename.split('_')[-1].replace('.jpg', ''))
image_path = os.path.join(self.root, filename)
else:
image_id = self.images[idx]['id']
if 'file_name' in self.images[idx]:
image_path = os.path.join(self.root, self.images[idx]['file_name'])
else:
image_path = os.path.join(self.root, self.images[idx]['image'])
image = Image.open(image_path)
if self.dynamic_image_size:
images = dynamic_preprocess(image, image_size=self.input_size,
use_thumbnail=self.use_thumbnail,
max_num=self.max_num)
else:
images = [image]
pixel_values = [self.transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return {
'image_id': image_id,
'input_text': self.prompt,
'pixel_values': pixel_values
}
def collate_fn(inputs, tokenizer):
pixel_values = torch.cat([_['pixel_values'] for _ in inputs], dim=0)
image_ids = [_['image_id'] for _ in inputs]
input_texts = [_['input_text'] for _ in inputs]
input_tokens = tokenizer(input_texts, return_tensors='pt')
return pixel_values, image_ids, input_tokens.input_ids, input_tokens.attention_mask
class InferenceSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self._size = int(size)
assert size > 0
self._rank = torch.distributed.get_rank()
self._world_size = torch.distributed.get_world_size()
self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
@staticmethod
def _get_local_indices(total_size, world_size, rank):
shard_size = total_size // world_size
left = total_size % world_size
shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
begin = sum(shard_sizes[:rank])
end = min(sum(shard_sizes[:rank + 1]), total_size)
return range(begin, end)
def __iter__(self):
yield from self._local_indices
def __len__(self):
return len(self._local_indices)
def evaluate_chat_model():
prompt = 'Provide a one-sentence caption for the provided image.'
print('prompt:', prompt)
random.seed(args.seed)
summaries = []
for ds_name in args.datasets:
annotation = ds_collections[ds_name]['annotation']
if type(annotation) == list:
annotation = annotation[0]
dataset = CaptionDataset(
name=ds_name,
root=ds_collections[ds_name]['root'],
annotation=annotation,
prompt=prompt,
input_size=image_size,
dynamic_image_size=args.dynamic,
use_thumbnail=use_thumbnail,
max_num=args.max_num
)
dataloader = torch.utils.data.DataLoader(
dataset=dataset,
sampler=InferenceSampler(len(dataset)),
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=True,
drop_last=False,
collate_fn=partial(collate_fn, tokenizer=tokenizer),
)
image_ids, captions = [], []
for _, (pixel_values, ids, _, _) in tqdm(enumerate(dataloader)):
pixel_values = pixel_values.to(torch.bfloat16).cuda()
generation_config = dict(
num_beams=args.num_beams,
max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
min_new_tokens=ds_collections[ds_name]['min_new_tokens'],
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
)
pred = model.chat(
tokenizer=tokenizer,
pixel_values=pixel_values,
question=prompt,
generation_config=generation_config,
verbose=True
)
image_ids.extend(ids)
captions.extend([pred])
torch.distributed.barrier()
world_size = torch.distributed.get_world_size()
merged_ids = [None for _ in range(world_size)]
merged_captions = [None for _ in range(world_size)]
torch.distributed.all_gather_object(merged_ids, image_ids)
torch.distributed.all_gather_object(merged_captions, captions)
merged_ids = [_ for _ in itertools.chain.from_iterable(merged_ids)]
merged_captions = [_ for _ in itertools.chain.from_iterable(merged_captions)]
average_length = sum(len(x.split()) for x in merged_captions) / len(merged_captions)
print(f'Average caption length: {average_length}')
if torch.distributed.get_rank() == 0:
print(f'Evaluating {ds_name} ...')
results = []
for image_id, caption in zip(merged_ids, merged_captions):
results.append({
'image_id': int(image_id),
'caption': caption,
})
time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
results_file = f'{ds_name}_{time_prefix}.json'
results_file = os.path.join(args.out_dir, results_file)
json.dump(results, open(results_file, 'w'))
annotation = ds_collections[ds_name]['annotation']
if type(annotation) == list:
annotation = annotation[-1]
coco = COCO(annotation)
coco_result = coco.loadRes(results_file)
coco_eval = COCOEvalCap(coco, coco_result)
coco_eval.evaluate()
summary = coco_eval.eval.items()
print(summary)
summaries.append([args.checkpoint, ds_name, average_length, summary])
torch.distributed.barrier()
out_path = '_'.join(args.checkpoint.split('/')[-2:])
writer = open(os.path.join(args.out_dir, f'{out_path}.txt'), 'a')
print(f"write results to file {os.path.join(args.out_dir, f'{out_path}.txt')}")
for summary in summaries:
print(summary)
writer.write(f'{summary}\n')
writer.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='')
parser.add_argument('--datasets', type=str, default='coco,flickr30k,nocaps')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--num-beams', type=int, default=5)
parser.add_argument('--temperature', type=float, default=0.0)
parser.add_argument('--out-dir', type=str, default='results')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--dynamic', action='store_true')
parser.add_argument('--max-num', type=int, default=6)
parser.add_argument('--load-in-8bit', action='store_true')
parser.add_argument('--load-in-4bit', action='store_true')
parser.add_argument('--auto', action='store_true')
args = parser.parse_args()
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
args.datasets = args.datasets.split(',')
print('datasets:', args.datasets)
assert args.batch_size == 1, 'Only batch size 1 is supported'
torch.distributed.init_process_group(
backend='nccl',
world_size=int(os.getenv('WORLD_SIZE', '1')),
rank=int(os.getenv('RANK', '0')),
)
torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
if args.auto:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
kwargs = {'device_map': 'auto'} if args.auto else {}
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
model = InternVLChatModel.from_pretrained(
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, **kwargs).eval()
if not args.load_in_8bit and not args.load_in_4bit and not args.auto:
model = model.cuda()
image_size = model.config.force_image_size or model.config.vision_config.image_size
use_thumbnail = model.config.use_thumbnail
total_params = sum(p.numel() for p in model.parameters()) / 1e9
if total_params > 20 or args.dynamic:
args.num_beams = 1
print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
else:
print(f'[test] total_params: {total_params}B')
print(f'[test] image_size: {image_size}')
print(f'[test] template: {model.config.template}')
print(f'[test] dynamic_image_size: {args.dynamic}')
print(f'[test] use_thumbnail: {use_thumbnail}')
print(f'[test] max_num: {args.max_num}')
evaluate_chat_model()
import argparse
import json
import os
import random
import torch
from internvl.model.internvl_chat import InternVLChatModel
from internvl.train.dataset import build_transform, dynamic_preprocess
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer
ds_collections = {
'art_and_design': {
'root': 'data/',
'annotation': 'data/cmmmu-data/llava_art_and_design.jsonl',
'max_new_tokens': 999,
'min_new_tokens': 1,
},
'business': {
'root': 'data/',
'annotation': 'data/cmmmu-data/llava_business.jsonl',
'max_new_tokens': 999,
'min_new_tokens': 1,
},
'health_and_medicine': {
'root': 'data/',
'annotation': 'data/cmmmu-data/llava_health_and_medicine.jsonl',
'max_new_tokens': 999,
'min_new_tokens': 1,
},
'humanities_and_social_sciences': {
'root': 'data/',
'annotation': 'data/cmmmu-data/llava_humanities_and_social_sciences.jsonl',
'max_new_tokens': 999,
'min_new_tokens': 1,
},
'science': {
'root': 'data/',
'annotation': 'data/cmmmu-data/llava_science.jsonl',
'max_new_tokens': 999,
'min_new_tokens': 1,
},
'technology_and_engineering': {
'root': 'data/',
'annotation': 'data/cmmmu-data/llava_technology_and_engineering.jsonl',
'max_new_tokens': 999,
'min_new_tokens': 1,
},
}
class VQADataset(torch.utils.data.Dataset):
def __init__(self, root, annotation, input_size=224, dynamic_image_size=False,
use_thumbnail=False, max_num=6):
self.root = root
self.items = []
f = open(annotation)
data = f.readlines()
for data_line in data:
data_line = json.loads(data_line)
self.items.append(data_line)
self.input_size = input_size
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.max_num = max_num
self.transform = build_transform(is_train=False, input_size=input_size)
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
item = self.items[idx]
image_path, question = item['image'], item['text']
image_path = os.path.join(self.root, image_path)
image = Image.open(image_path).convert('RGB')
if self.dynamic_image_size:
images = dynamic_preprocess(image, image_size=self.input_size,
use_thumbnail=self.use_thumbnail,
max_num=self.max_num)
else:
images = [image]
pixel_values = [self.transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return {
'question': question,
'pixel_values': pixel_values,
'item': item,
}
def evaluate_chat_model():
random.seed(args.seed)
for ds_name in args.datasets:
dataset = VQADataset(
root=ds_collections[ds_name]['root'],
annotation=ds_collections[ds_name]['annotation'],
input_size=image_size,
dynamic_image_size=args.dynamic,
use_thumbnail=use_thumbnail,
max_num=args.max_num
)
print(f'Evaluating {ds_name} ...')
results_file = f'{model_id}_{ds_name}.jsonl'
results_file = os.path.join(args.out_dir, results_file)
writer = open(results_file, 'w')
for _, data in tqdm(enumerate(dataset)):
pixel_value = data['pixel_values']
question = data['question']
item = data['item']
pixel_value = pixel_value.to(torch.bfloat16).cuda()
generation_config = dict(
num_beams=args.num_beams,
max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
min_new_tokens=ds_collections[ds_name]['min_new_tokens'],
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
)
pred = model.chat(
tokenizer=tokenizer,
pixel_values=pixel_value,
question=question,
generation_config=generation_config,
verbose=True
)
question_id = item['question_id']
text = item['text']
output = {
'question_id': question_id,
'prompt': text,
'text': pred,
'model_id': model_id,
'metadata': {}
}
writer.write(json.dumps(output, ensure_ascii=False) + '\n')
writer.flush()
print('Results saved to {}'.format(results_file))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='')
parser.add_argument('--datasets', type=str, default='')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--num-beams', type=int, default=1)
parser.add_argument('--temperature', type=float, default=0.0)
parser.add_argument('--out-dir', type=str, default='results')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--dynamic', action='store_true')
parser.add_argument('--max-num', type=int, default=6)
parser.add_argument('--load-in-8bit', action='store_true')
parser.add_argument('--load-in-4bit', action='store_true')
parser.add_argument('--auto', action='store_true')
args = parser.parse_args()
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
args.datasets = args.datasets.split(',')
print('datasets:', args.datasets)
assert args.batch_size == 1, 'Only batch size 1 is supported'
if args.auto:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
kwargs = {'device_map': 'auto'} if args.auto else {}
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
model = InternVLChatModel.from_pretrained(
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, **kwargs).eval()
if not args.load_in_8bit and not args.load_in_4bit and not args.auto:
model = model.cuda()
image_size = model.config.force_image_size or model.config.vision_config.image_size
use_thumbnail = model.config.use_thumbnail
total_params = sum(p.numel() for p in model.parameters()) / 1e9
if total_params > 20 or args.dynamic:
args.num_beams = 1
print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
else:
print(f'[test] total_params: {total_params}B')
print(f'[test] image_size: {image_size}')
print(f'[test] template: {model.config.template}')
print(f'[test] dynamic_image_size: {args.dynamic}')
print(f'[test] use_thumbnail: {use_thumbnail}')
print(f'[test] max_num: {args.max_num}')
model_id = '_'.join(args.checkpoint.split('/')[-2:])
evaluate_chat_model()
import argparse
import json
import os
import time
import openai
NUM_SECONDS_TO_SLEEP = 0.5
def get_eval(content: str, max_tokens: int):
while True:
try:
completion = openai.chat.completions.create(
model='gpt-4-0613',
messages=[{
'role': 'system',
'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
}, {
'role': 'user',
'content': content,
}],
temperature=0.2, # TODO: figure out which temperature is best for evaluation
max_tokens=max_tokens,
)
break
except Exception as e:
print(e)
time.sleep(NUM_SECONDS_TO_SLEEP)
return completion.choices[0].message.content
def parse_score(review):
try:
score_pair = review.split('\n')[0]
score_pair = score_pair.replace(',', ' ')
sp = score_pair.split(' ')
if len(sp) == 2:
return [float(sp[0]), float(sp[1])]
else:
print('error', review)
return [-1, -1]
except Exception as e:
print(e)
print('error', review)
return [-1, -1]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
parser.add_argument('-q', '--question')
parser.add_argument('-c', '--context')
parser.add_argument('-a', '--answer-list', nargs='+', default=[])
parser.add_argument('-r', '--rule')
parser.add_argument('-o', '--output')
parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
args = parser.parse_args()
f_q = open(os.path.expanduser(args.question))
f_ans1 = open(os.path.expanduser(args.answer_list[0]))
f_ans2 = open(os.path.expanduser(args.answer_list[1]))
rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
if os.path.isfile(os.path.expanduser(args.output)):
cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
else:
cur_reviews = []
review_file = open(f'{args.output}', 'a')
context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
image_to_context = {context['image']: context for context in context_list}
handles = []
idx = 0
for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
ques = json.loads(ques_js)
ans1 = json.loads(ans1_js)
ans2 = json.loads(ans2_js)
print(ques, ans1, ans2)
inst = image_to_context[ques['image']]
if isinstance(inst['caption'], list):
cap_str = '\n'.join(inst['caption'])
else:
cap_str = inst['caption']
category = 'llava_bench_' + json.loads(ques_js)['category']
if category in rule_dict:
rule = rule_dict[category]
else:
assert False, f'Visual QA category not found in rule file: {category}.'
prompt = rule['prompt']
role = rule['role']
content = (f'[Context]\n{cap_str}\n\n'
f'[Question]\n{ques["text"]}\n\n'
f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
f'[System]\n{prompt}\n\n')
cur_js = {
'id': idx + 1,
'question_id': ques['question_id'],
'answer1_id': ans1.get('answer_id', ans1['question_id']),
'answer2_id': ans2.get('answer_id', ans1['question_id']),
'category': category
}
if idx >= len(cur_reviews):
review = get_eval(content, args.max_tokens)
scores = parse_score(review)
cur_js['content'] = review
cur_js['tuple'] = scores
review_file.write(json.dumps(cur_js) + '\n')
review_file.flush()
else:
print(f'Skipping {idx} as we already have it.')
idx += 1
print(idx)
review_file.close()
import argparse
import json
import os
import random
import torch
from internvl.model.internvl_chat import InternVLChatModel
from internvl.train.dataset import build_transform, dynamic_preprocess
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer
ds_collections = {
'llava_bench': {
'root': 'data/llava-bench-in-the-wild/images',
'question': 'data/llava-bench-in-the-wild/questions.jsonl',
'max_new_tokens': 1000,
'min_new_tokens': 1,
},
}
class VQADataset(torch.utils.data.Dataset):
def __init__(self, root, data, prompt, input_size=224, dynamic_image_size=False,
use_thumbnail=False, max_num=6):
self.root = root
self.data = open(data).readlines()
self.prompt = prompt
self.input_size = input_size
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.max_num = max_num
self.transform = build_transform(is_train=False, input_size=input_size)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = json.loads(self.data[idx].strip())
image, question, question_id, annotation = data['image'], data[
'text'], data['question_id'], data.get('answer', None)
image = os.path.join(self.root, image)
image = Image.open(image).convert('RGB')
if self.dynamic_image_size:
images = dynamic_preprocess(image, image_size=self.input_size,
use_thumbnail=self.use_thumbnail,
max_num=self.max_num)
else:
images = [image]
pixel_values = [self.transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
question = question + self.prompt
return question_id, question, pixel_values, annotation
def evaluate_chat_model():
random.seed(args.seed)
for ds_name in args.datasets:
dataset = VQADataset(
root=ds_collections[ds_name]['root'],
data=ds_collections[ds_name]['question'],
prompt=' Please give a detailed answer.',
input_size=image_size,
dynamic_image_size=args.dynamic,
use_thumbnail=use_thumbnail,
max_num=args.max_num
)
outputs = []
for _, (question_id, question, pixel_values, annotations) in tqdm(enumerate(dataset)):
pixel_values = pixel_values.to(torch.bfloat16).cuda()
generation_config = dict(
num_beams=args.num_beams,
max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
min_new_tokens=ds_collections[ds_name]['min_new_tokens'],
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
)
pred = model.chat(
tokenizer=tokenizer,
pixel_values=pixel_values,
question=question,
generation_config=generation_config,
verbose=True
)
outputs.append({
'question_id': question_id,
'text': pred,
'model_id': model_id,
'metadata': {}
})
print(f'Evaluating {ds_name} ...')
results_file = 'llava_bench_results.jsonl'
results_file = os.path.join(args.out_dir, results_file)
writer = open(results_file, 'w')
for item in outputs:
writer.write(json.dumps(item) + '\n')
writer.close()
print('Results saved to {}'.format(results_file))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', type=str, default='')
parser.add_argument('--datasets', type=str, default='llava_bench')
parser.add_argument('--batch-size', type=int, default=1)
parser.add_argument('--num-workers', type=int, default=1)
parser.add_argument('--num-beams', type=int, default=5)
parser.add_argument('--temperature', type=float, default=0.0)
parser.add_argument('--out-dir', type=str, default='results')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--dynamic', action='store_true')
parser.add_argument('--max-num', type=int, default=6)
parser.add_argument('--load-in-8bit', action='store_true')
parser.add_argument('--load-in-4bit', action='store_true')
parser.add_argument('--auto', action='store_true')
args = parser.parse_args()
if not os.path.exists(args.out_dir):
os.makedirs(args.out_dir)
args.datasets = args.datasets.split(',')
print('datasets:', args.datasets)
assert args.batch_size == 1, 'Only batch size 1 is supported'
if args.auto:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
kwargs = {'device_map': 'auto'} if args.auto else {}
tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
model = InternVLChatModel.from_pretrained(
args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit, **kwargs).eval()
if not args.load_in_8bit and not args.load_in_4bit and not args.auto:
model = model.cuda()
image_size = model.config.force_image_size or model.config.vision_config.image_size
use_thumbnail = model.config.use_thumbnail
total_params = sum(p.numel() for p in model.parameters()) / 1e9
if total_params > 20 or args.dynamic:
args.num_beams = 1
print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
else:
print(f'[test] total_params: {total_params}B')
print(f'[test] image_size: {image_size}')
print(f'[test] template: {model.config.template}')
print(f'[test] dynamic_image_size: {args.dynamic}')
print(f'[test] use_thumbnail: {use_thumbnail}')
print(f'[test] max_num: {args.max_num}')
model_id = '_'.join(args.checkpoint.split('/')[-2:])
evaluate_chat_model()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment