Commit 26e59280 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #2674 failed with stages
in 0 seconds
import json
from open_clip.tokenizer import _tokenizer
from pycocoevalcap.eval import COCOEvalCap
from tqdm.auto import tqdm
def evaluate(model, dataloader, batch_size, device, transform, train_dataloader=None, num_workers=None, amp=True,
verbose=False):
coco = dataloader.dataset.coco
indexer = dataloader.dataset.ids
results = []
for idx, (img, _) in enumerate(tqdm(dataloader)):
n_samples = img.shape[0] # for last batch
idxs = [indexer[idx * batch_size + id] for id in range(n_samples)]
out = model.generate(img.to(device))
decoded = [_tokenizer.decode(i).split('<end_of_text>')[0].replace('<start_of_text>', '').strip() for i in
out.cpu().numpy()]
for image_id, caption in zip(idxs, decoded):
results.append({'image_id': image_id, 'caption': caption})
temp_res_file = 'temp_results.json'
with open(temp_res_file, 'w') as jf:
json.dump(results, jf)
coco_result = coco.loadRes(temp_res_file)
coco_eval = COCOEvalCap(coco, coco_result)
coco_eval.evaluate()
metrics = coco_eval.eval
# print output evaluation scores
for metric, score in metrics.items():
print(f'{metric}: {score:.3f}')
return metrics
"""
Code adapated from https://github.com/mlfoundations/open_clip/blob/main/src/training/zero_shot.py
Thanks to the authors of OpenCLIP
"""
from contextlib import suppress
import torch
import torch.nn.functional as F
from sklearn.metrics import balanced_accuracy_score, classification_report
from tqdm import tqdm
def zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=True, cupl=False):
"""
This function returns zero-shot vectors for each class in order
to use it for zero-shot classification.
model:
CLIP-like model with `encode_text`
tokenizer:
text tokenizer, i.e. convert list of strings to torch.Tensor of integers
classnames: list of str
name of classes
templates: list of str
templates to use.
Returns
-------
torch.Tensor of shape (N,C) where N is the number
of templates, and C is the number of classes.
"""
autocast = torch.cuda.amp.autocast if amp else suppress
with torch.no_grad(), autocast():
zeroshot_weights = []
for classname in tqdm(classnames):
if cupl:
texts = templates[classname]
else:
texts = [template.format(c=classname) for template in templates]
texts = tokenizer(texts).to(device) # tokenize
class_embeddings = model.encode_text(texts)
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
class_embedding /= class_embedding.norm()
zeroshot_weights.append(class_embedding)
zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device)
return zeroshot_weights
def accuracy(output, target, topk=(1,)):
"""
Compute top-k accuracy
output: torch.Tensor
shape (N, C) where N is the number of examples, C the number of classes.
these are the logits.
target: torch.Tensor
shape (N,) where N is the number of examples. Groundtruth class id of each example.
topk: tuple
which topk to compute, e.g., topk=(1,5) will compute top-1 and top-5 accuracies
Returns
-------
list of top-k accuracies in the same order as `topk`
"""
pred = output.topk(max(topk), 1, True, True)[1].t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
n = len(target)
return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) / n for k in topk]
def run_classification(model, classifier, dataloader, device, amp=True):
"""
Run zero-shot classifcation
model: torch.nn.Module
CLIP-like model with `encode_image` and `encode_text`
classifier: torch.Tensor
obtained from the function `zero_shot_classifier`
dataloader: torch.utils.data.Dataloader
Returns
-------
(pred, true) where
- pred (N, C) are the logits
- true (N,) are the actual classes
"""
autocast = torch.cuda.amp.autocast if amp else suppress
pred = []
true = []
nb = 0
with torch.no_grad():
for images, target in tqdm(dataloader):
images = images.to(device)
target = target.to(device)
with autocast():
# predict
image_features = model.encode_image(images)
image_features = F.normalize(image_features, dim=-1)
logits = 100. * image_features @ classifier
true.append(target.cpu())
pred.append(logits.float().cpu())
pred = torch.cat(pred)
true = torch.cat(true)
return pred, true
def average_precision_per_class(scores, targets):
"""
Compute average precision for each class
this metric is used for multi-label classification
see explanations here https://fangdahan.medium.com/calculate-mean-average-precision-map-for-multi-label-classification-b082679d31be
Code is adapted from https://github.com/pytorch/tnt/blob/master/torchnet/meter/meter.py, thanks to the authors of `tnt`.
Parameters
----------
scores: torch.Tensor
logits, of shape (N,C) where N is the number of examples, C the number of classes
targets: torch.Tensor
one-hot vectors of groundtruth targets (N, C), where N is the number of examples, C is the
number of classes
Returns
-------
torch.Tensor of shape (C,) of avereage precision for each class, where C is
the number of classes.
"""
ap = torch.zeros(scores.size(1))
rg = torch.arange(1, scores.size(0) + 1).float()
# compute average precision for each class
for k in range(scores.size(1)):
# sort scores
scores_k = scores[:, k]
targets_k = targets[:, k]
_, sortind = torch.sort(scores_k, 0, True)
truth = targets_k[sortind]
tp = truth.float().cumsum(0)
# compute precision curve
precision = tp.div(rg)
# compute average precision
ap[k] = precision[truth.bool()].sum() / max(float(truth.sum()), 1)
return ap
def evaluate(model, dataloader, tokenizer, classnames, templates, device, amp=True, verbose=False, cupl=False,
save_clf=None, load_clfs=[]):
"""
Run zero-shot classification and evaluate the metrics
Parameters
----------
model: torch.nn.Module
CLIP-like model with `encode_image` and `encode_text`
dataloader: torch.utils.data.Dataloader
tokenizer: text tokenizer
classnames: list of str
class names
templates: list of str
templates to use for zero-shot classification
device: cpu/cuda
amp: whether to use automatic mixed precision
verbose: whether to use verbose model
Returns
-------
dict of classification metrics
"""
if len(load_clfs) > 0:
n = len(load_clfs)
classifier = torch.load(load_clfs[0], map_location='cpu') / n
for i in range(1, n):
classifier = classifier + torch.load(load_clfs[i], map_location='cpu') / n
classifier = classifier.to(device)
else:
classifier = zero_shot_classifier(model, tokenizer, classnames, templates, device, cupl=cupl)
if save_clf is not None:
torch.save(classifier, save_clf)
# exit() - not sure if we want to exit here or not.
logits, target = run_classification(model, classifier, dataloader, device, amp=amp)
is_multilabel = (len(target.shape) == 2)
if is_multilabel:
if verbose:
print('Detected a multi-label classification dataset')
# Multiple labels per image, multiple classes on the dataset
ap_per_class = average_precision_per_class(logits, target)
if verbose:
for class_name, ap in zip(dataloader.dataset.classes, ap_per_class.tolist()):
print(f'Class: {class_name}, AveragePrecision: {ap}')
return {'mean_average_precision': ap_per_class.mean().item()}
else:
# Single label per image, multiple classes on the dataset
# just compute accuracy and mean_per_class_recall
pred = logits.argmax(axis=1)
# measure accuracy
if len(dataloader.dataset.classes) >= 5:
acc1, acc5 = accuracy(logits, target, topk=(1, 5))
else:
acc1, = accuracy(logits, target, topk=(1,))
acc5 = float('nan')
mean_per_class_recall = balanced_accuracy_score(target, pred)
if verbose:
print(classification_report(target, pred, digits=3))
return {'acc1': acc1, 'acc5': acc5, 'mean_per_class_recall': mean_per_class_recall}
from contextlib import suppress
import torch
import torch.nn.functional as F
from tqdm import tqdm
def evaluate(model, dataloader, tokenizer, device, amp=True, recall_k_list=[5]):
"""
Evaluate the model on the given dataset
Parameters
----------
model: torch.nn,Module
CLIP-like model with `encode_image` and `encode_text`
dataloader: torch.utils.data.Dataloader
dataloader to use for evaluation
tokenizer:
text tokenizer, i.e. convert list of strings to torch.Tensor of integers
device: cpu/cuda
amp: whether to use automatic mixed precision
recall_k_list: list of int
recall@k k's to use
Returns
-------
dict of retrieval metrics
"""
# list of batch of images embedding
batch_images_emb_list = []
# list of batch of text embedding
batch_texts_emb_list = []
# for each text, we collect the corresponding image index, as each image can have multiple corresponding texts
texts_image_index = []
dataloader = dataloader_with_indices(dataloader)
autocast = torch.cuda.amp.autocast if amp else suppress
for batch_images, batch_texts, inds in tqdm(dataloader):
batch_images = batch_images.to(device)
# tokenize all texts in the batch
batch_texts_tok = tokenizer([text for i, texts in enumerate(batch_texts) for text in texts]).to(device)
# store the index of image for each text
batch_texts_image_index = [ind for ind, texts in zip(inds, batch_texts) for text in texts]
# compute the embedding of images and texts
with torch.no_grad(), autocast():
batch_images_emb = F.normalize(model.encode_image(batch_images), dim=-1)
batch_texts_emb = F.normalize(model.encode_text(batch_texts_tok), dim=-1)
batch_images_emb_list.append(batch_images_emb.cpu())
batch_texts_emb_list.append(batch_texts_emb.cpu())
texts_image_index.extend(batch_texts_image_index)
batch_size = len(batch_images_emb_list[0])
# concatenate all embeddings
images_emb = torch.cat(batch_images_emb_list)
texts_emb = torch.cat(batch_texts_emb_list)
# get the score for each text and image pair
scores = texts_emb @ images_emb.t()
# construct a the positive pair matrix, which tells whether each text-image pair is a positive or not
positive_pairs = torch.zeros_like(scores, dtype=bool)
positive_pairs[torch.arange(len(scores)), texts_image_index] = True
metrics = {}
for recall_k in recall_k_list:
# Note that recall_at_k computes **actual** recall i.e. nb_true_positive/nb_positives, where the number
# of true positives, e.g. for text retrieval, is, for each image, the number of retrieved texts matching that image among the top-k.
# Also, the number of positives are the total number of texts matching the image in the dataset, as we have a set of captions
# for each image, that number will be greater than 1 for text retrieval.
# However, image/text retrieval recall@k, the way it is done in CLIP-like papers, is a bit different.
# recall@k, in CLIP-like papers, is, for each image, either 1 or 0. It is 1 if atleast one text matches the image among the top-k.
# so we can easily compute that using the actual recall, by checking whether there is at least one true positive,
# which would be the case if the recall is greater than 0. One we compute the recal for each image (or text), we average
# it over the dataset.
metrics[f'image_retrieval_recall@{recall_k}'] = (
batchify(recall_at_k, scores, positive_pairs, batch_size, device,
k=recall_k) > 0).float().mean().item()
metrics[f'text_retrieval_recall@{recall_k}'] = (
batchify(recall_at_k, scores.T, positive_pairs.T, batch_size, device,
k=recall_k) > 0).float().mean().item()
return metrics
def dataloader_with_indices(dataloader):
start = 0
for x, y in dataloader:
end = start + len(x)
inds = torch.arange(start, end)
yield x, y, inds
start = end
def recall_at_k(scores, positive_pairs, k):
"""
Compute the recall at k for each sample
:param scores: compability score between text and image embeddings (nb texts, nb images)
:param k: number of images to consider per text, for retrieval
:param positive_pairs: boolean matrix of positive pairs (nb texts, nb images)
:return: recall at k averaged over all texts
"""
nb_texts, nb_images = scores.shape
# for each text, sort according to image scores in decreasing order
topk_indices = torch.topk(scores, k, dim=1)[1]
# compute number of positives for each text
nb_positive = positive_pairs.sum(dim=1)
# nb_texts, k, nb_images
topk_indices_onehot = torch.nn.functional.one_hot(topk_indices, num_classes=nb_images)
# compute number of true positives
positive_pairs_reshaped = positive_pairs.view(nb_texts, 1, nb_images)
# a true positive means a positive among the topk
nb_true_positive = (topk_indices_onehot * positive_pairs_reshaped).sum(dim=(1, 2))
# compute recall at k
recall_at_k = (nb_true_positive / nb_positive)
return recall_at_k
def batchify(func, X, Y, batch_size, device, *args, **kwargs):
results = []
for start in range(0, len(X), batch_size):
end = start + batch_size
x = X[start:end].to(device)
y = Y[start:end].to(device)
result = func(x, y, *args, **kwargs).cpu()
results.append(result)
return torch.cat(results)
import open_clip
def get_model_collection_from_file(path):
return [l.strip().split(',') for l in open(path).readlines()]
model_collection = {
'openclip_base': [
('ViT-B-32-quickgelu', 'laion400m_e32'),
('ViT-B-32', 'laion2b_e16'),
('ViT-B-32', 'laion2b_s34b_b79k'),
('ViT-B-16', 'laion400m_e32'),
('ViT-B-16-plus-240', 'laion400m_e32'),
('ViT-L-14', 'laion400m_e32'),
('ViT-L-14', 'laion2b_s32b_b82k'),
('ViT-H-14', 'laion2b_s32b_b79k'),
('ViT-g-14', 'laion2b_s12b_b42k'),
],
'openclip_multilingual': [
('xlm-roberta-base-ViT-B-32', 'laion5b_s13b_b90k'),
('xlm-roberta-large-ViT-H-14', 'frozen_laion5b_s13b_b90k'),
],
'openclip_all': open_clip.list_pretrained(),
'openai': [
('ViT-B-32', 'openai'),
('ViT-B-16', 'openai'),
('ViT-L-14', 'openai'),
('ViT-L-14-336', 'openai'),
]
}
from typing import Union
import torch
from .internvl import load_internvl
from .japanese_clip import load_japanese_clip
from .open_clip import load_open_clip
# loading function must return (model, transform, tokenizer)
TYPE2FUNC = {
'open_clip': load_open_clip,
'ja_clip': load_japanese_clip,
'internvl': load_internvl,
}
MODEL_TYPES = list(TYPE2FUNC.keys())
def load_clip(
model_type: str,
model_name: str,
pretrained: str,
cache_dir: str,
device: Union[str, torch.device] = 'cuda'
):
assert model_type in MODEL_TYPES, f'model_type={model_type} is invalid!'
load_func = TYPE2FUNC[model_type]
return load_func(model_name=model_name, pretrained=pretrained, cache_dir=cache_dir, device=device)
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import os
from typing import Union
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class InternVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
instantiate a vision encoder according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_channels (`int`, *optional*, defaults to 3):
Number of color channels in the input images (e.g., 3 for RGB).
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
qkv_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the queries and values in the self-attention layers.
hidden_size (`int`, *optional*, defaults to 3200):
Dimensionality of the encoder layers and the pooler layer.
num_attention_heads (`int`, *optional*, defaults to 25):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 12800):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
qk_normalization (`bool`, *optional*, defaults to `True`):
Whether to normalize the queries and keys in the self-attention layers.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
use_flash_attn (`bool`, *optional*, defaults to `True`):
Whether to use flash attention mechanism.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the layer normalization layers.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
drop_path_rate (`float`, *optional*, defaults to 0.0):
Dropout rate for stochastic depth.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 0.1):
A factor for layer scale.
"""
model_type = 'intern_vit_6b'
def __init__(
self,
num_channels=3,
patch_size=14,
image_size=224,
qkv_bias=False,
hidden_size=3200,
num_attention_heads=25,
intermediate_size=12800,
qk_normalization=True,
num_hidden_layers=48,
use_flash_attn=True,
hidden_act='gelu',
layer_norm_eps=1e-6,
dropout=0.0,
drop_path_rate=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=0.1,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.dropout = dropout
self.drop_path_rate = drop_path_rate
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.qkv_bias = qkv_bias
self.qk_normalization = qk_normalization
self.use_flash_attn = use_flash_attn
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if 'vision_config' in config_dict:
config_dict = config_dict['vision_config']
if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
)
return cls.from_dict(config_dict, **kwargs)
import torch
import torch.nn as nn
from einops import rearrange
try: # v1
from flash_attn.flash_attn_interface import \
flash_attn_unpadded_qkvpacked_func
except: # v2
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import pad_input, unpad_input
class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
if cu_seqlens is None:
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads)
else:
assert max_s is not None
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
return output, None
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from einops import rearrange
from timm.models.layers import DropPath
from torch import nn
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (BaseModelOutput,
BaseModelOutputWithPooling)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from .configuration_intern_vit import InternVisionConfig
try:
from .flash_attention import FlashAttention
has_flash_attn = True
except:
print('FlashAttention is not installed.')
has_flash_attn = False
logger = logging.get_logger(__name__)
class InternRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
try:
from apex.normalization import FusedRMSNorm
InternRMSNorm = FusedRMSNorm # noqa
logger.info('Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm')
except ImportError:
# using the normal InternRMSNorm
pass
except Exception:
logger.warning('discovered apex but it failed to load, falling back to InternRMSNorm')
pass
class InternVisionEmbeddings(nn.Module):
def __init__(self, config: InternVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(
torch.randn(1, 1, self.embed_dim),
)
self.patch_embedding = nn.Conv2d(
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding.to(target_dtype)
return embeddings
class InternAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: InternVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.use_flash_attn = config.use_flash_attn and has_flash_attn
if config.use_flash_attn and not has_flash_attn:
print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:'
f' {self.num_heads}).'
)
self.scale = self.head_dim ** -0.5
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
self.attn_drop = nn.Dropout(config.attention_dropout)
self.proj_drop = nn.Dropout(config.dropout)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
if self.use_flash_attn:
self.inner_attn = FlashAttention(attention_dropout=config.attention_dropout)
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
def _naive_attn(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
if self.qk_normalization:
B_, H_, N_, D_ = q.shape
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
attn = ((q * self.scale) @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
qkv = self.qkv(x)
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
if self.qk_normalization:
q, k, v = qkv.unbind(2)
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
qkv = torch.stack([q, k, v], dim=2)
context, _ = self.inner_attn(
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=False
)
outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
outs = self.proj_drop(outs)
return outs
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self._naive_attn(hidden_states) if not self.use_flash_attn else self._flash_attn(hidden_states)
return x
class InternMLP(nn.Module):
def __init__(self, config: InternVisionConfig):
super().__init__()
self.config = config
self.act = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
class InternVisionEncoderLayer(nn.Module):
def __init__(self, config: InternVisionConfig, drop_path_rate: float):
super().__init__()
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.attn = InternAttention(config)
self.mlp = InternMLP(config)
self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
self.drop_path1 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.drop_path2 = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
def forward(
self,
hidden_states: torch.Tensor,
) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor], Optional[Tuple[torch.FloatTensor]]]:
"""
Args:
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
"""
hidden_states = hidden_states + self.drop_path1(self.attn(self.norm1(hidden_states)) * self.ls1)
hidden_states = hidden_states + self.drop_path2(self.mlp(self.norm2(hidden_states)) * self.ls2)
return hidden_states
class InternVisionEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`InternEncoderLayer`].
Args:
config (`InternConfig`):
The corresponding vision configuration for the `InternEncoder`.
"""
def __init__(self, config: InternVisionConfig):
super().__init__()
self.config = config
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
self.layers = nn.ModuleList([
InternVisionEncoderLayer(config, dpr[idx]) for idx in range(config.num_hidden_layers)])
self.gradient_checkpointing = True
def forward(
self,
inputs_embeds,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embedded representation of the inputs. Should be float, not int tokens.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
hidden_states = inputs_embeds
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint(
encoder_layer,
hidden_states)
else:
layer_outputs = encoder_layer(
hidden_states,
)
hidden_states = layer_outputs
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states
)
class InternVisionModel(PreTrainedModel):
main_input_name = 'pixel_values'
config_class = InternVisionConfig
def __init__(self, config: InternVisionConfig):
super().__init__(config)
self.config = config
self.embeddings = InternVisionEmbeddings(config)
self.encoder = InternVisionEncoder(config)
def resize_pos_embeddings(self, old_size, new_size, patch_size):
pos_emb = self.embeddings.position_embedding
_, num_positions, embed_dim = pos_emb.shape
cls_emb = pos_emb[:, :1, :]
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size, old_size // patch_size, -1).permute(0, 3, 1, 2)
pos_emb = F.interpolate(pos_emb.float(), size=new_size // patch_size, mode='bicubic', align_corners=False)
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
self.embeddings.position_embedding = nn.Parameter(pos_emb)
logger.info('Resized position embeddings from {} to {}'.format(old_size, new_size))
def get_input_embeddings(self):
return self.embeddings
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_embeds: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None and pixel_embeds is None:
raise ValueError('You have to specify pixel_values or pixel_embeds')
if pixel_embeds is not None:
hidden_states = pixel_embeds
else:
if len(pixel_values.shape) == 4:
hidden_states = self.embeddings(pixel_values)
else:
raise ValueError(f'wrong pixel_values size: {pixel_values.shape}')
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs.last_hidden_state
pooled_output = last_hidden_state[:, 0, :]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from .internvl_c_pytorch import load_internvl_c_pytorch
from .internvl_huggingface import (load_internvl_c_huggingface,
load_internvl_g_huggingface)
def load_internvl(model_name, pretrained, cache_dir, device):
if model_name == 'internvl_c_classification':
return load_internvl_c_pytorch(pretrained, device, 'classification')
elif model_name == 'internvl_c_retrieval':
return load_internvl_c_pytorch(pretrained, device, 'retrieval')
elif model_name == 'internvl_c_classification_hf':
return load_internvl_c_huggingface(pretrained, device, 'classification')
elif model_name == 'internvl_c_retrieval_hf':
return load_internvl_c_huggingface(pretrained, device, 'retrieval')
elif model_name == 'internvl_g_classification_hf':
return load_internvl_g_huggingface(pretrained, device, 'classification')
elif model_name == 'internvl_g_retrieval_hf':
return load_internvl_g_huggingface(pretrained, device, 'retrieval')
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import os
import torch
import torchvision.transforms as T
from torch import nn
from torchvision.transforms import InterpolationMode
from transformers import LlamaTokenizer
from .internvl_c import InternVL_C
try:
from .flash_attention import FlashAttention
except:
print('FlashAttention is not installed.')
class InternVLTokenizer(nn.Module):
def __init__(self, model_path):
super(InternVLTokenizer, self).__init__()
self.tokenizer = LlamaTokenizer.from_pretrained(model_path)
self.tokenizer.pad_token = ' ' # allow padding
self.tokenizer.add_eos_token = True
def forward(self, text, prefix='summarize:'):
if type(text) == str:
text = prefix + text
elif type(text) == list:
text = [prefix + item for item in text]
text = self.tokenizer(text, return_tensors='pt', max_length=80, truncation=True, padding=True).input_ids
return text
def build_transform(task, image_size=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
if task == 'retrieval':
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=mean, std=std)])
else:
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize(image_size, interpolation=InterpolationMode.BICUBIC),
T.CenterCrop(image_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std)])
return transform
def get_model_and_transform(task, image_size, device):
llm_path = os.path.split(os.path.realpath(__file__))[0]
llm_path = os.path.join(llm_path, 'chinese_alpaca_lora_7b')
model = InternVL_C(img_size=image_size, layerscale_force_fp32=True, llm_path=llm_path)
model = model.to(torch.float16).to(device)
transform = build_transform(task, image_size)
return model, transform
def load_internvl_c_pytorch(ckpt_path, device, task, image_size=224):
llm_path = os.path.split(os.path.realpath(__file__))[0]
llm_path = os.path.join(llm_path, 'chinese_alpaca_lora_7b')
tokenizer = InternVLTokenizer(llm_path)
model, transform = get_model_and_transform(task=task, image_size=image_size, device=device)
ckpt = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(ckpt, strict=False)
return model, transform, tokenizer
{
"architectures": [
"LlamaForCausalLM"
],
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 11008,
"max_position_embeddings": 2048,
"max_sequence_length": 2048,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"pad_token_id": 0,
"rms_norm_eps": 1e-06,
"tie_word_embeddings": false,
"torch_dtype": "float16",
"transformers_version": "4.28.0.dev0",
"use_cache": true,
"vocab_size": 49954
}
{
"_from_model_config": true,
"bos_token_id": 1,
"eos_token_id": 2,
"pad_token_id": 0,
"transformers_version": "4.28.0.dev0"
}
{
"metadata": {
"total_size": 13770997760
},
"weight_map": {
"lm_head.weight": "pytorch_model-00002-of-00002.bin",
"model.embed_tokens.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.0.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.0.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.0.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.0.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.0.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.0.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.0.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.0.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.0.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.0.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.1.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.1.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.1.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.1.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.1.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.1.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.1.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.1.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.1.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.1.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.10.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.10.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.10.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.10.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.10.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.10.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.10.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.10.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.10.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.10.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.11.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.11.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.11.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.11.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.11.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.11.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.11.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.11.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.11.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.11.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.12.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.12.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.12.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.12.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.12.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.12.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.12.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.12.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.12.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.12.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.13.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.13.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.13.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.13.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.13.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.13.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.13.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.13.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.13.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.13.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.14.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.14.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.14.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.14.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.14.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.14.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.14.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.14.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.14.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.14.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.15.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.15.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.15.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.15.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.15.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.15.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.15.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.15.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.15.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.15.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.16.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.16.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.16.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.16.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.16.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.16.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.16.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.16.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.16.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.16.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.17.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.17.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.17.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.17.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.17.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.17.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.17.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.17.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.17.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.17.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.18.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.18.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.18.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.18.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.18.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.18.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.18.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.18.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.18.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.18.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.19.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.19.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.19.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.19.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.19.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.19.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.19.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.19.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.19.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.19.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.2.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.2.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.2.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.2.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.2.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.2.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.2.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.2.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.2.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.2.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.20.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.20.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.20.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.20.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.20.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.20.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.20.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.20.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.20.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.20.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.21.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.21.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.21.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.21.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.21.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.21.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.21.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.21.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.21.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.21.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.22.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.22.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.22.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.22.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.22.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.22.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.22.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.22.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.22.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.22.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.23.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.23.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.23.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.23.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.23.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.23.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.23.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.23.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.23.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.23.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.24.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.24.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.24.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.24.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.24.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.24.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.24.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.24.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.24.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
"model.layers.24.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.25.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.25.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.25.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.25.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.25.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.25.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.25.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.25.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.25.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
"model.layers.25.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.26.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.26.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.26.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.26.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.26.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.26.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.26.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.26.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.26.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
"model.layers.26.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.27.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.27.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.27.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.27.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.27.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.27.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.27.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.27.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.27.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
"model.layers.27.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.28.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.28.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.28.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.28.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.28.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.28.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.28.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.28.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.28.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
"model.layers.28.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.29.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.29.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.29.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.29.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.29.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.29.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.29.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.29.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.29.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
"model.layers.29.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.3.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.3.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.3.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.3.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.3.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.3.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.3.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.3.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.3.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.3.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.30.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.30.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.30.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.30.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.30.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.30.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.30.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.30.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.30.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
"model.layers.30.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.31.input_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.31.mlp.down_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.31.mlp.gate_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.31.mlp.up_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.31.post_attention_layernorm.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.31.self_attn.k_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.31.self_attn.o_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.31.self_attn.q_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.31.self_attn.rotary_emb.inv_freq": "pytorch_model-00002-of-00002.bin",
"model.layers.31.self_attn.v_proj.weight": "pytorch_model-00002-of-00002.bin",
"model.layers.4.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.4.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.4.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.4.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.4.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.4.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.4.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.4.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.4.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.4.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.5.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.5.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.5.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.5.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.5.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.5.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.5.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.5.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.5.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.5.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.6.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.6.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.6.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.6.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.6.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.6.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.6.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.6.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.6.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.6.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.7.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.7.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.7.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.7.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.7.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.7.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.7.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.7.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.7.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.7.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.8.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.8.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.8.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.8.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.8.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.8.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.8.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.8.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.8.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.8.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.9.input_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.9.mlp.down_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.9.mlp.gate_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.9.mlp.up_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.9.post_attention_layernorm.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.9.self_attn.k_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.9.self_attn.o_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.9.self_attn.q_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.layers.9.self_attn.rotary_emb.inv_freq": "pytorch_model-00001-of-00002.bin",
"model.layers.9.self_attn.v_proj.weight": "pytorch_model-00001-of-00002.bin",
"model.norm.weight": "pytorch_model-00002-of-00002.bin"
}
}
{
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "[PAD]",
"unk_token": "<unk>"
}
{
"add_bos_token": true,
"add_eos_token": false,
"bos_token": {
"__type": "AddedToken",
"content": "<s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"clean_up_tokenization_spaces": false,
"eos_token": {
"__type": "AddedToken",
"content": "</s>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
},
"model_max_length": 1000000000000000019884624838656,
"pad_token": null,
"sp_model_kwargs": {},
"special_tokens_map_file": "chinese_alpaca_lora_7b/special_tokens_map.json",
"tokenizer_class": "LlamaTokenizer",
"unk_token": {
"__type": "AddedToken",
"content": "<unk>",
"lstrip": false,
"normalized": true,
"rstrip": false,
"single_word": false
}
}
# https://github.com/Dao-AILab/flash-attention/blob/v0.2.8/flash_attn/flash_attention.py
import torch
import torch.nn as nn
from einops import rearrange
try: # v1
from flash_attn.flash_attn_interface import \
flash_attn_unpadded_qkvpacked_func
except: # v2
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
from flash_attn.bert_padding import pad_input, unpad_input
class FlashAttention(nn.Module):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
super().__init__()
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
max_s=None, need_weights=False):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
if unpadded: (nnz, 3, h, d)
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.is_cuda
if cu_seqlens is None:
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
if key_padding_mask is None:
qkv = rearrange(qkv, 'b s ... -> (b s) ...')
max_s = seqlen
cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
device=qkv.device)
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, 'b s three h d -> b s (three h d)')
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
output_unpad = flash_attn_unpadded_qkvpacked_func(
x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
indices, batch_size, seqlen),
'b s (h d) -> b s h d', h=nheads)
else:
assert max_s is not None
output = flash_attn_unpadded_qkvpacked_func(
qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
softmax_scale=self.softmax_scale, causal=causal
)
return output, None
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange
from timm.models.layers import DropPath, to_2tuple
from torch import nn
from transformers import LlamaConfig, LlamaForCausalLM
try:
from .flash_attention import FlashAttention
has_flash_attn = True
except:
print('FlashAttention is not installed.')
has_flash_attn = False
class CrossAttention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., attn_head_dim=None, out_dim=None):
super().__init__()
if out_dim is None:
out_dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
assert all_head_dim == dim
self.q = nn.Linear(dim, all_head_dim, bias=False)
self.k = nn.Linear(dim, all_head_dim, bias=False)
self.v = nn.Linear(dim, all_head_dim, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.k_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, out_dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, k=None, v=None):
B, N, C = x.shape
N_k = k.shape[1]
N_v = v.shape[1]
q_bias, k_bias, v_bias = None, None, None
if self.q_bias is not None:
q_bias = self.q_bias
k_bias = self.k_bias
v_bias = self.v_bias
q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1)) # (B, N_head, N_q, N_k)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentiveBlock(nn.Module):
def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, attn_head_dim=None, out_dim=None):
super().__init__()
self.norm1_q = norm_layer(dim)
self.norm1_k = norm_layer(dim)
self.norm1_v = norm_layer(dim)
self.cross_attn = CrossAttention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
proj_drop=drop, attn_head_dim=attn_head_dim, out_dim=out_dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x_q, x_kv, pos_q, pos_k, bool_masked_pos, rel_pos_bias=None):
x_q = self.norm1_q(x_q + pos_q)
x_k = self.norm1_k(x_kv + pos_k)
x_v = self.norm1_v(x_kv)
x = self.cross_attn(x_q, k=x_k, v=x_v)
return x
class AttentionPoolingBlock(AttentiveBlock):
def forward(self, x):
x_q = x.mean(1, keepdim=True)
x_kv, pos_q, pos_k = x, 0, 0
x = super().forward(x_q, x_kv, pos_q, pos_k, bool_masked_pos=None, rel_pos_bias=None)
x = x.squeeze(1)
return x
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
try:
from apex.normalization import FusedRMSNorm
RMSNorm = FusedRMSNorm # noqa
print('Discovered apex.normalization.FusedRMSNorm - will use it instead of RMSNorm')
except ImportError:
# using the normal RMSNorm
pass
except Exception:
print('discovered apex but it failed to load, falling back to RMSNorm')
pass
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False, force_fp32=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
self.force_fp32 = force_fp32
@torch.cuda.amp.autocast(enabled=False)
def forward(self, x):
if self.force_fp32:
output_type = x.dtype
out = x.float().mul_(self.gamma.float()) if self.inplace else x.float() * self.gamma.float()
return out.to(dtype=output_type)
else:
out = x.mul_(self.gamma) if self.inplace else x * self.gamma
return out
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., use_flash_attn=False,
causal=False, norm_layer=nn.LayerNorm, qk_normalization=False):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.use_flash_attn = use_flash_attn
if use_flash_attn:
self.causal = causal
self.inner_attn = FlashAttention(attention_dropout=attn_drop)
self.qk_normalization = qk_normalization
self.q_norm = norm_layer(dim) if qk_normalization else nn.Identity()
self.k_norm = norm_layer(dim) if qk_normalization else nn.Identity()
def _naive_attn(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
if self.qk_normalization:
B_, H_, N_, D_ = q.shape
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(B_, N_, H_, D_).transpose(1, 2)
attn = ((q * self.scale) @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
qkv = self.qkv(x)
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.num_heads)
if self.qk_normalization:
q, k, v = qkv.unbind(2)
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
qkv = torch.stack([q, k, v], dim=2)
context, _ = self.inner_attn(
qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal
)
outs = self.proj(rearrange(context, 'b s h d -> b s (h d)'))
outs = self.proj_drop(outs)
return outs
def forward(self, x):
x = self._naive_attn(x) if not self.use_flash_attn else self._flash_attn(x)
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
bias=True, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class Block(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_flash_attn=False, with_cp=False,
qk_normalization=False, layerscale_force_fp32=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
use_flash_attn=use_flash_attn, causal=False, norm_layer=norm_layer,
qk_normalization=qk_normalization)
self.ls1 = LayerScale(dim, init_values=init_values,
force_fp32=layerscale_force_fp32) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values,
force_fp32=layerscale_force_fp32) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.with_cp = with_cp
def forward(self, x):
def _inner_forward(x):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
if self.with_cp:
return checkpoint.checkpoint(_inner_forward, x)
else:
return _inner_forward(x)
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x, **kwargs):
x = self.proj(x)
_, _, H, W = x.shape
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x, H, W
class InternVL_C(nn.Module):
def __init__(self, in_chans=3, patch_size=14, img_size=224, qkv_bias=False, drop_path_rate=0.0,
embed_dim=3200, num_heads=25, mlp_ratio=4, init_values=0.1, qk_normalization=True, depth=48,
use_flash_attn=True, with_cp=True, layerscale_force_fp32=False, context_length: int = 80,
transformer_width=4096, llm_path=None, attn_pool_num_heads=16, clip_embed_dim=768):
super().__init__()
use_flash_attn = use_flash_attn and has_flash_attn
if use_flash_attn and not has_flash_attn:
print('Warning: Flash Attention is not available, use_flash_attn is set to False.')
self.use_flash_attn = use_flash_attn
self.context_length = context_length
self.embed_dim = embed_dim
self.transformer_width = transformer_width
""" text encoder of InternVL """
llama_config = LlamaConfig.from_pretrained(llm_path)
model = LlamaForCausalLM(llama_config)
self.transformer = model.model
self.transformer.gradient_checkpointing = True
self.text_projection = nn.Parameter(torch.empty(transformer_width, clip_embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
""" image encoder of InternVL """
norm_layer_for_blocks = partial(RMSNorm, eps=1e-6)
self.norm_layer_for_blocks = norm_layer_for_blocks
self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
num_patches = self.patch_embed.num_patches
self.num_patches = num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList([
Block(embed_dim, num_heads, mlp_ratio, qkv_bias=qkv_bias,
norm_layer=norm_layer_for_blocks,
drop_path=dpr[i], init_values=init_values, attn_drop=0.,
use_flash_attn=use_flash_attn,
with_cp=with_cp,
qk_normalization=qk_normalization,
layerscale_force_fp32=layerscale_force_fp32)
for i in range(depth)])
self.clip_projector = AttentionPoolingBlock(
dim=embed_dim, num_heads=attn_pool_num_heads, qkv_bias=True, qk_scale=None,
drop=0., attn_drop=0., norm_layer=partial(nn.LayerNorm, eps=1e-5), out_dim=clip_embed_dim)
@property
def dtype(self):
return self.patch_embed.proj.weight.dtype
def forward_features(self, x):
x, _, _ = self.patch_embed(x.type(self.dtype))
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
for idx, blk in enumerate(self.blocks):
x = blk(x)
return x
def encode_image(self, image):
x = self.forward_features(image)
x = self.clip_projector(x)
return x
def encode_text(self, text):
text_key_padding_mask = text > 0
x = self.transformer(input_ids=text, attention_mask=text_key_padding_mask).last_hidden_state
x = x[torch.arange(x.shape[0]), text_key_padding_mask.sum(1) - 1]
x = x @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from transformers import LlamaTokenizer
from .configuration_intern_vit import InternVisionConfig
from .configuration_internvl import InternVLConfig
from .modeling_intern_vit import InternVisionModel
from .modeling_internvl import InternVL_C, InternVL_G, InternVLModel
__all__ = ['InternVisionConfig', 'InternVisionModel', 'InternVLConfig',
'InternVLModel', 'InternVL_C', 'InternVL_G']
# Prefix the text "summarize:"
class InternVLTokenizer(nn.Module):
def __init__(self, model_path):
super(InternVLTokenizer, self).__init__()
self.tokenizer = LlamaTokenizer.from_pretrained(model_path)
self.tokenizer.pad_token = ' ' # allow padding
self.tokenizer.add_eos_token = True
def forward(self, text, prefix='summarize:'):
if type(text) == str:
text = prefix + text
elif type(text) == list:
text = [prefix + item for item in text]
text = self.tokenizer(text, return_tensors='pt', max_length=80, truncation=True, padding='max_length').input_ids
return text
def build_transform(task, image_size=224, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
if task == 'retrieval':
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=mean, std=std)])
else:
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize(image_size, interpolation=InterpolationMode.BICUBIC),
T.CenterCrop(image_size),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
return transform
def load_internvl_c_huggingface(ckpt_path, device, task):
model = InternVL_C.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
if model.config.use_backbone_lora:
model.vision_model.merge_and_unload()
model.vision_model = model.vision_model.model
if model.config.use_qllama_lora:
model.qllama.merge_and_unload()
model.qllama = model.qllama.model
if model.config.force_image_size is not None:
image_size = model.config.force_image_size
else:
image_size = model.config.vision_config.image_size
transform = build_transform(task, image_size)
tokenizer = InternVLTokenizer(ckpt_path)
return model, transform, tokenizer
def load_internvl_g_huggingface(ckpt_path, device, task):
model = InternVL_G.from_pretrained(ckpt_path, torch_dtype=torch.float16).to(device)
if model.config.use_backbone_lora:
model.vision_model.merge_and_unload()
model.vision_model = model.vision_model.model
if model.config.use_qllama_lora:
model.qllama.merge_and_unload()
model.qllama = model.qllama.model
if model.config.force_image_size is not None:
image_size = model.config.force_image_size
else:
image_size = model.config.vision_config.image_size
transform = build_transform(task, image_size)
tokenizer = InternVLTokenizer(ckpt_path)
return model, transform, tokenizer
# --------------------------------------------------------
# InternVL
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import os
from typing import Union
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__)
class InternVisionConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
instantiate a vision encoder according to the specified arguments, defining the model architecture.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
num_channels (`int`, *optional*, defaults to 3):
Number of color channels in the input images (e.g., 3 for RGB).
patch_size (`int`, *optional*, defaults to 14):
The size (resolution) of each patch.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
qkv_bias (`bool`, *optional*, defaults to `False`):
Whether to add a bias to the queries and values in the self-attention layers.
hidden_size (`int`, *optional*, defaults to 3200):
Dimensionality of the encoder layers and the pooler layer.
num_attention_heads (`int`, *optional*, defaults to 25):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (`int`, *optional*, defaults to 12800):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
qk_normalization (`bool`, *optional*, defaults to `True`):
Whether to normalize the queries and keys in the self-attention layers.
num_hidden_layers (`int`, *optional*, defaults to 48):
Number of hidden layers in the Transformer encoder.
use_flash_attn (`bool`, *optional*, defaults to `True`):
Whether to use flash attention mechanism.
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
The epsilon used by the layer normalization layers.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
drop_path_rate (`float`, *optional*, defaults to 0.0):
Dropout rate for stochastic depth.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_factor (`float`, *optional*, defaults to 0.1):
A factor for layer scale.
"""
model_type = 'intern_vit_6b'
def __init__(
self,
num_channels=3,
patch_size=14,
image_size=224,
qkv_bias=False,
hidden_size=3200,
num_attention_heads=25,
intermediate_size=12800,
qk_normalization=True,
num_hidden_layers=48,
use_flash_attn=True,
hidden_act='gelu',
layer_norm_eps=1e-6,
dropout=0.0,
drop_path_rate=0.0,
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=0.1,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.dropout = dropout
self.drop_path_rate = drop_path_rate
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.image_size = image_size
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.qkv_bias = qkv_bias
self.qk_normalization = qk_normalization
self.use_flash_attn = use_flash_attn
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if 'vision_config' in config_dict:
config_dict = config_dict['vision_config']
if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
)
return cls.from_dict(config_dict, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment