"vscode:/vscode.git/clone" did not exist on "affbaa609d4d3afee36ff08866ecc82dcc0d36de"
Commit eea86168 authored by lvskiller's avatar lvskiller
Browse files

monkey_model

parent 80836a45
{
"architectures": [
"MonkeyLMHeadModel"
],
"attn_dropout_prob": 0.0,
"auto_map": {
"AutoConfig": "configuration_qwen.QWenConfig",
"AutoModelForCausalLM": "modeling_monkey.MonkeyLMHeadModel"
},
"bf16": true,
"emb_dropout_prob": 0.0,
"fp16": false,
"fp32": false,
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 22016,
"kv_channels": 128,
"layer_norm_epsilon": 1e-06,
"max_position_embeddings": 8192,
"model_type": "monkey",
"no_bias": true,
"num_attention_heads": 32,
"num_hidden_layers": 32,
"onnx_safe": null,
"rotary_emb_base": 10000,
"rotary_pct": 1.0,
"scale_attn_weights": true,
"seq_length": 2048,
"tie_word_embeddings": false,
"tokenizer_type": "QWenTokenizer",
"torch_dtype": "bfloat16",
"transformers_version": "4.32.0",
"use_cache": false,
"use_dynamic_ntk": true,
"use_flash_attn": false,
"use_logn_attn": true,
"visual": {
"heads": 16,
"image_size": 896,
"image_start_id": 151857,
"layers": 48,
"mlp_ratio": 4.9231,
"output_dim": 4096,
"patch_size": 14,
"width": 1664,
"lora_repeat_num":0
},
"vocab_size": 151936
}
\ No newline at end of file
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from transformers import PretrainedConfig
class MonkeyConfig(PretrainedConfig):
model_type = "monkey"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
emb_dropout_prob=0.0,
attn_dropout_prob=0.0,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
max_position_embeddings=8192,
scale_attn_weights=True,
use_cache=True,
bf16=False,
fp16=False,
fp32=False,
kv_channels=128,
rotary_pct=1.0,
rotary_emb_base=10000,
use_dynamic_ntk=True,
use_logn_attn=True,
use_flash_attn="auto",
intermediate_size=22016,
no_bias=True,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.emb_dropout_prob = emb_dropout_prob
self.attn_dropout_prob = attn_dropout_prob
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.bf16 = bf16
self.fp16 = fp16
self.fp32 = fp32
self.kv_channels = kv_channels
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.no_bias = no_bias
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from transformers import PretrainedConfig
class QwenConfig(PretrainedConfig):
model_type = "monkey"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
emb_dropout_prob=0.0,
attn_dropout_prob=0.0,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
max_position_embeddings=8192,
scale_attn_weights=True,
use_cache=True,
bf16=False,
fp16=False,
fp32=False,
kv_channels=128,
rotary_pct=1.0,
rotary_emb_base=10000,
use_dynamic_ntk=True,
use_logn_attn=True,
use_flash_attn="auto",
intermediate_size=22016,
no_bias=True,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.emb_dropout_prob = emb_dropout_prob
self.attn_dropout_prob = attn_dropout_prob
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.bf16 = bf16
self.fp16 = fp16
self.fp32 = fp32
self.kv_channels = kv_channels
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.no_bias = no_bias
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
import importlib
import math
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.cuda.amp import autocast
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessorList
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
try:
from einops import rearrange
except ImportError:
rearrange = None
from torch import nn
from monkey_model.modeling_qwen import QWenModel,QWenPreTrainedModel,QWenLMHeadModel
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
logger = logging.get_logger(__name__)
class MonkeyModel(QWenModel):
def __init__(self, config):
super().__init__(config)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']):
bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
assert (bos_pos[0] == eos_pos[0]).all()
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
images = []
for i, a, b in img_pos:
image = input_ids[i][a + 1 : b - 1].tolist()
image = image[ : image.index(self.config.visual['image_start_id'] + 2)]
images.append(bytes(image).decode('utf-8'))
windows,images_448 = self.visual.encode(images)
patch_list = []
lora_idx = 0
for col in windows:
for image_patch in col:
patch_list.append(self.visual(image_patch,idx=lora_idx))
lora_idx += 1
global_feat = self.visual(images_448)
local_feat = torch.cat(patch_list,dim=1)
images = torch.cat([local_feat,global_feat],dim=1)
assert images.shape[0] == len(images)
else:
images = None
return super().forward(input_ids,
past_key_values,
attention_mask,
token_type_ids,
position_ids,
head_mask,inputs_embeds,
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
images)
class MonkeyLMHeadModel(QWenLMHeadModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
def __init__(self, config):
super().__init__(config)
assert (
config.bf16 + config.fp16 + config.fp32 <= 1
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
if autoset_precision:
if SUPPORT_BF16:
logger.warn(
"The model is automatically converting to bf16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.bf16 = True
elif SUPPORT_FP16:
logger.warn(
"The model is automatically converting to fp16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.fp16 = True
else:
config.fp32 = True
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
if config.fp32:
if SUPPORT_BF16:
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
elif SUPPORT_FP16:
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
self.transformer = MonkeyModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.bf16:
self.transformer.bfloat16()
self.lm_head.bfloat16()
if config.fp16:
self.transformer.half()
self.lm_head.half()
self.post_init()
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import math
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.cuda.amp import autocast
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessorList
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
try:
from einops import rearrange
except ImportError:
rearrange = None
from torch import nn
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
from .configuration_qwen import QWenConfig
from .qwen_generation_utils import (
HistoryType,
make_context,
decode_tokens,
get_stop_words_ids,
StopWordsLogitsProcessor,
)
from .visual import VisionTransformer
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "qwen"
_CONFIG_FOR_DOC = "QWenConfig"
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
_ERROR_BAD_CHAT_FORMAT = """\
We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
"""
_SENTINEL = object()
_ERROR_STREAM_IN_CHAT = """\
Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
"""
apply_rotary_emb_func = None
rms_norm = None
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class QWenAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.seq_length = config.seq_length
self.hidden_size = config.hidden_size
self.split_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.scale_attn_weights = True
self.projection_size = config.kv_channels * config.num_attention_heads
assert self.projection_size % config.num_attention_heads == 0
self.hidden_size_per_attention_head = (
self.projection_size // config.num_attention_heads
)
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
self.c_proj = nn.Linear(
config.hidden_size, self.projection_size, bias=not config.no_bias
)
self.is_fp32 = not (config.bf16 or config.fp16)
self.bf16 = config.bf16
self.use_dynamic_ntk = config.use_dynamic_ntk
self.use_logn_attn = config.use_logn_attn
logn_list = [
math.log(i, self.seq_length) if i > self.seq_length else 1
for i in range(1, 32768)
]
self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[],
value.size(-1) ** 0.5,
dtype=attn_weights.dtype,
device=attn_weights.device,
)
query_length, key_length = query.size(-2), key.size(-2)
# causal_mask = self.bias[
# :, :, key_length - query_length : key_length, :key_length
# ]
# mask_value = torch.finfo(attn_weights.dtype).min
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
# attn_weights.device
# )
# attn_weights = torch.where(
# causal_mask, attn_weights.to(attn_weights.dtype), mask_value
# )
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights
def _upcast_and_reordered_attn(
self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None
):
bsz, num_heads, q_seq_len, dk = query.size()
_, _, k_seq_len, _ = key.size()
attn_weights = torch.empty(
bsz * num_heads,
q_seq_len,
k_seq_len,
dtype=torch.float32,
device=query.device,
)
scale_factor = 1.0
if self.scale_attn_weights:
scale_factor /= float(value.size(-1)) ** 0.5
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
-1, dk, k_seq_len
)
attn_weights = torch.baddbmm(
attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = registered_causal_mask[
:, :, key_length - query_length : key_length, :key_length
]
mask_value = torch.finfo(attn_weights.dtype).min
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
attn_weights.device
)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if attn_weights.dtype != torch.float32:
raise RuntimeError(
"Error with upcasting, attn_weights does not have dtype torch.float32"
)
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor
def _merge_heads(self, tensor, num_heads, attn_head_size):
tensor = tensor.contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb: Optional[List[torch.Tensor]] = None,
registered_causal_mask: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
):
mixed_x_layer = self.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
if rotary_pos_emb is not None:
cur_len = query.shape[1]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
if use_cache:
present = (key, value)
else:
present = None
if self.use_logn_attn and not self.training:
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
seq_start = key.size(1) - query.size(1)
seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
query = query * logn_tensor.expand_as(query)
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
attn_output, attn_weight = self._attn(
query, key, value, registered_causal_mask, attention_mask, head_mask
)
context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim
)
attn_output = self.c_proj(context_layer)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weight,)
return outputs
class QWenMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.w1 = nn.Linear(
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
)
self.w2 = nn.Linear(
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
)
ff_dim_in = config.intermediate_size // 2
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
def forward(self, hidden_states):
a1 = self.w1(hidden_states)
a2 = self.w2(hidden_states)
intermediate_parallel = a1 * F.silu(a2)
output = self.c_proj(intermediate_parallel)
return output
class QWenBlock(nn.Module):
def __init__(self, config):
super().__init__()
hidden_size = config.hidden_size
self.bf16 = config.bf16
self.ln_1 = RMSNorm(
hidden_size,
eps=config.layer_norm_epsilon,
)
self.attn = QWenAttention(config)
self.ln_2 = RMSNorm(
hidden_size,
eps=config.layer_norm_epsilon,
)
self.mlp = QWenMLP(config)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb: Optional[List[torch.Tensor]] = None,
registered_causal_mask: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
layernorm_output = self.ln_1(hidden_states)
attn_outputs = self.attn(
layernorm_output,
rotary_pos_emb,
registered_causal_mask=registered_causal_mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0]
outputs = attn_outputs[1:]
residual = hidden_states
layernorm_input = attn_output + residual
layernorm_output = self.ln_2(layernorm_input)
residual = layernorm_input
mlp_output = self.mlp(layernorm_output)
hidden_states = residual + mlp_output
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs
class QWenPreTrainedModel(PreTrainedModel):
config_class = QWenConfig
base_model_prefix = "transformer"
is_parallelizable = False
supports_gradient_checkpointing = True
_no_split_modules = ["QWenBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, RMSNorm):
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if name == "c_proj.weight":
p.data.normal_(
mean=0.0,
std=(
self.config.initializer_range
/ math.sqrt(2 * self.config.num_hidden_layers)
),
)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, QWenModel):
module.gradient_checkpointing = value
class QWenModel(QWenPreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config):
super().__init__(config)
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
self.embed_dim = config.hidden_size
self.gradient_checkpointing = False
self.use_dynamic_ntk = config.use_dynamic_ntk
self.seq_length = config.seq_length
self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
self.drop = nn.Dropout(config.emb_dropout_prob)
if config.rotary_pct == 1.0:
self.rotary_ndims = None
else:
assert config.rotary_pct < 1
self.rotary_ndims = int(
config.kv_channels * config.rotary_pct
)
dim = (
self.rotary_ndims
if self.rotary_ndims is not None
else config.kv_channels
)
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
self.use_flash_attn = config.use_flash_attn
self.is_fp32 = not (config.bf16 or config.fp16)
self.registered_causal_mask = None
# if (
# self.use_flash_attn
# and flash_attn_unpadded_func is not None
# and not self.is_fp32
# ):
# self.registered_causal_mask = None
# else:
# max_positions = config.max_position_embeddings
# self.register_buffer(
# "registered_causal_mask",
# torch.tril(
# torch.ones((max_positions, max_positions), dtype=torch.bool)
# ).view(1, 1, max_positions, max_positions),
# persistent=False,
# )
self.h = nn.ModuleList(
[
QWenBlock(
config
)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = RMSNorm(
self.embed_dim,
eps=config.layer_norm_epsilon,
)
self.visual = VisionTransformer(**config.visual)
self.post_init()
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
images=None
):
if images is None:
if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']):
bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
assert (bos_pos[0] == eos_pos[0]).all()
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
images = []
for i, a, b in img_pos:
image = input_ids[i][a + 1 : b - 1].tolist()
image = image[ : image.index(self.config.visual['image_start_id'] + 2)]
images.append(bytes(image).decode('utf-8'))
images = self.visual.encode(images)
assert images.shape[0] == len(images)
else:
images = None
else:
bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
assert (bos_pos[0] == eos_pos[0]).all()
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
encoder_attention_mask = None
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_length
)
hidden_states = inputs_embeds
kv_seq_len = hidden_states.size()[1]
if past_key_values[0] is not None:
# past key values[0][0] shape: bs * seq_len * head_num * dim
kv_seq_len += past_key_values[0][0].shape[1]
if (
self.use_dynamic_ntk
and kv_seq_len == hidden_states.size()[1]
and not self.training
):
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
ntk_alpha = max(ntk_alpha, 1)
else:
ntk_alpha = self.rotary_emb._ntk_alpha_cached
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
for idx in range(len(rotary_pos_emb)):
rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
hidden_states = self.drop(hidden_states)
if images is not None:
for idx, (i, a, b) in enumerate(img_pos):
hidden_states[i][a + 1 : b] = images[idx]
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
rotary_pos_emb,
self.registered_causal_mask,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
rotary_pos_emb=rotary_pos_emb,
registered_causal_mask=self.registered_causal_mask,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class QWenLMHeadModel(QWenPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
def __init__(self, config):
super().__init__(config)
assert (
config.bf16 + config.fp16 + config.fp32 <= 1
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
if autoset_precision:
if SUPPORT_BF16:
logger.warn(
"The model is automatically converting to bf16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.bf16 = True
elif SUPPORT_FP16:
logger.warn(
"The model is automatically converting to fp16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.fp16 = True
else:
config.fp32 = True
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
if config.fp32:
if SUPPORT_BF16:
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
elif SUPPORT_FP16:
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.bf16:
self.transformer.bfloat16()
self.lm_head.bfloat16()
if config.fp16:
self.transformer.half()
self.lm_head.half()
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None)
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
return tuple(
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
)
for layer_past in past_key_values
)
def chat(
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
append_history: bool = True,
stream: Optional[bool] = _SENTINEL,
stop_words_ids: Optional[List[List[int]]] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
) -> Tuple[str, HistoryType]:
generation_config = generation_config if generation_config is not None else self.generation_config
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
if max_window_size is None:
max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer,
query,
history=history,
system=system,
max_window_size=max_window_size,
chat_format=generation_config.chat_format,
)
stop_words_ids.extend(get_stop_words_ids(
generation_config.chat_format, tokenizer
))
input_ids = torch.tensor([context_tokens]).to(self.device)
outputs = self.generate(
input_ids,
stop_words_ids=stop_words_ids,
return_dict_in_generate=False,
generation_config=generation_config,
**kwargs,
)
response = decode_tokens(
outputs[0],
tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
chat_format=generation_config.chat_format,
verbose=False,
errors='replace'
)
if append_history:
history.append((query, response))
return response, history
def chat_pretrain(
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
append_history: bool = False,
stream: Optional[bool] = _SENTINEL,
stop_words_ids: Optional[List[List[int]]] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
) -> Tuple[str, HistoryType]:
generation_config = generation_config if generation_config is not None else self.generation_config
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
if history is None:
history = []
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
if max_window_size is None:
max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer,
query,
history=history,
system=system,
max_window_size=max_window_size,
chat_format=generation_config.chat_format,
)
stop_words_ids.extend(get_stop_words_ids(
generation_config.chat_format, tokenizer
))
input_ids = torch.tensor([context_tokens]).to(self.device)
outputs = self.generate(
input_ids,
stop_words_ids=stop_words_ids,
return_dict_in_generate=False,
generation_config=generation_config,
**kwargs,
)
response = decode_tokens(
outputs[0],
tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
chat_format=generation_config.chat_format,
verbose=False,
errors='replace'
)
if append_history:
history.append((query, response))
return response, history
def chat_stream(
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
stop_words_ids: Optional[List[List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
) -> Generator[str, Any, None]:
generation_config = generation_config if generation_config is not None else self.generation_config
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
if max_window_size is None:
max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer,
query,
history=history,
system=system,
max_window_size=max_window_size,
chat_format=generation_config.chat_format,
)
stop_words_ids.extend(get_stop_words_ids(
generation_config.chat_format, tokenizer
))
if stop_words_ids is not None:
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=generation_config.eos_token_id,
)
if logits_processor is None:
logits_processor = LogitsProcessorList([stop_words_logits_processor])
else:
logits_processor.append(stop_words_logits_processor)
input_ids = torch.tensor([context_tokens]).to(self.device)
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
self.__class__.generate_stream = NewGenerationMixin.generate
self.__class__.sample_stream = NewGenerationMixin.sample_stream
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
def stream_generator():
outputs = []
for token in self.generate_stream(
input_ids,
return_dict_in_generate=False,
generation_config=stream_config,
logits_processor=logits_processor,
seed=-1,
**kwargs):
outputs.append(token.item())
yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')
return stream_generator()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
generation_config = generation_config if generation_config is not None else self.generation_config
# Process stop_words_ids.
stop_words_ids = kwargs.pop("stop_words_ids", None)
if stop_words_ids is None and generation_config is not None:
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
if stop_words_ids is None:
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
if stop_words_ids is not None:
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=generation_config.eos_token_id,
)
if logits_processor is None:
logits_processor = LogitsProcessorList([stop_words_logits_processor])
else:
logits_processor.append(stop_words_logits_processor)
return super().generate(
inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
**kwargs,
)
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
self.dim = dim
self.base = base
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
if importlib.util.find_spec("einops") is None:
raise RuntimeError("einops is required for Rotary Embedding")
self._rotary_pos_emb_cache = None
self._seq_len_cached = 0
self._ntk_alpha_cached = 1.0
def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
seqlen = max_seq_len + offset
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (
base
** (
torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
/ self.dim
)
)
self._seq_len_cached = max(2 * seqlen, 16)
self._ntk_alpha_cached = ntk_alpha
seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
from einops import rearrange
emb = rearrange(emb, "n d -> 1 n 1 d")
cos, sin = emb.cos(), emb.sin()
self._rotary_pos_emb_cache = [cos, sin]
def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
cos, sin = self._rotary_pos_emb_cache
return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]]
def _rotate_half(x):
from einops import rearrange
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs):
cos, sin = freqs
if apply_rotary_emb_func is not None and t.is_cuda:
t_ = t.float()
cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
return output
else:
rot_dim = freqs[0].shape[-1]
cos, sin = freqs
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
t_ = t_.float()
t_pass_ = t_pass_.float()
t_ = (t_ * cos) + (_rotate_half(t_) * sin)
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
if rms_norm is not None and x.is_cuda:
return rms_norm(x, self.weight, self.eps)
else:
output = self._norm(x.float()).type_as(x)
return output * self.weight
This source diff could not be displayed because it is too large. You can view the blob instead.
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Generation support."""
from typing import Tuple, List, Union, Iterable
import numpy as np
import torch
import torch.nn.functional as F
from transformers import PreTrainedTokenizer
from transformers import logging
from transformers.generation import LogitsProcessor
logger = logging.get_logger(__name__)
# Types.
HistoryType = List[Tuple[str, str]]
TokensType = List[int]
BatchTokensType = List[List[int]]
def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
for tokens in batch:
context_length = len(tokens)
if context_length < seq_length:
tokens.extend([pad_id] * (seq_length - context_length))
return batch
def get_ltor_masks_and_position_ids(
data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss,
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(
torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
).view(att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1) :] -= i + 1 - prev_index
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = attention_mask < 0.5
return attention_mask, loss_mask, position_ids
def get_batch(context_tokens: torch.LongTensor, eod_id: int):
"""Generate batch from context tokens."""
# Move to GPU.
tokens = context_tokens.contiguous().to(context_tokens.device)
# Get the attention mask and postition ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
eod_id,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False,
)
return tokens, attention_mask, position_ids
def get_stop_words_ids(chat_format, tokenizer):
if chat_format == "raw":
stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
elif chat_format == "chatml":
stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return stop_words_ids
def make_context(
tokenizer: PreTrainedTokenizer,
query: str,
history: List[Tuple[str, str]] = None,
system: str = "",
max_window_size: int = 6144,
chat_format: str = "chatml",
):
if history is None:
history = []
if chat_format == "chatml":
im_start, im_end = "<|im_start|>", "<|im_end|>"
im_start_tokens = [tokenizer.im_start_id]
im_end_tokens = [tokenizer.im_end_id]
nl_tokens = tokenizer.encode("\n")
def _tokenize_str(role, content):
return f"{role}\n{content}", tokenizer.encode(
role, allowed_special=set(tokenizer.IMAGE_ST)
) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))
system_text, system_tokens_part = _tokenize_str("system", system)
system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
raw_text = ""
context_tokens = []
for turn_query, turn_response in reversed(history):
query_text, query_tokens_part = _tokenize_str("user", turn_query)
query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
if turn_response is not None:
response_text, response_tokens_part = _tokenize_str(
"assistant", turn_response
)
response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
prev_chat = (
f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
)
else:
next_context_tokens = nl_tokens + query_tokens + nl_tokens
prev_chat = f"\n{im_start}{query_text}{im_end}\n"
current_context_size = (
len(system_tokens) + len(next_context_tokens) + len(context_tokens)
)
if current_context_size < max_window_size:
context_tokens = next_context_tokens + context_tokens
raw_text = prev_chat + raw_text
else:
break
context_tokens = system_tokens + context_tokens
raw_text = f"{im_start}{system_text}{im_end}" + raw_text
context_tokens += (
nl_tokens
+ im_start_tokens
+ _tokenize_str("user", query)[1]
+ im_end_tokens
+ nl_tokens
+ im_start_tokens
+ tokenizer.encode("assistant")
+ nl_tokens
)
raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
elif chat_format == "raw":
raw_text = query
context_tokens = tokenizer.encode(raw_text)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
return raw_text, context_tokens
def _decode_default(
tokens: List[int],
*,
stop_words: List[str],
eod_words: List[str],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str='replace',
):
trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
if verbose:
print("\nRaw Generate: ", trim_decode_tokens)
end_reason = f"Gen length {len(tokens)}"
for stop_word in stop_words:
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
for eod_word in eod_words:
if eod_word in trim_decode_tokens:
end_reason = f"Gen {eod_word!r}"
trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
trim_decode_tokens = trim_decode_tokens.strip()
if verbose:
print("\nEnd Reason:", end_reason)
print("\nGenerate: ", trim_decode_tokens)
if return_end_reason:
return trim_decode_tokens, end_reason
else:
return trim_decode_tokens
def _decode_chatml(
tokens: List[int],
*,
stop_words: List[str],
eod_token_ids: List[int],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
verbose: bool = False,
return_end_reason: bool = False,
errors: str='replace'
):
end_reason = f"Gen length {len(tokens)}"
eod_token_idx = context_length
for eod_token_idx in range(context_length, len(tokens)):
if tokens[eod_token_idx] in eod_token_ids:
end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
break
trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
if verbose:
print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
print("\nRaw Generate:", trim_decode_tokens)
print("\nEnd Reason:", end_reason)
for stop_word in stop_words:
trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
trim_decode_tokens = trim_decode_tokens.strip()
if verbose:
print("\nGenerate:", trim_decode_tokens)
if return_end_reason:
return trim_decode_tokens, end_reason
else:
return trim_decode_tokens
def decode_tokens(
tokens: Union[torch.LongTensor, TokensType],
tokenizer: PreTrainedTokenizer,
raw_text_len: int,
context_length: int,
chat_format: str,
verbose: bool = False,
return_end_reason: bool = False,
errors: str="replace",
) -> str:
if torch.is_tensor(tokens):
tokens = tokens.cpu().numpy().tolist()
if chat_format == "chatml":
return _decode_chatml(
tokens,
stop_words=[],
eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
tokenizer=tokenizer,
raw_text_len=raw_text_len,
context_length=context_length,
verbose=verbose,
return_end_reason=return_end_reason,
errors=errors,
)
elif chat_format == "raw":
return _decode_default(
tokens,
stop_words=["<|endoftext|>"],
eod_words=["<|endoftext|>"],
tokenizer=tokenizer,
raw_text_len=raw_text_len,
verbose=verbose,
return_end_reason=return_end_reason,
errors=errors,
)
else:
raise NotImplementedError(f"Unknown chat format {chat_format!r}")
class StopWordsLogitsProcessor(LogitsProcessor):
"""
:class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
Args:
stop_words_ids (:obj:`List[List[int]]`):
List of list of token ids of stop ids. In order to get the tokens of the words
that should not appear in the generated text, use :obj:`tokenizer(bad_word,
add_prefix_space=True).input_ids`.
eos_token_id (:obj:`int`):
The id of the `end-of-sequence` token.
"""
def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
raise ValueError(
f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
)
if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
raise ValueError(
f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
)
if any(
any(
(not isinstance(token_id, (int, np.integer)) or token_id < 0)
for token_id in stop_word_ids
)
for stop_word_ids in stop_words_ids
):
raise ValueError(
f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
)
self.stop_words_ids = list(
filter(
lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
)
)
self.eos_token_id = eos_token_id
for stop_token_seq in self.stop_words_ids:
assert (
len(stop_token_seq) > 0
), "Stop words token sequences {} cannot have an empty list".format(
stop_words_ids
)
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
stopped_samples = self._calc_stopped_samples(input_ids)
for i, should_stop in enumerate(stopped_samples):
if should_stop:
scores[i, self.eos_token_id] = float(2**15)
return scores
def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
elif len(tokens) > len(prev_tokens):
# if bad word tokens are longer then prev input_ids they can't be equal
return False
elif prev_tokens[-len(tokens) :].tolist() == tokens:
# if tokens match
return True
else:
return False
def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
stopped_samples = []
for prev_input_ids_slice in prev_input_ids:
match = False
for stop_token_seq in self.stop_words_ids:
if self._tokens_match(prev_input_ids_slice, stop_token_seq):
# if tokens do not match continue
match = True
break
stopped_samples.append(match)
return stopped_samples
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
"""This function has been mostly taken from huggingface conversational
ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-
conversational-ai-with-transfer-learning-2d818ac26313"""
if top_k > 0:
# Remove all tokens with a probability less than the
# last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token
# above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size(0)):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
{
"pad_token": "<|endoftext|>"
}
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Tokenization classes for QWen."""
import base64
import logging
import os
import requests
import unicodedata
from typing import Collection, Dict, List, Set, Tuple, Union, Any, Callable, Optional
import tiktoken
import numpy as np
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw
from transformers import PreTrainedTokenizer, AddedToken
from transformers.utils import try_to_load_from_cache
import matplotlib.colors as mcolors
from matplotlib.font_manager import FontProperties
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken", "ttf": "SimSun.ttf"}
PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
ENDOFTEXT = "<|endoftext|>"
IMSTART = "<|im_start|>"
IMEND = "<|im_end|>"
# as the default behavior is changed to allow special tokens in
# regular texts, the surface forms of special tokens need to be
# as different as possible to minimize the impact
EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
SPECIAL_TOKENS = (
ENDOFTEXT,
IMSTART,
IMEND,
) + EXTRAS
IMG_TOKEN_SPAN = 1280
def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
with open(tiktoken_bpe_file, "rb") as f:
contents = f.read()
return {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in contents.splitlines() if line)
}
def _list_find(
input_list: List[Any],
candidates: Tuple[Any],
start: int = 0,
):
for i in range(start, len(input_list)):
if input_list[i] in candidates:
return i
return -1
def _replace_closed_tag(
input_tokens: List[Any],
start_tags: Union[Any, Tuple[Any]],
end_tags: Union[Any, Tuple[Any]],
inclusive_replace_func: Callable,
exclusive_replace_func: Callable = lambda x: x,
):
if isinstance(start_tags, (str, int)):
start_tags = (start_tags,)
if isinstance(end_tags, (str, int)):
end_tags = (end_tags,)
assert len(start_tags) == len(end_tags)
output_tokens = []
end = 0
while True:
start = _list_find(input_tokens, start_tags, end)
if start == -1:
break
output_tokens.extend(exclusive_replace_func(input_tokens[end : start]))
tag_idx = start_tags.index(input_tokens[start])
end = _list_find(input_tokens, (end_tags[tag_idx],), start)
if end == -1:
raise ValueError("Unclosed image token")
output_tokens.extend(inclusive_replace_func(input_tokens[start : end + 1]))
end += 1
output_tokens.extend(exclusive_replace_func(input_tokens[end : ]))
return output_tokens
class QWenTokenizer(PreTrainedTokenizer):
"""QWen tokenizer."""
vocab_files_names = VOCAB_FILES_NAMES
def __init__(
self,
vocab_file,
errors="replace",
image_start_tag='<img>',
image_end_tag='</img>',
image_pad_tag='<imgpad>',
ref_start_tag='<ref>',
ref_end_tag='</ref>',
box_start_tag='<box>',
box_end_tag='</box>',
quad_start_tag='<quad>',
quad_end_tag='</quad>',
**kwargs,
):
super().__init__(**kwargs)
self.image_start_tag = image_start_tag
self.image_end_tag = image_end_tag
self.image_pad_tag = image_pad_tag
self.ref_start_tag = ref_start_tag
self.ref_end_tag = ref_end_tag
self.box_start_tag = box_start_tag
self.box_end_tag = box_end_tag
self.quad_start_tag = quad_start_tag
self.quad_end_tag = quad_end_tag
self.IMAGE_ST = (
ref_start_tag, ref_end_tag,
box_start_tag, box_end_tag,
quad_start_tag, quad_end_tag,
image_start_tag, image_end_tag,
image_pad_tag
)
self.errors = errors # how to handle errors in decoding
self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
self.special_tokens = {
token: index
for index, token in enumerate(
SPECIAL_TOKENS + self.IMAGE_ST, start=len(self.mergeable_ranks)
)
}
self.img_start_id = self.special_tokens[self.image_start_tag]
self.img_end_id = self.special_tokens[self.image_end_tag]
self.img_pad_id = self.special_tokens[self.image_pad_tag]
self.ref_start_id = self.special_tokens[self.ref_start_tag]
self.ref_end_id = self.special_tokens[self.ref_end_tag]
self.box_start_id = self.special_tokens[self.box_start_tag]
self.box_end_id = self.special_tokens[self.box_end_tag]
self.quad_start_id = self.special_tokens[self.quad_start_tag]
self.quad_end_id = self.special_tokens[self.quad_end_tag]
enc = tiktoken.Encoding(
"Qwen",
pat_str=PAT_STR,
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
assert (
len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
self.decoder = {
v: k for k, v in self.mergeable_ranks.items()
} # type: dict[int, bytes|str]
self.decoder.update({v: k for k, v in self.special_tokens.items()})
self.tokenizer = enc # type: tiktoken.Encoding
self.eod_id = self.tokenizer.eot_token
self.im_start_id = self.special_tokens[IMSTART]
self.im_end_id = self.special_tokens[IMEND]
def __getstate__(self):
# for pickle lovers
state = self.__dict__.copy()
del state['tokenizer']
return state
def __setstate__(self, state):
# tokenizer is not python native; don't pass it; rebuild it
self.__dict__.update(state)
enc = tiktoken.Encoding(
"Qwen",
pat_str=PAT_STR,
mergeable_ranks=self.mergeable_ranks,
special_tokens=self.special_tokens,
)
self.tokenizer = enc
def __len__(self) -> int:
return self.tokenizer.n_vocab
def get_vocab(self) -> Dict[bytes, int]:
return self.mergeable_ranks
def convert_tokens_to_ids(
self, tokens: Union[bytes, str, List[Union[bytes, str]]]
) -> List[int]:
ids = []
if isinstance(tokens, (str, bytes)):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.mergeable_ranks.get(tokens)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.mergeable_ranks.get(token))
return ids
def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
if not special_tokens and new_tokens:
raise ValueError('Adding regular tokens is not supported')
for token in new_tokens:
surface_form = token.content if isinstance(token, AddedToken) else token
if surface_form not in SPECIAL_TOKENS + self.IMAGE_ST:
raise ValueError('Adding unknown special tokens is not supported')
return 0
def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
"""
Save only the vocabulary of the tokenizer (vocabulary).
Returns:
`Tuple(str)`: Paths to the files saved.
"""
file_path = os.path.join(save_directory, "qwen.tiktoken")
with open(file_path, "w", encoding="utf8") as w:
for k, v in self.mergeable_ranks.items():
line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
w.write(line)
return (file_path,)
def tokenize(
self,
text: str,
allowed_special: Union[Set, str] = "all",
disallowed_special: Union[Collection, str] = (),
**kwargs,
) -> List[Union[bytes, str]]:
"""
Converts a string in a sequence of tokens.
Args:
text (`str`):
The sequence to be encoded.
allowed_special (`Literal["all"]` or `set`):
The surface forms of the tokens to be encoded as special tokens in regular texts.
Default to "all".
disallowed_special (`Literal["all"]` or `Collection`):
The surface forms of the tokens that should not be in regular texts and trigger errors.
Default to an empty tuple.
kwargs (additional keyword arguments, *optional*):
Will be passed to the underlying model specific encode method.
Returns:
`List[bytes|str]`: The list of tokens.
"""
tokens = []
text = unicodedata.normalize("NFC", text)
# this implementation takes a detour: text -> token id -> token surface forms
for t in self.tokenizer.encode(
text, allowed_special=allowed_special, disallowed_special=disallowed_special
):
tokens.append(self.decoder[t])
def _encode_imgurl(img_tokens):
assert img_tokens[0] == self.image_start_tag and img_tokens[-1] == self.image_end_tag
img_tokens = img_tokens[1:-1]
img_url = b''.join(img_tokens)
out_img_tokens = list(map(self.decoder.get, img_url))
if len(out_img_tokens) > IMG_TOKEN_SPAN:
raise ValueError("The content in {}..{} is too long".format(
self.image_start_tag, self.image_end_tag))
out_img_tokens.extend([self.image_pad_tag] * (IMG_TOKEN_SPAN - len(out_img_tokens)))
out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag]
return out_img_tokens
return _replace_closed_tag(tokens, self.image_start_tag, self.image_end_tag, _encode_imgurl)
def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
"""
Converts a sequence of tokens in a single string.
"""
text = ""
temp = b""
for t in tokens:
if isinstance(t, str):
if temp:
text += temp.decode("utf-8", errors=self.errors)
temp = b""
text += t
elif isinstance(t, bytes):
temp += t
else:
raise TypeError("token should only be of type types or str")
if temp:
text += temp.decode("utf-8", errors=self.errors)
return text
@property
def vocab_size(self):
return self.tokenizer.n_vocab
def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
"""Converts an id to a token, special tokens included"""
if index in self.decoder:
return self.decoder[index]
raise ValueError("unknown ids")
def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
"""Converts a token to an id using the vocab, special tokens included"""
if token in self.special_tokens:
return self.special_tokens[token]
if token in self.mergeable_ranks:
return self.mergeable_ranks[token]
raise ValueError("unknown token")
def _tokenize(self, text: str, **kwargs):
"""
Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
Do NOT take care of added tokens.
"""
raise NotImplementedError
def _decode(
self,
token_ids: Union[int, List[int]],
skip_special_tokens: bool = False,
errors: str = None,
**kwargs,
) -> str:
if isinstance(token_ids, int):
token_ids = [token_ids]
def _decode_imgurl(img_token_ids):
assert img_token_ids[0] == self.img_start_id and img_token_ids[-1] == self.img_end_id
img_token_ids = img_token_ids[1:-1]
img_token_ids = img_token_ids[ : img_token_ids.index(self.img_pad_id)]
img_url = bytes(img_token_ids).decode('utf-8')
return [self.img_start_id] + self.tokenizer.encode(img_url) + [self.img_end_id]
token_ids = _replace_closed_tag(token_ids, self.img_start_id, self.img_end_id, _decode_imgurl)
if skip_special_tokens:
token_ids = [i for i in token_ids if i < self.eod_id]
return self.tokenizer.decode(token_ids, errors=errors or self.errors)
def to_list_format(self, text: str):
text = unicodedata.normalize("NFC", text)
token_ids = self.tokenizer.encode(
text, allowed_special=set(self.IMAGE_ST + (ENDOFTEXT,)))
def _encode_vl_info(tokens):
if len(tokens) == 0:
return []
if tokens[0] == self.img_start_id and tokens[-1] == self.img_end_id:
key = 'image'
elif tokens[0] == self.ref_start_id and tokens[-1] == self.ref_end_id:
key = 'ref'
elif tokens[0] == self.box_start_id and tokens[-1] == self.box_end_id:
key = 'box'
elif tokens[0] == self.quad_start_id and tokens[-1] == self.quad_end_id:
key = 'quad'
else:
_tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
return [{'text': b''.join(map(_tobytes, map(self.decoder.get, tokens))).decode('utf-8')}]
_tobytes = lambda x: x.encode('utf-8') if isinstance(x, str) else x
val = b''.join(map(_tobytes, map(self.decoder.get, tokens[1:-1]))).decode('utf-8')
return [{key: val}]
return _replace_closed_tag(
token_ids,
(self.img_start_id, self.ref_start_id, self.box_start_id, self.quad_start_id),
(self.img_end_id, self.ref_end_id, self.box_end_id, self.quad_end_id),
_encode_vl_info,
_encode_vl_info,
)
def from_list_format(self, list_format: List[Dict]):
text = ''
num_images = 0
for ele in list_format:
if 'image' in ele:
num_images += 1
text += f'Picture {num_images}:'
text += self.image_start_tag + ele['image'] + self.image_end_tag
text += '\n'
elif 'text' in ele:
text += ele['text']
elif 'box' in ele:
if 'ref' in ele:
text += self.ref_start_tag + ele['ref'] + self.ref_end_tag
for box in ele['box']:
text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag
else:
raise ValueError("Unsupport element: " + str(ele))
return text
def _fetch_latest_picture(self, response, history):
if history is None:
history = []
_history = history + [(response, None)]
for q, r in _history[::-1]:
for ele in self.to_list_format(q)[::-1]:
if 'image' in ele:
return ele['image']
return None
def _fetch_all_box_with_ref(self, text):
list_format = self.to_list_format(text)
output = []
for i, ele in enumerate(list_format):
if 'box' in ele:
bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
assert len(bbox) == 4
output.append({'box': bbox})
if i > 0 and 'ref' in list_format[i-1]:
output[-1]['ref'] = list_format[i-1]['ref'].strip()
return output
def draw_bbox_on_latest_picture(
self,
response,
history=None,
) -> Optional[Image.Image]:
image = self._fetch_latest_picture(response, history)
if image is None:
return None
if image.startswith("http://") or image.startswith("https://"):
image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
h, w = image.height, image.width
else:
image = np.asarray(Image.open(image).convert("RGB"))
h, w = image.shape[0], image.shape[1]
visualizer = Visualizer(image)
boxes = self._fetch_all_box_with_ref(response)
if not boxes:
return None
color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()]) # init color
for box in boxes:
if 'ref' in box: # random new color for new refexps
color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()])
x1, y1, x2, y2 = box['box']
x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
visualizer.draw_box((x1, y1, x2, y2), alpha=1, edge_color=color)
if 'ref' in box:
visualizer.draw_text(box['ref'], (x1, y1), color=color, horizontal_alignment="left")
return visualizer.output
import colorsys
import logging
import math
import numpy as np
import matplotlib as mpl
import matplotlib.colors as mplc
import matplotlib.figure as mplfigure
import torch
from matplotlib.backends.backend_agg import FigureCanvasAgg
from PIL import Image
import random
logger = logging.getLogger(__name__)
class VisImage:
def __init__(self, img, scale=1.0):
self.img = img
self.scale = scale
self.width, self.height = img.shape[1], img.shape[0]
self._setup_figure(img)
def _setup_figure(self, img):
fig = mplfigure.Figure(frameon=False)
self.dpi = fig.get_dpi()
# add a small 1e-2 to avoid precision lost due to matplotlib's truncation
# (https://github.com/matplotlib/matplotlib/issues/15363)
fig.set_size_inches(
(self.width * self.scale + 1e-2) / self.dpi,
(self.height * self.scale + 1e-2) / self.dpi,
)
self.canvas = FigureCanvasAgg(fig)
# self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
ax.axis("off")
self.fig = fig
self.ax = ax
self.reset_image(img)
def reset_image(self, img):
img = img.astype("uint8")
self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
def save(self, filepath):
self.fig.savefig(filepath)
def get_image(self):
canvas = self.canvas
s, (width, height) = canvas.print_to_buffer()
buffer = np.frombuffer(s, dtype="uint8")
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
return rgb.astype("uint8")
class Visualizer:
def __init__(self, img_rgb, metadata=None, scale=1.0):
self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
self.font_path = try_to_load_from_cache("Qwen/Qwen-VL-Chat", "SimSun.ttf")
self.output = VisImage(self.img, scale=scale)
self.cpu_device = torch.device("cpu")
# too small texts are useless, therefore clamp to 14
self._default_font_size = max(
np.sqrt(self.output.height * self.output.width) // 30, 15 // scale
)
def draw_text(
self,
text,
position,
*,
font_size=None,
color="g",
horizontal_alignment="center",
rotation=0,
):
if not font_size:
font_size = self._default_font_size
# since the text background is dark, we don't want the text to be dark
color = np.maximum(list(mplc.to_rgb(color)), 0.2)
color[np.argmax(color)] = max(0.8, np.max(color))
x, y = position
self.output.ax.text(
x,
y,
text,
size=font_size * self.output.scale,
fontproperties=FontProperties(fname=self.font_path),
bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
verticalalignment="top",
horizontalalignment=horizontal_alignment,
color=color,
zorder=10,
rotation=rotation,
)
return self.output
def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
x0, y0, x1, y1 = box_coord
width = x1 - x0
height = y1 - y0
linewidth = max(self._default_font_size / 4, 1)
self.output.ax.add_patch(
mpl.patches.Rectangle(
(x0, y0),
width,
height,
fill=False,
edgecolor=edge_color,
linewidth=linewidth * self.output.scale,
alpha=alpha,
linestyle=line_style,
)
)
return self.output
def get_output(self):
return self.output
{
"auto_map": {
"AutoTokenizer": [
"tokenization_qwen.QWenTokenizer",
null
]
},
"clean_up_tokenization_spaces": true,
"model_max_length": 2048,
"padding_side": "right",
"tokenizer_class": "QWenTokenizer"
}
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
import math
import requests
from io import BytesIO
from functools import partial
from PIL import Image
from typing import Callable, Optional, Sequence, Tuple, List
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import trunc_normal_
from torchvision import transforms
from torchvision.transforms import InterpolationMode
def reconstruct_matrix(windows):
temp =[]
for col in windows:
temp.append(torch.cat((col),dim=3))
all_img = torch.cat(temp,dim=2)
return all_img
def sliding_window(matrix, window_size, stride):
b,c,height, width = matrix.shape
window_rows = (height - window_size[0]) // stride + 1
window_cols = (width - window_size[1]) // stride + 1
windows = []
for i in range(window_rows):
windows_col = []
for j in range(window_cols):
window = matrix[:,:, i*stride:i*stride+window_size[0], j*stride:j*stride+window_size[1]]
windows_col.append(window)
windows.append(windows_col)
return windows
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Resampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
grid_size,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=nn.LayerNorm
):
super().__init__()
self.num_queries = grid_size ** 2
self.embed_dim = embed_dim
self.num_heads = num_heads
self.pos_embed = nn.Parameter(
torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
).requires_grad_(False)
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
trunc_normal_(self.query, std=.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, attn_mask=None):
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
return out.permute(1, 0, 2)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
class Lora_Adapter(nn.Module):
def __init__(self,
d_model=None,
out_feat=None,
r=16,
dropout=0.05):
super().__init__()
self.d_model = d_model
self.out_feat = out_feat
self.r = r
self.lora_scale = nn.Parameter(torch.ones(1))
self.lora_a = nn.Linear(self.d_model, self.r,bias=False)
self.lora_b = nn.Linear(self.r, self.out_feat,bias=False)
self.lora_dropout = nn.Dropout(p=dropout)
with torch.no_grad():
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_b.weight)
def forward(self, x ):
#residual = x if residual is None else residual
x = self.lora_dropout(x)
down = self.lora_a(x)
up = self.lora_b(down)
up = up * self.lora_scale
output = up
return output
class VisualAttention(nn.Module):
"""self-attention layer class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def __init__(self, embed_dim, num_heads,
bias=True, kdim=None, vdim=None,lora_repeat_num=4):
super(VisualAttention, self).__init__()
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.num_heads = num_heads
# Per attention head and per partition values.
assert embed_dim % num_heads == 0
self.hidden_size_per_attention_head = embed_dim // num_heads
self.num_attention_heads_per_partition = num_heads
self.hidden_size_per_partition = embed_dim
# Strided linear layer.
assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently'
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim)
self.in_proj_lora = []
for _ in range(lora_repeat_num):
self.in_proj_lora.append(Lora_Adapter(d_model=embed_dim,out_feat=3 * embed_dim))
self.in_proj_lora = nn.ModuleList(self.in_proj_lora)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj_lora = []
for _ in range(lora_repeat_num):
self.out_proj_lora.append(Lora_Adapter(d_model=embed_dim,out_feat=embed_dim))
self.out_proj_lora = nn.ModuleList(self.out_proj_lora)
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
def forward(self, query, key, value, attn_mask = None,idx = None):
# query/key/value: [sq, b, h]
sq, b, _ = query.size()
assert query is key, 'Only Support Self-Attention Currently'
sk = sq
mixed_x_layer = self.in_proj(query)
if idx == None:
pass
else:
lora_res = self.in_proj_lora[idx](query)
mixed_x_layer += lora_res
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = mixed_x_layer.split(
self.hidden_size_per_attention_head, dim=-1)
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer = query_layer.view(sq,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.view(sk,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
q_scaled = query_layer / self.norm_factor
if attn_mask is not None:
attention_probs = torch.baddbmm(attn_mask, q_scaled, key_layer.transpose(-2, -1))
else:
attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1))
attention_probs = attention_probs.softmax(dim=-1)
value_layer = value_layer.view(sk,
b * self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head).transpose(0, 1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer)
# change view [b, np, sq, hn]
context_layer = context_layer.view(b,
self.num_attention_heads_per_partition,
sq, self.hidden_size_per_attention_head)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + \
(self.hidden_size_per_partition,)
context_layer = context_layer.view(*new_context_layer_shape)
output = self.out_proj(context_layer)
if idx == None:
pass
else:
lora_res = self.out_proj_lora[idx](context_layer)
output += lora_res
return output
class VisualAttentionBlock(nn.Module):
def __init__(
self,
d_model: int,
n_head: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
is_cross_attention: bool = False,
lora_repeat_num = 4,
):
super().__init__()
self.ln_1 = norm_layer(d_model)
if is_cross_attention:
self.ln_1_kv = norm_layer(d_model)
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.attn = VisualAttention(d_model, n_head,lora_repeat_num = lora_repeat_num)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model))
]))
self.mlp_lora = []
for _ in range(lora_repeat_num):
self.mlp_lora.append(Lora_Adapter(d_model=d_model,out_feat=d_model,r=32))
self.mlp_lora = nn.ModuleList(self.mlp_lora)
def attention(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
idx = None
):
k_x = k_x if k_x is not None else q_x
v_x = v_x if v_x is not None else q_x
attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None
return self.attn(q_x, k_x, v_x, attn_mask=attn_mask,idx=idx)
def forward(
self,
q_x: torch.Tensor,
k_x: Optional[torch.Tensor] = None,
v_x: Optional[torch.Tensor] = None,
attn_mask: Optional[torch.Tensor] = None,
idx = None
):
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
x = q_x + self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask,idx=idx)
residual = x
x = x + self.mlp(self.ln_2(x))
if idx == None:
pass
else:
x += self.mlp_lora[idx](residual)
return x
class TransformerBlock(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
mlp_ratio: float = 4.0,
act_layer: Callable = nn.GELU,
norm_layer: Callable = nn.LayerNorm,
lora_repeat_num=4
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList([
VisualAttentionBlock(
width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer,lora_repeat_num=lora_repeat_num)
for _ in range(layers)
])
def get_cast_dtype(self) -> torch.dtype:
return self.resblocks[0].mlp.c_fc.weight.dtype
def get_cast_device(self) -> torch.device:
return self.resblocks[0].mlp.c_fc.weight.device
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None,idx=None):
for r in self.resblocks:
x = r(x, attn_mask=attn_mask,idx=idx)
return x
class VisionTransformer(nn.Module):
def __init__(
self,
image_size: int,
patch_size: int,
width: int,
layers: int,
heads: int,
mlp_ratio: float,
n_queries: int = 256,
output_dim: int = 512,
lora_repeat_num: int = 4,
**kwargs
):
super().__init__()
image_height, image_width = self.image_size = (image_size, image_size)
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
self.grid_size = (image_height // patch_height, image_width // patch_width)
self.output_dim = output_dim
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
self.image_transform = transforms.Compose([
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
# class embeddings and positional embeddings
scale = width ** -0.5
self.positional_embedding = nn.Parameter(scale * torch.randn(256, width))
norm_layer = partial(nn.LayerNorm, eps=1e-6)
act_layer = nn.GELU
self.ln_pre = norm_layer(width)
self.transformer = TransformerBlock(
width,
layers,
heads,
mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
lora_repeat_num=lora_repeat_num
)
self.attn_pool = Resampler(
grid_size=int(math.sqrt(n_queries)),
embed_dim=output_dim,
num_heads=output_dim // 128,
kv_dim=width,
norm_layer=norm_layer,
)
self.ln_post = norm_layer(output_dim)
self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim))
def forward(self, x: torch.Tensor,idx=None):
x = x.to(
dtype=self.transformer.get_cast_dtype(),
device=self.transformer.get_cast_device(),
)
# to patches
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = x + get_abs_pos(self.positional_embedding, x.size(1))
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x,idx=idx)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.attn_pool(x)
x = self.ln_post(x)
x = x @ self.proj
return x
def encode(self, image_paths: List[str]):
images = []
for image_path in image_paths:
if image_path.startswith("http://") or image_path.startswith("https://"):
image = Image.open(requests.get(image_path, stream=True).raw)
else:
image = Image.open(image_path)
image = image.convert("RGB")
images.append(self.image_transform(image))
images = torch.stack(images, dim=0)
B,C,H,W = images.shape
windows = sliding_window(images,window_size=(448,448),stride=448)
images_448 = F.interpolate(images, size=(448,448), mode='bicubic')
return windows,images_448
if __name__ == "__main__":
pass
visual = VisionTransformer(
image_size= 896,
patch_size= 14,
width=1664,
layers = 48,
heads= 16,
mlp_ratio = 4.9231,
output_dim= 4096)
img = torch.randn(1,3,896,896)
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType
# Define LoRA Config
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["in_proj","out_proj","c_fc","c_proj"],
lora_dropout=0.05,
bias="none",
)
# prepare int-8 model for training
model = visual
# add LoRA adaptor
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
print(model)
print(visual)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment